mirror of
https://github.com/handsomezhuzhu/api-proxy.git
synced 2026-04-18 22:32:54 +00:00
添加可选安全开关以限制请求体大小和基础限流功能
This commit is contained in:
166
main.go
166
main.go
@@ -1,15 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -20,6 +25,23 @@ const (
|
||||
var (
|
||||
proxyMap = make(map[string]*httputil.ReverseProxy)
|
||||
tpl *template.Template
|
||||
|
||||
maxBodyBytes int64
|
||||
rateLimitRPS float64
|
||||
rateLimitBurst int
|
||||
|
||||
limiterState = struct {
|
||||
mu sync.Mutex
|
||||
clients map[string]*tokenBucket
|
||||
lastCleanup time.Time
|
||||
}{
|
||||
clients: make(map[string]*tokenBucket),
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
limiterIdleTTL = 10 * time.Minute
|
||||
limiterCleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// Denied headers that should not be forwarded to the upstream API
|
||||
@@ -218,6 +240,12 @@ type loggingResponseWriter struct {
|
||||
statusCode int
|
||||
}
|
||||
|
||||
type tokenBucket struct {
|
||||
tokens float64
|
||||
lastRefill time.Time
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
||||
lrw.statusCode = code
|
||||
lrw.ResponseWriter.WriteHeader(code)
|
||||
@@ -253,9 +281,119 @@ func getClientIP(r *http.Request) string {
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
return host
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
func parseInt64Env(name string) int64 {
|
||||
raw := strings.TrimSpace(os.Getenv(name))
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
v, err := strconv.ParseInt(raw, 10, 64)
|
||||
if err != nil || v < 0 {
|
||||
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func parseFloatEnv(name string) float64 {
|
||||
raw := strings.TrimSpace(os.Getenv(name))
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
v, err := strconv.ParseFloat(raw, 64)
|
||||
if err != nil || v < 0 {
|
||||
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func parseIntEnv(name string) int {
|
||||
raw := strings.TrimSpace(os.Getenv(name))
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil || v < 0 {
|
||||
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func loadOptionalProtections() {
|
||||
maxBodyBytes = parseInt64Env("PROXY_MAX_BODY_BYTES")
|
||||
rateLimitRPS = parseFloatEnv("PROXY_RATE_LIMIT_RPS")
|
||||
rateLimitBurst = parseIntEnv("PROXY_RATE_LIMIT_BURST")
|
||||
|
||||
if rateLimitRPS > 0 && rateLimitBurst <= 0 {
|
||||
rateLimitBurst = int(math.Ceil(rateLimitRPS))
|
||||
if rateLimitBurst < 1 {
|
||||
rateLimitBurst = 1
|
||||
}
|
||||
}
|
||||
if rateLimitRPS <= 0 || rateLimitBurst <= 0 {
|
||||
rateLimitRPS = 0
|
||||
rateLimitBurst = 0
|
||||
}
|
||||
|
||||
log.Printf("Optional protections: PROXY_MAX_BODY_BYTES=%d (0=off), PROXY_RATE_LIMIT_RPS=%.2f, PROXY_RATE_LIMIT_BURST=%d (0=off)", maxBodyBytes, rateLimitRPS, rateLimitBurst)
|
||||
}
|
||||
|
||||
func allowByRateLimit(clientID string) bool {
|
||||
if rateLimitRPS <= 0 || rateLimitBurst <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
limiterState.mu.Lock()
|
||||
defer limiterState.mu.Unlock()
|
||||
|
||||
if now.Sub(limiterState.lastCleanup) >= limiterCleanupInterval {
|
||||
for key, bucket := range limiterState.clients {
|
||||
if now.Sub(bucket.lastSeen) > limiterIdleTTL {
|
||||
delete(limiterState.clients, key)
|
||||
}
|
||||
}
|
||||
limiterState.lastCleanup = now
|
||||
}
|
||||
|
||||
bucket, ok := limiterState.clients[clientID]
|
||||
if !ok {
|
||||
limiterState.clients[clientID] = &tokenBucket{
|
||||
tokens: float64(rateLimitBurst - 1),
|
||||
lastRefill: now,
|
||||
lastSeen: now,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
elapsed := now.Sub(bucket.lastRefill).Seconds()
|
||||
if elapsed > 0 {
|
||||
bucket.tokens = math.Min(float64(rateLimitBurst), bucket.tokens+elapsed*rateLimitRPS)
|
||||
bucket.lastRefill = now
|
||||
}
|
||||
bucket.lastSeen = now
|
||||
|
||||
if bucket.tokens >= 1 {
|
||||
bucket.tokens -= 1
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeJSONError(w http.ResponseWriter, status int, message, errType string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
fmt.Fprintf(w, `{"error":{"message":%q,"type":%q}}`, message, errType)
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Parse template
|
||||
|
||||
@@ -325,9 +463,14 @@ func init() {
|
||||
clientIP := getClientIP(r)
|
||||
log.Printf("[ERROR] Client: %s | Target: %s | Error: %v", clientIP, targetURL.Host, err)
|
||||
|
||||
var maxErr *http.MaxBytesError
|
||||
if errors.As(err, &maxErr) || strings.Contains(strings.ToLower(err.Error()), "request body too large") {
|
||||
writeJSONError(w, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large")
|
||||
return
|
||||
}
|
||||
|
||||
// 返回 JSON 格式错误,方便 AI 客户端解析
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
fmt.Fprintf(w, `{"error": {"message": "Proxy Connection Error: %v", "type": "proxy_error"}}`, err)
|
||||
writeJSONError(w, http.StatusBadGateway, "Proxy connection error", "proxy_error")
|
||||
}
|
||||
|
||||
proxyMap[path] = proxy
|
||||
@@ -385,6 +528,24 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
||||
// 4. Serve Proxy
|
||||
proxy := proxyMap[matchedPrefix]
|
||||
|
||||
if rateLimitRPS > 0 {
|
||||
if !allowByRateLimit(clientIP) {
|
||||
lrw.Header().Set("Retry-After", "1")
|
||||
writeJSONError(lrw, http.StatusTooManyRequests, "Rate limit exceeded", "rate_limit_exceeded")
|
||||
log.Printf("[429] Path: %s | IP: %s", path, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if maxBodyBytes > 0 {
|
||||
if r.ContentLength > maxBodyBytes && r.ContentLength != -1 {
|
||||
writeJSONError(lrw, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large")
|
||||
log.Printf("[413] Path: %s | IP: %s | Content-Length: %d", path, clientIP, r.ContentLength)
|
||||
return
|
||||
}
|
||||
r.Body = http.MaxBytesReader(lrw, r.Body, maxBodyBytes)
|
||||
}
|
||||
|
||||
// Rewrite the path: remove the prefix
|
||||
// Example: /openai/v1/chat/completions -> /v1/chat/completions
|
||||
// The SingleHostReverseProxy will append this to the target URL.
|
||||
@@ -440,6 +601,7 @@ func main() {
|
||||
if len(os.Args) > 1 {
|
||||
port = os.Args[1]
|
||||
}
|
||||
loadOptionalProtections()
|
||||
|
||||
// Determine the HTTP server settings
|
||||
server := &http.Server{
|
||||
|
||||
Reference in New Issue
Block a user