添加可选安全开关以限制请求体大小和基础限流功能

This commit is contained in:
2026-03-05 11:17:59 +08:00
parent fc011785d8
commit 068c192fb7
3 changed files with 188 additions and 2 deletions

View File

@@ -96,6 +96,26 @@ go build -o api-proxy main.go
./api-proxy 8080 ./api-proxy 8080
``` ```
### 可选安全开关(默认关闭)
你可以在部署时通过环境变量开启请求体大小限制和基础限流;不设置时保持“开放代理”行为。
```bash
# 限制请求体最大 10MB0 或不设置=关闭)
export PROXY_MAX_BODY_BYTES=10485760
# 基础限流:每个客户端 IP 每秒 5 个请求,突发 10任一为 0 或不设置=关闭)
export PROXY_RATE_LIMIT_RPS=5
export PROXY_RATE_LIMIT_BURST=10
./api-proxy
```
说明:
- `PROXY_MAX_BODY_BYTES`: 请求体上限(字节)
- `PROXY_RATE_LIMIT_RPS`: 每秒补充 token 数
- `PROXY_RATE_LIMIT_BURST`: 桶容量;若只设置 `RPS`,会自动取 `ceil(RPS)`
## 使用示例 ## 使用示例
假设你的服务运行在 `http://localhost:7890` 假设你的服务运行在 `http://localhost:7890`

View File

@@ -7,4 +7,8 @@ services:
# 如果您在 main.go 中配置的端口是 7890那么这里保持 7890 # 如果您在 main.go 中配置的端口是 7890那么这里保持 7890
ports: ports:
- "7890:7890" # 映射:将主机的 80 端口映射到容器内的 7890 端口 - "7890:7890" # 映射:将主机的 80 端口映射到容器内的 7890 端口
environment:
- PROXY_MAX_BODY_BYTES=${PROXY_MAX_BODY_BYTES:-0}
- PROXY_RATE_LIMIT_RPS=${PROXY_RATE_LIMIT_RPS:-0}
- PROXY_RATE_LIMIT_BURST=${PROXY_RATE_LIMIT_BURST:-0}
restart: always restart: always

166
main.go
View File

@@ -1,15 +1,20 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"html/template" "html/template"
"log" "log"
"math"
"net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"os" "os"
"sort" "sort"
"strconv"
"strings" "strings"
"sync"
"time" "time"
) )
@@ -20,6 +25,23 @@ const (
var ( var (
proxyMap = make(map[string]*httputil.ReverseProxy) proxyMap = make(map[string]*httputil.ReverseProxy)
tpl *template.Template tpl *template.Template
maxBodyBytes int64
rateLimitRPS float64
rateLimitBurst int
limiterState = struct {
mu sync.Mutex
clients map[string]*tokenBucket
lastCleanup time.Time
}{
clients: make(map[string]*tokenBucket),
}
)
const (
limiterIdleTTL = 10 * time.Minute
limiterCleanupInterval = 5 * time.Minute
) )
// Denied headers that should not be forwarded to the upstream API // Denied headers that should not be forwarded to the upstream API
@@ -218,6 +240,12 @@ type loggingResponseWriter struct {
statusCode int statusCode int
} }
type tokenBucket struct {
tokens float64
lastRefill time.Time
lastSeen time.Time
}
func (lrw *loggingResponseWriter) WriteHeader(code int) { func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code) lrw.ResponseWriter.WriteHeader(code)
@@ -253,9 +281,119 @@ func getClientIP(r *http.Request) string {
parts := strings.Split(xff, ",") parts := strings.Split(xff, ",")
return strings.TrimSpace(parts[0]) return strings.TrimSpace(parts[0])
} }
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
return host
}
return r.RemoteAddr return r.RemoteAddr
} }
func parseInt64Env(name string) int64 {
raw := strings.TrimSpace(os.Getenv(name))
if raw == "" {
return 0
}
v, err := strconv.ParseInt(raw, 10, 64)
if err != nil || v < 0 {
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
return 0
}
return v
}
func parseFloatEnv(name string) float64 {
raw := strings.TrimSpace(os.Getenv(name))
if raw == "" {
return 0
}
v, err := strconv.ParseFloat(raw, 64)
if err != nil || v < 0 {
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
return 0
}
return v
}
func parseIntEnv(name string) int {
raw := strings.TrimSpace(os.Getenv(name))
if raw == "" {
return 0
}
v, err := strconv.Atoi(raw)
if err != nil || v < 0 {
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
return 0
}
return v
}
func loadOptionalProtections() {
maxBodyBytes = parseInt64Env("PROXY_MAX_BODY_BYTES")
rateLimitRPS = parseFloatEnv("PROXY_RATE_LIMIT_RPS")
rateLimitBurst = parseIntEnv("PROXY_RATE_LIMIT_BURST")
if rateLimitRPS > 0 && rateLimitBurst <= 0 {
rateLimitBurst = int(math.Ceil(rateLimitRPS))
if rateLimitBurst < 1 {
rateLimitBurst = 1
}
}
if rateLimitRPS <= 0 || rateLimitBurst <= 0 {
rateLimitRPS = 0
rateLimitBurst = 0
}
log.Printf("Optional protections: PROXY_MAX_BODY_BYTES=%d (0=off), PROXY_RATE_LIMIT_RPS=%.2f, PROXY_RATE_LIMIT_BURST=%d (0=off)", maxBodyBytes, rateLimitRPS, rateLimitBurst)
}
func allowByRateLimit(clientID string) bool {
if rateLimitRPS <= 0 || rateLimitBurst <= 0 {
return true
}
now := time.Now()
limiterState.mu.Lock()
defer limiterState.mu.Unlock()
if now.Sub(limiterState.lastCleanup) >= limiterCleanupInterval {
for key, bucket := range limiterState.clients {
if now.Sub(bucket.lastSeen) > limiterIdleTTL {
delete(limiterState.clients, key)
}
}
limiterState.lastCleanup = now
}
bucket, ok := limiterState.clients[clientID]
if !ok {
limiterState.clients[clientID] = &tokenBucket{
tokens: float64(rateLimitBurst - 1),
lastRefill: now,
lastSeen: now,
}
return true
}
elapsed := now.Sub(bucket.lastRefill).Seconds()
if elapsed > 0 {
bucket.tokens = math.Min(float64(rateLimitBurst), bucket.tokens+elapsed*rateLimitRPS)
bucket.lastRefill = now
}
bucket.lastSeen = now
if bucket.tokens >= 1 {
bucket.tokens -= 1
return true
}
return false
}
func writeJSONError(w http.ResponseWriter, status int, message, errType string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
fmt.Fprintf(w, `{"error":{"message":%q,"type":%q}}`, message, errType)
}
func init() { func init() {
// Parse template // Parse template
@@ -325,9 +463,14 @@ func init() {
clientIP := getClientIP(r) clientIP := getClientIP(r)
log.Printf("[ERROR] Client: %s | Target: %s | Error: %v", clientIP, targetURL.Host, err) log.Printf("[ERROR] Client: %s | Target: %s | Error: %v", clientIP, targetURL.Host, err)
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) || strings.Contains(strings.ToLower(err.Error()), "request body too large") {
writeJSONError(w, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large")
return
}
// 返回 JSON 格式错误,方便 AI 客户端解析 // 返回 JSON 格式错误,方便 AI 客户端解析
w.WriteHeader(http.StatusBadGateway) writeJSONError(w, http.StatusBadGateway, "Proxy connection error", "proxy_error")
fmt.Fprintf(w, `{"error": {"message": "Proxy Connection Error: %v", "type": "proxy_error"}}`, err)
} }
proxyMap[path] = proxy proxyMap[path] = proxy
@@ -385,6 +528,24 @@ func handler(w http.ResponseWriter, r *http.Request) {
// 4. Serve Proxy // 4. Serve Proxy
proxy := proxyMap[matchedPrefix] proxy := proxyMap[matchedPrefix]
if rateLimitRPS > 0 {
if !allowByRateLimit(clientIP) {
lrw.Header().Set("Retry-After", "1")
writeJSONError(lrw, http.StatusTooManyRequests, "Rate limit exceeded", "rate_limit_exceeded")
log.Printf("[429] Path: %s | IP: %s", path, clientIP)
return
}
}
if maxBodyBytes > 0 {
if r.ContentLength > maxBodyBytes && r.ContentLength != -1 {
writeJSONError(lrw, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large")
log.Printf("[413] Path: %s | IP: %s | Content-Length: %d", path, clientIP, r.ContentLength)
return
}
r.Body = http.MaxBytesReader(lrw, r.Body, maxBodyBytes)
}
// Rewrite the path: remove the prefix // Rewrite the path: remove the prefix
// Example: /openai/v1/chat/completions -> /v1/chat/completions // Example: /openai/v1/chat/completions -> /v1/chat/completions
// The SingleHostReverseProxy will append this to the target URL. // The SingleHostReverseProxy will append this to the target URL.
@@ -440,6 +601,7 @@ func main() {
if len(os.Args) > 1 { if len(os.Args) > 1 {
port = os.Args[1] port = os.Args[1]
} }
loadOptionalProtections()
// Determine the HTTP server settings // Determine the HTTP server settings
server := &http.Server{ server := &http.Server{