添加可选安全开关以限制请求体大小和基础限流功能

This commit is contained in:
2026-03-05 11:17:59 +08:00
parent fc011785d8
commit 068c192fb7
3 changed files with 188 additions and 2 deletions

166
main.go
View File

@@ -1,15 +1,20 @@
package main
import (
"errors"
"fmt"
"html/template"
"log"
"math"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"sort"
"strconv"
"strings"
"sync"
"time"
)
@@ -20,6 +25,23 @@ const (
var (
proxyMap = make(map[string]*httputil.ReverseProxy)
tpl *template.Template
maxBodyBytes int64
rateLimitRPS float64
rateLimitBurst int
limiterState = struct {
mu sync.Mutex
clients map[string]*tokenBucket
lastCleanup time.Time
}{
clients: make(map[string]*tokenBucket),
}
)
const (
limiterIdleTTL = 10 * time.Minute
limiterCleanupInterval = 5 * time.Minute
)
// Denied headers that should not be forwarded to the upstream API
@@ -218,6 +240,12 @@ type loggingResponseWriter struct {
statusCode int
}
type tokenBucket struct {
tokens float64
lastRefill time.Time
lastSeen time.Time
}
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
@@ -253,9 +281,119 @@ func getClientIP(r *http.Request) string {
parts := strings.Split(xff, ",")
return strings.TrimSpace(parts[0])
}
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
return host
}
return r.RemoteAddr
}
func parseInt64Env(name string) int64 {
raw := strings.TrimSpace(os.Getenv(name))
if raw == "" {
return 0
}
v, err := strconv.ParseInt(raw, 10, 64)
if err != nil || v < 0 {
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
return 0
}
return v
}
func parseFloatEnv(name string) float64 {
raw := strings.TrimSpace(os.Getenv(name))
if raw == "" {
return 0
}
v, err := strconv.ParseFloat(raw, 64)
if err != nil || v < 0 {
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
return 0
}
return v
}
func parseIntEnv(name string) int {
raw := strings.TrimSpace(os.Getenv(name))
if raw == "" {
return 0
}
v, err := strconv.Atoi(raw)
if err != nil || v < 0 {
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
return 0
}
return v
}
func loadOptionalProtections() {
maxBodyBytes = parseInt64Env("PROXY_MAX_BODY_BYTES")
rateLimitRPS = parseFloatEnv("PROXY_RATE_LIMIT_RPS")
rateLimitBurst = parseIntEnv("PROXY_RATE_LIMIT_BURST")
if rateLimitRPS > 0 && rateLimitBurst <= 0 {
rateLimitBurst = int(math.Ceil(rateLimitRPS))
if rateLimitBurst < 1 {
rateLimitBurst = 1
}
}
if rateLimitRPS <= 0 || rateLimitBurst <= 0 {
rateLimitRPS = 0
rateLimitBurst = 0
}
log.Printf("Optional protections: PROXY_MAX_BODY_BYTES=%d (0=off), PROXY_RATE_LIMIT_RPS=%.2f, PROXY_RATE_LIMIT_BURST=%d (0=off)", maxBodyBytes, rateLimitRPS, rateLimitBurst)
}
func allowByRateLimit(clientID string) bool {
if rateLimitRPS <= 0 || rateLimitBurst <= 0 {
return true
}
now := time.Now()
limiterState.mu.Lock()
defer limiterState.mu.Unlock()
if now.Sub(limiterState.lastCleanup) >= limiterCleanupInterval {
for key, bucket := range limiterState.clients {
if now.Sub(bucket.lastSeen) > limiterIdleTTL {
delete(limiterState.clients, key)
}
}
limiterState.lastCleanup = now
}
bucket, ok := limiterState.clients[clientID]
if !ok {
limiterState.clients[clientID] = &tokenBucket{
tokens: float64(rateLimitBurst - 1),
lastRefill: now,
lastSeen: now,
}
return true
}
elapsed := now.Sub(bucket.lastRefill).Seconds()
if elapsed > 0 {
bucket.tokens = math.Min(float64(rateLimitBurst), bucket.tokens+elapsed*rateLimitRPS)
bucket.lastRefill = now
}
bucket.lastSeen = now
if bucket.tokens >= 1 {
bucket.tokens -= 1
return true
}
return false
}
func writeJSONError(w http.ResponseWriter, status int, message, errType string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
fmt.Fprintf(w, `{"error":{"message":%q,"type":%q}}`, message, errType)
}
func init() {
// Parse template
@@ -325,9 +463,14 @@ func init() {
clientIP := getClientIP(r)
log.Printf("[ERROR] Client: %s | Target: %s | Error: %v", clientIP, targetURL.Host, err)
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) || strings.Contains(strings.ToLower(err.Error()), "request body too large") {
writeJSONError(w, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large")
return
}
// 返回 JSON 格式错误,方便 AI 客户端解析
w.WriteHeader(http.StatusBadGateway)
fmt.Fprintf(w, `{"error": {"message": "Proxy Connection Error: %v", "type": "proxy_error"}}`, err)
writeJSONError(w, http.StatusBadGateway, "Proxy connection error", "proxy_error")
}
proxyMap[path] = proxy
@@ -385,6 +528,24 @@ func handler(w http.ResponseWriter, r *http.Request) {
// 4. Serve Proxy
proxy := proxyMap[matchedPrefix]
if rateLimitRPS > 0 {
if !allowByRateLimit(clientIP) {
lrw.Header().Set("Retry-After", "1")
writeJSONError(lrw, http.StatusTooManyRequests, "Rate limit exceeded", "rate_limit_exceeded")
log.Printf("[429] Path: %s | IP: %s", path, clientIP)
return
}
}
if maxBodyBytes > 0 {
if r.ContentLength > maxBodyBytes && r.ContentLength != -1 {
writeJSONError(lrw, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large")
log.Printf("[413] Path: %s | IP: %s | Content-Length: %d", path, clientIP, r.ContentLength)
return
}
r.Body = http.MaxBytesReader(lrw, r.Body, maxBodyBytes)
}
// Rewrite the path: remove the prefix
// Example: /openai/v1/chat/completions -> /v1/chat/completions
// The SingleHostReverseProxy will append this to the target URL.
@@ -440,6 +601,7 @@ func main() {
if len(os.Args) > 1 {
port = os.Args[1]
}
loadOptionalProtections()
// Determine the HTTP server settings
server := &http.Server{