mirror of
https://github.com/handsomezhuzhu/api-proxy.git
synced 2026-04-18 14:22:54 +00:00
添加可选安全开关以限制请求体大小和基础限流功能
This commit is contained in:
20
README.md
20
README.md
@@ -96,6 +96,26 @@ go build -o api-proxy main.go
|
||||
./api-proxy 8080
|
||||
```
|
||||
|
||||
### 可选安全开关(默认关闭)
|
||||
|
||||
你可以在部署时通过环境变量开启请求体大小限制和基础限流;不设置时保持“开放代理”行为。
|
||||
|
||||
```bash
|
||||
# 限制请求体最大 10MB(0 或不设置=关闭)
|
||||
export PROXY_MAX_BODY_BYTES=10485760
|
||||
|
||||
# 基础限流:每个客户端 IP 每秒 5 个请求,突发 10(任一为 0 或不设置=关闭)
|
||||
export PROXY_RATE_LIMIT_RPS=5
|
||||
export PROXY_RATE_LIMIT_BURST=10
|
||||
|
||||
./api-proxy
|
||||
```
|
||||
|
||||
说明:
|
||||
- `PROXY_MAX_BODY_BYTES`: 请求体上限(字节)
|
||||
- `PROXY_RATE_LIMIT_RPS`: 每秒补充 token 数
|
||||
- `PROXY_RATE_LIMIT_BURST`: 桶容量;若只设置 `RPS`,会自动取 `ceil(RPS)`
|
||||
|
||||
## 使用示例
|
||||
|
||||
假设你的服务运行在 `http://localhost:7890`。
|
||||
|
||||
@@ -7,4 +7,8 @@ services:
|
||||
# 如果您在 main.go 中配置的端口是 7890,那么这里保持 7890
|
||||
ports:
|
||||
- "7890:7890" # 映射:将主机的 80 端口映射到容器内的 7890 端口
|
||||
environment:
|
||||
- PROXY_MAX_BODY_BYTES=${PROXY_MAX_BODY_BYTES:-0}
|
||||
- PROXY_RATE_LIMIT_RPS=${PROXY_RATE_LIMIT_RPS:-0}
|
||||
- PROXY_RATE_LIMIT_BURST=${PROXY_RATE_LIMIT_BURST:-0}
|
||||
restart: always
|
||||
166
main.go
166
main.go
@@ -1,15 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -20,6 +25,23 @@ const (
|
||||
var (
|
||||
proxyMap = make(map[string]*httputil.ReverseProxy)
|
||||
tpl *template.Template
|
||||
|
||||
maxBodyBytes int64
|
||||
rateLimitRPS float64
|
||||
rateLimitBurst int
|
||||
|
||||
limiterState = struct {
|
||||
mu sync.Mutex
|
||||
clients map[string]*tokenBucket
|
||||
lastCleanup time.Time
|
||||
}{
|
||||
clients: make(map[string]*tokenBucket),
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
limiterIdleTTL = 10 * time.Minute
|
||||
limiterCleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// Denied headers that should not be forwarded to the upstream API
|
||||
@@ -218,6 +240,12 @@ type loggingResponseWriter struct {
|
||||
statusCode int
|
||||
}
|
||||
|
||||
type tokenBucket struct {
|
||||
tokens float64
|
||||
lastRefill time.Time
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
||||
lrw.statusCode = code
|
||||
lrw.ResponseWriter.WriteHeader(code)
|
||||
@@ -253,9 +281,119 @@ func getClientIP(r *http.Request) string {
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
return host
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
func parseInt64Env(name string) int64 {
|
||||
raw := strings.TrimSpace(os.Getenv(name))
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
v, err := strconv.ParseInt(raw, 10, 64)
|
||||
if err != nil || v < 0 {
|
||||
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func parseFloatEnv(name string) float64 {
|
||||
raw := strings.TrimSpace(os.Getenv(name))
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
v, err := strconv.ParseFloat(raw, 64)
|
||||
if err != nil || v < 0 {
|
||||
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func parseIntEnv(name string) int {
|
||||
raw := strings.TrimSpace(os.Getenv(name))
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil || v < 0 {
|
||||
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func loadOptionalProtections() {
|
||||
maxBodyBytes = parseInt64Env("PROXY_MAX_BODY_BYTES")
|
||||
rateLimitRPS = parseFloatEnv("PROXY_RATE_LIMIT_RPS")
|
||||
rateLimitBurst = parseIntEnv("PROXY_RATE_LIMIT_BURST")
|
||||
|
||||
if rateLimitRPS > 0 && rateLimitBurst <= 0 {
|
||||
rateLimitBurst = int(math.Ceil(rateLimitRPS))
|
||||
if rateLimitBurst < 1 {
|
||||
rateLimitBurst = 1
|
||||
}
|
||||
}
|
||||
if rateLimitRPS <= 0 || rateLimitBurst <= 0 {
|
||||
rateLimitRPS = 0
|
||||
rateLimitBurst = 0
|
||||
}
|
||||
|
||||
log.Printf("Optional protections: PROXY_MAX_BODY_BYTES=%d (0=off), PROXY_RATE_LIMIT_RPS=%.2f, PROXY_RATE_LIMIT_BURST=%d (0=off)", maxBodyBytes, rateLimitRPS, rateLimitBurst)
|
||||
}
|
||||
|
||||
func allowByRateLimit(clientID string) bool {
|
||||
if rateLimitRPS <= 0 || rateLimitBurst <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
limiterState.mu.Lock()
|
||||
defer limiterState.mu.Unlock()
|
||||
|
||||
if now.Sub(limiterState.lastCleanup) >= limiterCleanupInterval {
|
||||
for key, bucket := range limiterState.clients {
|
||||
if now.Sub(bucket.lastSeen) > limiterIdleTTL {
|
||||
delete(limiterState.clients, key)
|
||||
}
|
||||
}
|
||||
limiterState.lastCleanup = now
|
||||
}
|
||||
|
||||
bucket, ok := limiterState.clients[clientID]
|
||||
if !ok {
|
||||
limiterState.clients[clientID] = &tokenBucket{
|
||||
tokens: float64(rateLimitBurst - 1),
|
||||
lastRefill: now,
|
||||
lastSeen: now,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
elapsed := now.Sub(bucket.lastRefill).Seconds()
|
||||
if elapsed > 0 {
|
||||
bucket.tokens = math.Min(float64(rateLimitBurst), bucket.tokens+elapsed*rateLimitRPS)
|
||||
bucket.lastRefill = now
|
||||
}
|
||||
bucket.lastSeen = now
|
||||
|
||||
if bucket.tokens >= 1 {
|
||||
bucket.tokens -= 1
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeJSONError(w http.ResponseWriter, status int, message, errType string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
fmt.Fprintf(w, `{"error":{"message":%q,"type":%q}}`, message, errType)
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Parse template
|
||||
|
||||
@@ -325,9 +463,14 @@ func init() {
|
||||
clientIP := getClientIP(r)
|
||||
log.Printf("[ERROR] Client: %s | Target: %s | Error: %v", clientIP, targetURL.Host, err)
|
||||
|
||||
var maxErr *http.MaxBytesError
|
||||
if errors.As(err, &maxErr) || strings.Contains(strings.ToLower(err.Error()), "request body too large") {
|
||||
writeJSONError(w, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large")
|
||||
return
|
||||
}
|
||||
|
||||
// 返回 JSON 格式错误,方便 AI 客户端解析
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
fmt.Fprintf(w, `{"error": {"message": "Proxy Connection Error: %v", "type": "proxy_error"}}`, err)
|
||||
writeJSONError(w, http.StatusBadGateway, "Proxy connection error", "proxy_error")
|
||||
}
|
||||
|
||||
proxyMap[path] = proxy
|
||||
@@ -385,6 +528,24 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
||||
// 4. Serve Proxy
|
||||
proxy := proxyMap[matchedPrefix]
|
||||
|
||||
if rateLimitRPS > 0 {
|
||||
if !allowByRateLimit(clientIP) {
|
||||
lrw.Header().Set("Retry-After", "1")
|
||||
writeJSONError(lrw, http.StatusTooManyRequests, "Rate limit exceeded", "rate_limit_exceeded")
|
||||
log.Printf("[429] Path: %s | IP: %s", path, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if maxBodyBytes > 0 {
|
||||
if r.ContentLength > maxBodyBytes && r.ContentLength != -1 {
|
||||
writeJSONError(lrw, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large")
|
||||
log.Printf("[413] Path: %s | IP: %s | Content-Length: %d", path, clientIP, r.ContentLength)
|
||||
return
|
||||
}
|
||||
r.Body = http.MaxBytesReader(lrw, r.Body, maxBodyBytes)
|
||||
}
|
||||
|
||||
// Rewrite the path: remove the prefix
|
||||
// Example: /openai/v1/chat/completions -> /v1/chat/completions
|
||||
// The SingleHostReverseProxy will append this to the target URL.
|
||||
@@ -440,6 +601,7 @@ func main() {
|
||||
if len(os.Args) > 1 {
|
||||
port = os.Args[1]
|
||||
}
|
||||
loadOptionalProtections()
|
||||
|
||||
// Determine the HTTP server settings
|
||||
server := &http.Server{
|
||||
|
||||
Reference in New Issue
Block a user