mirror of
https://github.com/handsomezhuzhu/api-proxy.git
synced 2026-02-20 20:00:15 +00:00
重构代理处理逻辑,优化模板渲染和请求转发
This commit is contained in:
293
main.go
293
main.go
@@ -2,65 +2,38 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"html/template"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var apiMapping = map[string]string{
|
var (
|
||||||
// "/discord": "https://discord.com/api",
|
proxyMap = make(map[string]*httputil.ReverseProxy)
|
||||||
// "/telegram": "https://api.telegram.org",
|
tpl *template.Template
|
||||||
"/openai": "https://api.openai.com",
|
)
|
||||||
"/claude": "https://api.anthropic.com",
|
|
||||||
"/gemini": "https://generativelanguage.googleapis.com",
|
// Denied headers that should not be forwarded to the upstream API
|
||||||
"/meta": "https://www.meta.ai/api",
|
var deniedHeaderPrefixes = []string{"cf-", "forward", "cdn"}
|
||||||
"/groq": "https://api.groq.com/openai",
|
var deniedExactHeaders = map[string]bool{
|
||||||
"/xai": "https://api.x.ai",
|
"host": true,
|
||||||
"/cohere": "https://api.cohere.ai",
|
"referer": true,
|
||||||
"/huggingface": "https://api-inference.huggingface.co",
|
"connection": true,
|
||||||
"/together": "https://api.together.xyz",
|
"keep-alive": true,
|
||||||
"/novita": "https://api.novita.ai",
|
"proxy-authenticate": true,
|
||||||
"/portkey": "https://api.portkey.ai",
|
"proxy-authorization": true,
|
||||||
"/fireworks": "https://api.fireworks.ai",
|
"te": true,
|
||||||
"/openrouter": "https://openrouter.ai/api",
|
"trailers": true,
|
||||||
"/cerebras": "https://api.cerebras.ai",
|
"transfer-encoding": true,
|
||||||
|
"upgrade": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
var deniedHeaders = []string{"host", "referer", "cf-", "forward", "cdn"}
|
const htmlTemplate = `<!DOCTYPE html>
|
||||||
|
|
||||||
func isAllowedHeader(key string) bool {
|
|
||||||
for _, deniedHeader := range deniedHeaders {
|
|
||||||
if strings.Contains(strings.ToLower(key), deniedHeader) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func targetURL(pathname string) string {
|
|
||||||
split := strings.Index(pathname[1:], "/")
|
|
||||||
prefix := pathname[:split+1]
|
|
||||||
if base, exists := apiMapping[prefix]; exists {
|
|
||||||
return base + pathname[len(prefix):]
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func handler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
|
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
|
|
||||||
var paths []string
|
|
||||||
for k := range apiMapping {
|
|
||||||
paths = append(paths, k)
|
|
||||||
}
|
|
||||||
sort.Strings(paths)
|
|
||||||
|
|
||||||
html := `<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
<html lang="en">
|
||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
@@ -206,12 +179,16 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
|||||||
<th>Target Service URL</th>
|
<th>Target Service URL</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>`
|
<tbody>
|
||||||
for _, path := range paths {
|
{{range .Items}}
|
||||||
target := apiMapping[path]
|
<tr>
|
||||||
html += fmt.Sprintf("<tr><td><code>%s</code></td><td><span class=\"target-url\">%s</span></td></tr>", path, target)
|
<td><code>{{.Path}}</code></td>
|
||||||
}
|
<td><span class="target-url">{{.Target}}</span></td>
|
||||||
html += `</tbody></table></div>
|
</tr>
|
||||||
|
{{end}}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
<div class="footer">
|
<div class="footer">
|
||||||
AI API Proxy © 2024
|
AI API Proxy © 2024
|
||||||
<br>
|
<br>
|
||||||
@@ -220,10 +197,102 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
|||||||
</div>
|
</div>
|
||||||
</body>
|
</body>
|
||||||
</html>`
|
</html>`
|
||||||
fmt.Fprint(w, html)
|
|
||||||
|
type PageData struct {
|
||||||
|
Items []MappingItem
|
||||||
|
}
|
||||||
|
|
||||||
|
type MappingItem struct {
|
||||||
|
Path string
|
||||||
|
Target string
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiMapping = map[string]string{
|
||||||
|
// "/discord": "https://discord.com/api",
|
||||||
|
// "/telegram": "https://api.telegram.org",
|
||||||
|
"/openai": "https://api.openai.com",
|
||||||
|
"/claude": "https://api.anthropic.com",
|
||||||
|
"/gemini": "https://generativelanguage.googleapis.com",
|
||||||
|
"/meta": "https://www.meta.ai/api",
|
||||||
|
"/groq": "https://api.groq.com/openai",
|
||||||
|
"/xai": "https://api.x.ai",
|
||||||
|
"/cohere": "https://api.cohere.ai",
|
||||||
|
"/huggingface": "https://api-inference.huggingface.co",
|
||||||
|
"/together": "https://api.together.xyz",
|
||||||
|
"/novita": "https://api.novita.ai",
|
||||||
|
"/portkey": "https://api.portkey.ai",
|
||||||
|
"/fireworks": "https://api.fireworks.ai",
|
||||||
|
"/openrouter": "https://openrouter.ai/api",
|
||||||
|
"/cerebras": "https://api.cerebras.ai",
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Parse template
|
||||||
|
var err error
|
||||||
|
tpl, err = template.New("index").Parse(htmlTemplate)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize proxies
|
||||||
|
for path, target := range apiMapping {
|
||||||
|
targetURL, err := url.Parse(target)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Invalid URL for %s: %v", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
|
||||||
|
// 1. Optimize connection reuse and streaming
|
||||||
|
// FlushInterval is crucial for SSE (Server-Sent Events) to work properly with AI APIs
|
||||||
|
proxy.FlushInterval = 100 * time.Millisecond
|
||||||
|
|
||||||
|
// Custom Director to handle headers and request modification
|
||||||
|
originalDirector := proxy.Director
|
||||||
|
proxy.Director = func(req *http.Request) {
|
||||||
|
originalDirector(req)
|
||||||
|
|
||||||
|
// Set the Host header to the target host (required by many APIs like OpenAI, Cloudflare)
|
||||||
|
req.Host = targetURL.Host
|
||||||
|
|
||||||
|
// Filter denied headers
|
||||||
|
// Note: httputil already removes standard hop-by-hop headers
|
||||||
|
for k := range req.Header {
|
||||||
|
lowerKey := strings.ToLower(k)
|
||||||
|
if deniedExactHeaders[lowerKey] {
|
||||||
|
req.Header.Del(k)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, prefix := range deniedHeaderPrefixes {
|
||||||
|
if strings.HasPrefix(lowerKey, prefix) {
|
||||||
|
req.Header.Del(k)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anonymize forward headers
|
||||||
|
req.Header.Del("X-Forwarded-For")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional: Custom error handler
|
||||||
|
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
log.Printf("Proxy error for %s: %v", r.URL.Path, err)
|
||||||
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyMap[path] = proxy
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// 1. Handle Home Page
|
||||||
|
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
|
||||||
|
renderHome(w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 2. Handle Robots.txt
|
||||||
if r.URL.Path == "/robots.txt" {
|
if r.URL.Path == "/robots.txt" {
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -231,63 +300,68 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
query := r.URL.RawQuery
|
// 3. Find Matching Proxy
|
||||||
|
// Iterate to find the matching prefix.
|
||||||
|
// Optimistically look for exact prefix match or prefix/
|
||||||
|
var matchedPrefix string
|
||||||
|
path := r.URL.Path
|
||||||
|
|
||||||
if query != "" {
|
// Note: Iterating map is random, but since keys are distinct root segments (e.g. /openai),
|
||||||
query = "?" + query
|
// simple prefix check works. For nested paths, one would need to sort keys by length descending.
|
||||||
|
for prefix := range apiMapping {
|
||||||
|
// Match "/openai" or "/openai/..."
|
||||||
|
if path == prefix || strings.HasPrefix(path, prefix+"/") {
|
||||||
|
matchedPrefix = prefix
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
targetURL := targetURL(r.URL.Path + query)
|
if matchedPrefix == "" {
|
||||||
|
|
||||||
if targetURL == "" {
|
|
||||||
http.Error(w, "Not Found", http.StatusNotFound)
|
http.Error(w, "Not Found", http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new request
|
// 4. Serve Proxy
|
||||||
client := &http.Client{}
|
proxy := proxyMap[matchedPrefix]
|
||||||
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
|
|
||||||
if err != nil {
|
// Rewrite the path: remove the prefix
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
// Example: /openai/v1/chat/completions -> /v1/chat/completions
|
||||||
return
|
// The SingleHostReverseProxy will append this to the target URL.
|
||||||
|
// Target: https://api.openai.com
|
||||||
|
// Result: https://api.openai.com/v1/chat/completions
|
||||||
|
r.URL.Path = strings.TrimPrefix(path, matchedPrefix)
|
||||||
|
|
||||||
|
// Ensure path starts with / if it became empty
|
||||||
|
if r.URL.Path == "" {
|
||||||
|
r.URL.Path = "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, values := range r.Header {
|
// ServeHTTP automatically handles:
|
||||||
if isAllowedHeader(key) {
|
// - Connection reuse (Keep-Alive)
|
||||||
for _, value := range values {
|
// - Context cancellation (client disconnects -> stops upstream request)
|
||||||
proxyReq.Header.Add(key, value)
|
// - Header copying (with hop-by-hop removal)
|
||||||
}
|
// - Body copying
|
||||||
}
|
// - Streaming responses
|
||||||
|
proxy.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderHome(w http.ResponseWriter) {
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
// Prepare data for template
|
||||||
|
var items []MappingItem
|
||||||
|
for k, v := range apiMapping {
|
||||||
|
items = append(items, MappingItem{Path: k, Target: v})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make the request
|
// Sort by path for consistent display
|
||||||
resp, err := client.Do(proxyReq)
|
sort.Slice(items, func(i, j int) bool {
|
||||||
if err != nil {
|
return items[i].Path < items[j].Path
|
||||||
log.Printf("Failed to fetch: %v", err)
|
})
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// Copy response headers
|
if err := tpl.Execute(w, PageData{Items: items}); err != nil {
|
||||||
for key, values := range resp.Header {
|
log.Printf("Template execution error: %v", err)
|
||||||
for _, value := range values {
|
|
||||||
w.Header().Add(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set security headers
|
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
w.Header().Set("X-Frame-Options", "DENY")
|
|
||||||
w.Header().Set("Referrer-Policy", "no-referrer")
|
|
||||||
|
|
||||||
// Set status code
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
// Copy response body
|
|
||||||
_, err = io.Copy(w, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error copying response: %v", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,9 +370,18 @@ func main() {
|
|||||||
if len(os.Args) > 1 {
|
if len(os.Args) > 1 {
|
||||||
port = os.Args[1]
|
port = os.Args[1]
|
||||||
}
|
}
|
||||||
http.HandleFunc("/", handler)
|
|
||||||
log.Printf("Starting server on :" + port)
|
// Determine the HTTP server settings
|
||||||
if err := http.ListenAndServe(":"+port, nil); err != nil {
|
server := &http.Server{
|
||||||
|
Addr: ":" + port,
|
||||||
|
Handler: http.HandlerFunc(handler),
|
||||||
|
ReadTimeout: 10 * time.Minute, // Allow long headers/body reading
|
||||||
|
WriteTimeout: 0, // MUST be 0 for streaming responses (SSE) to work indefinitely
|
||||||
|
IdleTimeout: 60 * time.Second, // Keep-alive connection idle time
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Starting proxy server on " + server.Addr)
|
||||||
|
if err := server.ListenAndServe(); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user