From 068c192fb74f363ba3c7c3632bb62ecaf97b90e1 Mon Sep 17 00:00:00 2001 From: handsomezhuzhu <2658601135@qq.com> Date: Thu, 5 Mar 2026 11:17:59 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=8F=AF=E9=80=89=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E5=BC=80=E5=85=B3=E4=BB=A5=E9=99=90=E5=88=B6=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E4=BD=93=E5=A4=A7=E5=B0=8F=E5=92=8C=E5=9F=BA=E7=A1=80?= =?UTF-8?q?=E9=99=90=E6=B5=81=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 20 ++++++ docker-compose.yml | 4 ++ main.go | 166 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 188 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3d63280..b9a3b39 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,26 @@ go build -o api-proxy main.go ./api-proxy 8080 ``` +### 可选安全开关(默认关闭) + +你可以在部署时通过环境变量开启请求体大小限制和基础限流;不设置时保持“开放代理”行为。 + +```bash +# 限制请求体最大 10MB(0 或不设置=关闭) +export PROXY_MAX_BODY_BYTES=10485760 + +# 基础限流:每个客户端 IP 每秒 5 个请求,突发 10(任一为 0 或不设置=关闭) +export PROXY_RATE_LIMIT_RPS=5 +export PROXY_RATE_LIMIT_BURST=10 + +./api-proxy +``` + +说明: +- `PROXY_MAX_BODY_BYTES`: 请求体上限(字节) +- `PROXY_RATE_LIMIT_RPS`: 每秒补充 token 数 +- `PROXY_RATE_LIMIT_BURST`: 桶容量;若只设置 `RPS`,会自动取 `ceil(RPS)` + ## 使用示例 假设你的服务运行在 `http://localhost:7890`。 diff --git a/docker-compose.yml b/docker-compose.yml index ca3d85f..9940c8b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,4 +7,8 @@ services: # 如果您在 main.go 中配置的端口是 7890,那么这里保持 7890 ports: - "7890:7890" # 映射:将主机的 80 端口映射到容器内的 7890 端口 + environment: + - PROXY_MAX_BODY_BYTES=${PROXY_MAX_BODY_BYTES:-0} + - PROXY_RATE_LIMIT_RPS=${PROXY_RATE_LIMIT_RPS:-0} + - PROXY_RATE_LIMIT_BURST=${PROXY_RATE_LIMIT_BURST:-0} restart: always \ No newline at end of file diff --git a/main.go b/main.go index 4e31c3f..a1762e8 100644 --- a/main.go +++ b/main.go @@ -1,15 +1,20 @@ package main import ( + "errors" "fmt" "html/template" "log" + "math" + "net" "net/http" "net/http/httputil" "net/url" "os" "sort" + "strconv" "strings" + "sync" "time" ) @@ -20,6 +25,23 @@ const ( var ( proxyMap = make(map[string]*httputil.ReverseProxy) tpl *template.Template + + maxBodyBytes int64 + rateLimitRPS float64 + rateLimitBurst int + + limiterState = struct { + mu sync.Mutex + clients map[string]*tokenBucket + lastCleanup time.Time + }{ + clients: make(map[string]*tokenBucket), + } +) + +const ( + limiterIdleTTL = 10 * time.Minute + limiterCleanupInterval = 5 * time.Minute ) // Denied headers that should not be forwarded to the upstream API @@ -218,6 +240,12 @@ type loggingResponseWriter struct { statusCode int } +type tokenBucket struct { + tokens float64 + lastRefill time.Time + lastSeen time.Time +} + func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) @@ -253,9 +281,119 @@ func getClientIP(r *http.Request) string { parts := strings.Split(xff, ",") return strings.TrimSpace(parts[0]) } + if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + return host + } return r.RemoteAddr } +func parseInt64Env(name string) int64 { + raw := strings.TrimSpace(os.Getenv(name)) + if raw == "" { + return 0 + } + v, err := strconv.ParseInt(raw, 10, 64) + if err != nil || v < 0 { + log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw) + return 0 + } + return v +} + +func parseFloatEnv(name string) float64 { + raw := strings.TrimSpace(os.Getenv(name)) + if raw == "" { + return 0 + } + v, err := strconv.ParseFloat(raw, 64) + if err != nil || v < 0 { + log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw) + return 0 + } + return v +} + +func parseIntEnv(name string) int { + raw := strings.TrimSpace(os.Getenv(name)) + if raw == "" { + return 0 + } + v, err := strconv.Atoi(raw) + if err != nil || v < 0 { + log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw) + return 0 + } + return v +} + +func loadOptionalProtections() { + maxBodyBytes = parseInt64Env("PROXY_MAX_BODY_BYTES") + rateLimitRPS = parseFloatEnv("PROXY_RATE_LIMIT_RPS") + rateLimitBurst = parseIntEnv("PROXY_RATE_LIMIT_BURST") + + if rateLimitRPS > 0 && rateLimitBurst <= 0 { + rateLimitBurst = int(math.Ceil(rateLimitRPS)) + if rateLimitBurst < 1 { + rateLimitBurst = 1 + } + } + if rateLimitRPS <= 0 || rateLimitBurst <= 0 { + rateLimitRPS = 0 + rateLimitBurst = 0 + } + + log.Printf("Optional protections: PROXY_MAX_BODY_BYTES=%d (0=off), PROXY_RATE_LIMIT_RPS=%.2f, PROXY_RATE_LIMIT_BURST=%d (0=off)", maxBodyBytes, rateLimitRPS, rateLimitBurst) +} + +func allowByRateLimit(clientID string) bool { + if rateLimitRPS <= 0 || rateLimitBurst <= 0 { + return true + } + + now := time.Now() + + limiterState.mu.Lock() + defer limiterState.mu.Unlock() + + if now.Sub(limiterState.lastCleanup) >= limiterCleanupInterval { + for key, bucket := range limiterState.clients { + if now.Sub(bucket.lastSeen) > limiterIdleTTL { + delete(limiterState.clients, key) + } + } + limiterState.lastCleanup = now + } + + bucket, ok := limiterState.clients[clientID] + if !ok { + limiterState.clients[clientID] = &tokenBucket{ + tokens: float64(rateLimitBurst - 1), + lastRefill: now, + lastSeen: now, + } + return true + } + + elapsed := now.Sub(bucket.lastRefill).Seconds() + if elapsed > 0 { + bucket.tokens = math.Min(float64(rateLimitBurst), bucket.tokens+elapsed*rateLimitRPS) + bucket.lastRefill = now + } + bucket.lastSeen = now + + if bucket.tokens >= 1 { + bucket.tokens -= 1 + return true + } + return false +} + +func writeJSONError(w http.ResponseWriter, status int, message, errType string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + fmt.Fprintf(w, `{"error":{"message":%q,"type":%q}}`, message, errType) +} + func init() { // Parse template @@ -325,9 +463,14 @@ func init() { clientIP := getClientIP(r) log.Printf("[ERROR] Client: %s | Target: %s | Error: %v", clientIP, targetURL.Host, err) + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) || strings.Contains(strings.ToLower(err.Error()), "request body too large") { + writeJSONError(w, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large") + return + } + // 返回 JSON 格式错误,方便 AI 客户端解析 - w.WriteHeader(http.StatusBadGateway) - fmt.Fprintf(w, `{"error": {"message": "Proxy Connection Error: %v", "type": "proxy_error"}}`, err) + writeJSONError(w, http.StatusBadGateway, "Proxy connection error", "proxy_error") } proxyMap[path] = proxy @@ -385,6 +528,24 @@ func handler(w http.ResponseWriter, r *http.Request) { // 4. Serve Proxy proxy := proxyMap[matchedPrefix] + if rateLimitRPS > 0 { + if !allowByRateLimit(clientIP) { + lrw.Header().Set("Retry-After", "1") + writeJSONError(lrw, http.StatusTooManyRequests, "Rate limit exceeded", "rate_limit_exceeded") + log.Printf("[429] Path: %s | IP: %s", path, clientIP) + return + } + } + + if maxBodyBytes > 0 { + if r.ContentLength > maxBodyBytes && r.ContentLength != -1 { + writeJSONError(lrw, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large") + log.Printf("[413] Path: %s | IP: %s | Content-Length: %d", path, clientIP, r.ContentLength) + return + } + r.Body = http.MaxBytesReader(lrw, r.Body, maxBodyBytes) + } + // Rewrite the path: remove the prefix // Example: /openai/v1/chat/completions -> /v1/chat/completions // The SingleHostReverseProxy will append this to the target URL. @@ -440,6 +601,7 @@ func main() { if len(os.Args) > 1 { port = os.Args[1] } + loadOptionalProtections() // Determine the HTTP server settings server := &http.Server{