重构代理处理逻辑,优化模板渲染和请求转发

This commit is contained in:
2026-01-20 13:43:36 +08:00
parent 4a8bfa248d
commit a70730886f

293
main.go
View File

@@ -2,65 +2,38 @@ package main
import ( import (
"fmt" "fmt"
"io" "html/template"
"log" "log"
"net/http" "net/http"
"net/http/httputil"
"net/url"
"os" "os"
"sort" "sort"
"strings" "strings"
"time"
) )
var apiMapping = map[string]string{ var (
// "/discord": "https://discord.com/api", proxyMap = make(map[string]*httputil.ReverseProxy)
// "/telegram": "https://api.telegram.org", tpl *template.Template
"/openai": "https://api.openai.com", )
"/claude": "https://api.anthropic.com",
"/gemini": "https://generativelanguage.googleapis.com", // Denied headers that should not be forwarded to the upstream API
"/meta": "https://www.meta.ai/api", var deniedHeaderPrefixes = []string{"cf-", "forward", "cdn"}
"/groq": "https://api.groq.com/openai", var deniedExactHeaders = map[string]bool{
"/xai": "https://api.x.ai", "host": true,
"/cohere": "https://api.cohere.ai", "referer": true,
"/huggingface": "https://api-inference.huggingface.co", "connection": true,
"/together": "https://api.together.xyz", "keep-alive": true,
"/novita": "https://api.novita.ai", "proxy-authenticate": true,
"/portkey": "https://api.portkey.ai", "proxy-authorization": true,
"/fireworks": "https://api.fireworks.ai", "te": true,
"/openrouter": "https://openrouter.ai/api", "trailers": true,
"/cerebras": "https://api.cerebras.ai", "transfer-encoding": true,
"upgrade": true,
} }
var deniedHeaders = []string{"host", "referer", "cf-", "forward", "cdn"} const htmlTemplate = `<!DOCTYPE html>
func isAllowedHeader(key string) bool {
for _, deniedHeader := range deniedHeaders {
if strings.Contains(strings.ToLower(key), deniedHeader) {
return false
}
}
return true
}
func targetURL(pathname string) string {
split := strings.Index(pathname[1:], "/")
prefix := pathname[:split+1]
if base, exists := apiMapping[prefix]; exists {
return base + pathname[len(prefix):]
}
return ""
}
func handler(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
var paths []string
for k := range apiMapping {
paths = append(paths, k)
}
sort.Strings(paths)
html := `<!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
@@ -206,12 +179,16 @@ func handler(w http.ResponseWriter, r *http.Request) {
<th>Target Service URL</th> <th>Target Service URL</th>
</tr> </tr>
</thead> </thead>
<tbody>` <tbody>
for _, path := range paths { {{range .Items}}
target := apiMapping[path] <tr>
html += fmt.Sprintf("<tr><td><code>%s</code></td><td><span class=\"target-url\">%s</span></td></tr>", path, target) <td><code>{{.Path}}</code></td>
} <td><span class="target-url">{{.Target}}</span></td>
html += `</tbody></table></div> </tr>
{{end}}
</tbody>
</table>
</div>
<div class="footer"> <div class="footer">
AI API Proxy &copy; 2024 AI API Proxy &copy; 2024
<br> <br>
@@ -220,10 +197,102 @@ func handler(w http.ResponseWriter, r *http.Request) {
</div> </div>
</body> </body>
</html>` </html>`
fmt.Fprint(w, html)
type PageData struct {
Items []MappingItem
}
type MappingItem struct {
Path string
Target string
}
var apiMapping = map[string]string{
// "/discord": "https://discord.com/api",
// "/telegram": "https://api.telegram.org",
"/openai": "https://api.openai.com",
"/claude": "https://api.anthropic.com",
"/gemini": "https://generativelanguage.googleapis.com",
"/meta": "https://www.meta.ai/api",
"/groq": "https://api.groq.com/openai",
"/xai": "https://api.x.ai",
"/cohere": "https://api.cohere.ai",
"/huggingface": "https://api-inference.huggingface.co",
"/together": "https://api.together.xyz",
"/novita": "https://api.novita.ai",
"/portkey": "https://api.portkey.ai",
"/fireworks": "https://api.fireworks.ai",
"/openrouter": "https://openrouter.ai/api",
"/cerebras": "https://api.cerebras.ai",
}
func init() {
// Parse template
var err error
tpl, err = template.New("index").Parse(htmlTemplate)
if err != nil {
log.Fatalf("Failed to parse template: %v", err)
}
// Initialize proxies
for path, target := range apiMapping {
targetURL, err := url.Parse(target)
if err != nil {
log.Fatalf("Invalid URL for %s: %v", path, err)
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
// 1. Optimize connection reuse and streaming
// FlushInterval is crucial for SSE (Server-Sent Events) to work properly with AI APIs
proxy.FlushInterval = 100 * time.Millisecond
// Custom Director to handle headers and request modification
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
// Set the Host header to the target host (required by many APIs like OpenAI, Cloudflare)
req.Host = targetURL.Host
// Filter denied headers
// Note: httputil already removes standard hop-by-hop headers
for k := range req.Header {
lowerKey := strings.ToLower(k)
if deniedExactHeaders[lowerKey] {
req.Header.Del(k)
continue
}
for _, prefix := range deniedHeaderPrefixes {
if strings.HasPrefix(lowerKey, prefix) {
req.Header.Del(k)
break
}
}
}
// Anonymize forward headers
req.Header.Del("X-Forwarded-For")
}
// Optional: Custom error handler
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
log.Printf("Proxy error for %s: %v", r.URL.Path, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
}
proxyMap[path] = proxy
}
}
func handler(w http.ResponseWriter, r *http.Request) {
// 1. Handle Home Page
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
renderHome(w)
return return
} }
// 2. Handle Robots.txt
if r.URL.Path == "/robots.txt" { if r.URL.Path == "/robots.txt" {
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -231,63 +300,68 @@ func handler(w http.ResponseWriter, r *http.Request) {
return return
} }
query := r.URL.RawQuery // 3. Find Matching Proxy
// Iterate to find the matching prefix.
// Optimistically look for exact prefix match or prefix/
var matchedPrefix string
path := r.URL.Path
if query != "" { // Note: Iterating map is random, but since keys are distinct root segments (e.g. /openai),
query = "?" + query // simple prefix check works. For nested paths, one would need to sort keys by length descending.
for prefix := range apiMapping {
// Match "/openai" or "/openai/..."
if path == prefix || strings.HasPrefix(path, prefix+"/") {
matchedPrefix = prefix
break
}
} }
targetURL := targetURL(r.URL.Path + query) if matchedPrefix == "" {
if targetURL == "" {
http.Error(w, "Not Found", http.StatusNotFound) http.Error(w, "Not Found", http.StatusNotFound)
return return
} }
// Create new request // 4. Serve Proxy
client := &http.Client{} proxy := proxyMap[matchedPrefix]
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil { // Rewrite the path: remove the prefix
http.Error(w, "Internal Server Error", http.StatusInternalServerError) // Example: /openai/v1/chat/completions -> /v1/chat/completions
return // The SingleHostReverseProxy will append this to the target URL.
// Target: https://api.openai.com
// Result: https://api.openai.com/v1/chat/completions
r.URL.Path = strings.TrimPrefix(path, matchedPrefix)
// Ensure path starts with / if it became empty
if r.URL.Path == "" {
r.URL.Path = "/"
} }
for key, values := range r.Header { // ServeHTTP automatically handles:
if isAllowedHeader(key) { // - Connection reuse (Keep-Alive)
for _, value := range values { // - Context cancellation (client disconnects -> stops upstream request)
proxyReq.Header.Add(key, value) // - Header copying (with hop-by-hop removal)
} // - Body copying
} // - Streaming responses
proxy.ServeHTTP(w, r)
}
func renderHome(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
// Prepare data for template
var items []MappingItem
for k, v := range apiMapping {
items = append(items, MappingItem{Path: k, Target: v})
} }
// Make the request // Sort by path for consistent display
resp, err := client.Do(proxyReq) sort.Slice(items, func(i, j int) bool {
if err != nil { return items[i].Path < items[j].Path
log.Printf("Failed to fetch: %v", err) })
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// Copy response headers if err := tpl.Execute(w, PageData{Items: items}); err != nil {
for key, values := range resp.Header { log.Printf("Template execution error: %v", err)
for _, value := range values {
w.Header().Add(key, value)
}
}
// Set security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("Referrer-Policy", "no-referrer")
// Set status code
w.WriteHeader(resp.StatusCode)
// Copy response body
_, err = io.Copy(w, resp.Body)
if err != nil {
log.Printf("Error copying response: %v", err)
} }
} }
@@ -296,9 +370,18 @@ func main() {
if len(os.Args) > 1 { if len(os.Args) > 1 {
port = os.Args[1] port = os.Args[1]
} }
http.HandleFunc("/", handler)
log.Printf("Starting server on :" + port) // Determine the HTTP server settings
if err := http.ListenAndServe(":"+port, nil); err != nil { server := &http.Server{
Addr: ":" + port,
Handler: http.HandlerFunc(handler),
ReadTimeout: 10 * time.Minute, // Allow long headers/body reading
WriteTimeout: 0, // MUST be 0 for streaming responses (SSE) to work indefinitely
IdleTimeout: 60 * time.Second, // Keep-alive connection idle time
}
log.Printf("Starting proxy server on " + server.Addr)
if err := server.ListenAndServe(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }