package main
import (
"fmt"
"html/template"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"sort"
"strings"
"time"
)
const (
Version = "1.2.0"
)
var (
proxyMap = make(map[string]*httputil.ReverseProxy)
tpl *template.Template
)
// Denied headers that should not be forwarded to the upstream API
// x-real-ip is denied to prevent leaking internal IP structures if behind multiple proxies
var deniedHeaderPrefixes = []string{"cf-", "forward", "cdn", "x-real-ip"}
var deniedExactHeaders = map[string]bool{
"host": true,
// "referer": true, // 有些 CDN 防盗链可能需要 referer
"connection": true,
"keep-alive": true,
"proxy-authenticate": true,
"proxy-authorization": true,
"te": true,
"trailers": true,
"transfer-encoding": true,
"upgrade": true,
}
const htmlTemplate = `
AI API Proxy
AI API Proxy Service
Maintained by Simon
✅
Service is active and running
This service routes requests to various AI provider APIs through a unified interface.
Available Endpoints
| Path Prefix |
Target Service URL |
{{range .Items}}
{{.Path}} |
{{.Target}} |
{{end}}
`
type PageData struct {
Items []MappingItem
Version string
}
type MappingItem struct {
Path string
Target string
}
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
}
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
var apiMapping = map[string]string{
// "/discord": "https://discord.com/api",
// "/telegram": "https://api.telegram.org",
"/openai": "https://api.openai.com",
"/claude": "https://api.anthropic.com",
"/gemini": "https://generativelanguage.googleapis.com",
"/meta": "https://www.meta.ai/api",
"/groq": "https://api.groq.com/openai",
"/xai": "https://api.x.ai",
"/cohere": "https://api.cohere.ai",
"/huggingface": "https://api-inference.huggingface.co",
"/together": "https://api.together.xyz",
"/novita": "https://api.novita.ai",
"/portkey": "https://api.portkey.ai",
"/fireworks": "https://api.fireworks.ai",
"/openrouter": "https://openrouter.ai/api",
"/cerebras": "https://api.cerebras.ai",
}
// 获取请求的真实 IP,优先获取 CDN 传递的 Header
func getClientIP(r *http.Request) string {
// 阿里 CDN (ESA) 通常会把真实 IP 放在 X-Forwarded-For 的第一个
// 或者尝试获取 Ali-Cdn-Real-Ip
if ip := r.Header.Get("Ali-Cdn-Real-Ip"); ip != "" {
return ip
}
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
return strings.TrimSpace(parts[0])
}
return r.RemoteAddr
}
func init() {
// Parse template
var err error
tpl, err = template.New("index").Parse(htmlTemplate)
if err != nil {
log.Fatalf("Failed to parse template: %v", err)
}
// Initialize proxies
for path, target := range apiMapping {
targetURL, err := url.Parse(target)
if err != nil {
log.Fatalf("Invalid URL for %s: %v", path, err)
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
// 1. Optimize connection reuse and streaming
// FlushInterval is crucial for SSE (Server-Sent Events) to work properly with AI APIs
proxy.FlushInterval = 100 * time.Millisecond
// Custom Director to handle headers and request modification
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
// Set the Host header to the target host (required by many APIs like OpenAI, Cloudflare)
req.Host = targetURL.Host
// Filter denied headers
// Note: httputil already removes standard hop-by-hop headers
for k := range req.Header {
lowerKey := strings.ToLower(k)
if deniedExactHeaders[lowerKey] {
req.Header.Del(k)
continue
}
for _, prefix := range deniedHeaderPrefixes {
if strings.HasPrefix(lowerKey, prefix) {
req.Header.Del(k)
break
}
}
}
// Anonymize forward headers
req.Header.Del("X-Forwarded-For")
}
// 新增:ModifyResponse 强制禁用 CDN 缓存
proxy.ModifyResponse = func(res *http.Response) error {
// Nginx 和大部分 CDN 识别这个 Header 来禁用缓冲
res.Header.Set("X-Accel-Buffering", "no")
// [修改点] 针对经过代理的所有请求,强制添加禁止缓存头
// 配合 CDN 的 "遵循源站" 策略,实现精细化控制
res.Header.Set("Cache-Control", "no-cache, no-store, must-revalidate")
res.Header.Set("Pragma", "no-cache")
res.Header.Set("Expires", "0")
return nil
}
// Optional: Custom error handler
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
clientIP := getClientIP(r)
log.Printf("[ERROR] Client: %s | Target: %s | Error: %v", clientIP, targetURL.Host, err)
// 返回 JSON 格式错误,方便 AI 客户端解析
w.WriteHeader(http.StatusBadGateway)
fmt.Fprintf(w, `{"error": {"message": "Proxy Connection Error: %v", "type": "proxy_error"}}`, err)
}
proxyMap[path] = proxy
}
}
func handler(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
clientIP := getClientIP(r)
// 包装 Writer 记录状态码
lrw := &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
// 简单日志
if r.URL.Path != "/" {
log.Printf("[REQ] %s %s from %s", r.Method, r.URL.Path, clientIP)
}
// 1. Handle Home Page
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
renderHome(lrw)
return
}
// 2. Handle Robots.txt
if r.URL.Path == "/robots.txt" {
lrw.Header().Set("Content-Type", "text/plain")
lrw.WriteHeader(http.StatusOK)
fmt.Fprint(lrw, "User-agent: *\nDisallow: /")
return
}
// 3. Find Matching Proxy
// Iterate to find the matching prefix.
// Optimistically look for exact prefix match or prefix/
var matchedPrefix string
path := r.URL.Path
// Note: Iterating map is random, but since keys are distinct root segments (e.g. /openai),
// simple prefix check works. For nested paths, one would need to sort keys by length descending.
for prefix := range apiMapping {
// Match "/openai" or "/openai/..."
if path == prefix || strings.HasPrefix(path, prefix+"/") {
matchedPrefix = prefix
break
}
}
if matchedPrefix == "" {
http.Error(lrw, "Not Found", http.StatusNotFound)
log.Printf("[404] Path: %s | IP: %s", path, clientIP)
return
}
// 4. Serve Proxy
proxy := proxyMap[matchedPrefix]
// Rewrite the path: remove the prefix
// Example: /openai/v1/chat/completions -> /v1/chat/completions
// The SingleHostReverseProxy will append this to the target URL.
// Target: https://api.openai.com
// Result: https://api.openai.com/v1/chat/completions
r.URL.Path = strings.TrimPrefix(path, matchedPrefix)
// Ensure path starts with / if it became empty
if r.URL.Path == "" {
r.URL.Path = "/"
}
// ServeHTTP automatically handles:
// - Connection reuse (Keep-Alive)
// - Context cancellation (client disconnects -> stops upstream request)
// - Header copying (with hop-by-hop removal)
// - Body copying
// - Streaming responses
proxy.ServeHTTP(lrw, r)
// 只有非 200 或者耗时较长时才打印结束日志,避免日志刷屏
duration := time.Since(startTime)
if lrw.statusCode != 200 || duration > 5*time.Second {
log.Printf("[RES] %d | %v | %s -> %s", lrw.statusCode, duration, path, apiMapping[matchedPrefix])
}
}
func renderHome(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
// Prepare data for template
var items []MappingItem
for k, v := range apiMapping {
items = append(items, MappingItem{Path: k, Target: v})
}
// Sort by path for consistent display
sort.Slice(items, func(i, j int) bool {
return items[i].Path < items[j].Path
})
if err := tpl.Execute(w, PageData{
Items: items,
Version: Version,
}); err != nil {
log.Printf("Template execution error: %v", err)
}
}
func main() {
port := "7890"
if len(os.Args) > 1 {
port = os.Args[1]
}
// Determine the HTTP server settings
server := &http.Server{
Addr: ":" + port,
Handler: http.HandlerFunc(handler),
ReadTimeout: 10 * time.Minute, // Allow long headers/body reading
WriteTimeout: 0, // MUST be 0 for streaming responses (SSE) to work indefinitely
IdleTimeout: 60 * time.Second, // Keep-alive connection idle time
}
log.Printf("Starting proxy server on %s (Behind CDN mode)", server.Addr)
if err := server.ListenAndServe(); err != nil {
log.Fatal(err)
}
}