mirror of
https://github.com/handsomezhuzhu/api-proxy.git
synced 2026-04-18 22:32:54 +00:00
添加可选安全开关以限制请求体大小和基础限流功能
This commit is contained in:
20
README.md
20
README.md
@@ -96,6 +96,26 @@ go build -o api-proxy main.go
|
|||||||
./api-proxy 8080
|
./api-proxy 8080
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 可选安全开关(默认关闭)
|
||||||
|
|
||||||
|
你可以在部署时通过环境变量开启请求体大小限制和基础限流;不设置时保持“开放代理”行为。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 限制请求体最大 10MB(0 或不设置=关闭)
|
||||||
|
export PROXY_MAX_BODY_BYTES=10485760
|
||||||
|
|
||||||
|
# 基础限流:每个客户端 IP 每秒 5 个请求,突发 10(任一为 0 或不设置=关闭)
|
||||||
|
export PROXY_RATE_LIMIT_RPS=5
|
||||||
|
export PROXY_RATE_LIMIT_BURST=10
|
||||||
|
|
||||||
|
./api-proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
说明:
|
||||||
|
- `PROXY_MAX_BODY_BYTES`: 请求体上限(字节)
|
||||||
|
- `PROXY_RATE_LIMIT_RPS`: 每秒补充 token 数
|
||||||
|
- `PROXY_RATE_LIMIT_BURST`: 桶容量;若只设置 `RPS`,会自动取 `ceil(RPS)`
|
||||||
|
|
||||||
## 使用示例
|
## 使用示例
|
||||||
|
|
||||||
假设你的服务运行在 `http://localhost:7890`。
|
假设你的服务运行在 `http://localhost:7890`。
|
||||||
|
|||||||
@@ -7,4 +7,8 @@ services:
|
|||||||
# 如果您在 main.go 中配置的端口是 7890,那么这里保持 7890
|
# 如果您在 main.go 中配置的端口是 7890,那么这里保持 7890
|
||||||
ports:
|
ports:
|
||||||
- "7890:7890" # 映射:将主机的 80 端口映射到容器内的 7890 端口
|
- "7890:7890" # 映射:将主机的 80 端口映射到容器内的 7890 端口
|
||||||
|
environment:
|
||||||
|
- PROXY_MAX_BODY_BYTES=${PROXY_MAX_BODY_BYTES:-0}
|
||||||
|
- PROXY_RATE_LIMIT_RPS=${PROXY_RATE_LIMIT_RPS:-0}
|
||||||
|
- PROXY_RATE_LIMIT_BURST=${PROXY_RATE_LIMIT_BURST:-0}
|
||||||
restart: always
|
restart: always
|
||||||
166
main.go
166
main.go
@@ -1,15 +1,20 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"log"
|
"log"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,6 +25,23 @@ const (
|
|||||||
var (
|
var (
|
||||||
proxyMap = make(map[string]*httputil.ReverseProxy)
|
proxyMap = make(map[string]*httputil.ReverseProxy)
|
||||||
tpl *template.Template
|
tpl *template.Template
|
||||||
|
|
||||||
|
maxBodyBytes int64
|
||||||
|
rateLimitRPS float64
|
||||||
|
rateLimitBurst int
|
||||||
|
|
||||||
|
limiterState = struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
clients map[string]*tokenBucket
|
||||||
|
lastCleanup time.Time
|
||||||
|
}{
|
||||||
|
clients: make(map[string]*tokenBucket),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
limiterIdleTTL = 10 * time.Minute
|
||||||
|
limiterCleanupInterval = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// Denied headers that should not be forwarded to the upstream API
|
// Denied headers that should not be forwarded to the upstream API
|
||||||
@@ -218,6 +240,12 @@ type loggingResponseWriter struct {
|
|||||||
statusCode int
|
statusCode int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type tokenBucket struct {
|
||||||
|
tokens float64
|
||||||
|
lastRefill time.Time
|
||||||
|
lastSeen time.Time
|
||||||
|
}
|
||||||
|
|
||||||
func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
||||||
lrw.statusCode = code
|
lrw.statusCode = code
|
||||||
lrw.ResponseWriter.WriteHeader(code)
|
lrw.ResponseWriter.WriteHeader(code)
|
||||||
@@ -253,9 +281,119 @@ func getClientIP(r *http.Request) string {
|
|||||||
parts := strings.Split(xff, ",")
|
parts := strings.Split(xff, ",")
|
||||||
return strings.TrimSpace(parts[0])
|
return strings.TrimSpace(parts[0])
|
||||||
}
|
}
|
||||||
|
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||||
|
return host
|
||||||
|
}
|
||||||
return r.RemoteAddr
|
return r.RemoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseInt64Env(name string) int64 {
|
||||||
|
raw := strings.TrimSpace(os.Getenv(name))
|
||||||
|
if raw == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
v, err := strconv.ParseInt(raw, 10, 64)
|
||||||
|
if err != nil || v < 0 {
|
||||||
|
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseFloatEnv(name string) float64 {
|
||||||
|
raw := strings.TrimSpace(os.Getenv(name))
|
||||||
|
if raw == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
v, err := strconv.ParseFloat(raw, 64)
|
||||||
|
if err != nil || v < 0 {
|
||||||
|
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIntEnv(name string) int {
|
||||||
|
raw := strings.TrimSpace(os.Getenv(name))
|
||||||
|
if raw == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
v, err := strconv.Atoi(raw)
|
||||||
|
if err != nil || v < 0 {
|
||||||
|
log.Printf("Invalid %s=%q, fallback to 0 (disabled)", name, raw)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadOptionalProtections() {
|
||||||
|
maxBodyBytes = parseInt64Env("PROXY_MAX_BODY_BYTES")
|
||||||
|
rateLimitRPS = parseFloatEnv("PROXY_RATE_LIMIT_RPS")
|
||||||
|
rateLimitBurst = parseIntEnv("PROXY_RATE_LIMIT_BURST")
|
||||||
|
|
||||||
|
if rateLimitRPS > 0 && rateLimitBurst <= 0 {
|
||||||
|
rateLimitBurst = int(math.Ceil(rateLimitRPS))
|
||||||
|
if rateLimitBurst < 1 {
|
||||||
|
rateLimitBurst = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if rateLimitRPS <= 0 || rateLimitBurst <= 0 {
|
||||||
|
rateLimitRPS = 0
|
||||||
|
rateLimitBurst = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Optional protections: PROXY_MAX_BODY_BYTES=%d (0=off), PROXY_RATE_LIMIT_RPS=%.2f, PROXY_RATE_LIMIT_BURST=%d (0=off)", maxBodyBytes, rateLimitRPS, rateLimitBurst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func allowByRateLimit(clientID string) bool {
|
||||||
|
if rateLimitRPS <= 0 || rateLimitBurst <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
limiterState.mu.Lock()
|
||||||
|
defer limiterState.mu.Unlock()
|
||||||
|
|
||||||
|
if now.Sub(limiterState.lastCleanup) >= limiterCleanupInterval {
|
||||||
|
for key, bucket := range limiterState.clients {
|
||||||
|
if now.Sub(bucket.lastSeen) > limiterIdleTTL {
|
||||||
|
delete(limiterState.clients, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
limiterState.lastCleanup = now
|
||||||
|
}
|
||||||
|
|
||||||
|
bucket, ok := limiterState.clients[clientID]
|
||||||
|
if !ok {
|
||||||
|
limiterState.clients[clientID] = &tokenBucket{
|
||||||
|
tokens: float64(rateLimitBurst - 1),
|
||||||
|
lastRefill: now,
|
||||||
|
lastSeen: now,
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
elapsed := now.Sub(bucket.lastRefill).Seconds()
|
||||||
|
if elapsed > 0 {
|
||||||
|
bucket.tokens = math.Min(float64(rateLimitBurst), bucket.tokens+elapsed*rateLimitRPS)
|
||||||
|
bucket.lastRefill = now
|
||||||
|
}
|
||||||
|
bucket.lastSeen = now
|
||||||
|
|
||||||
|
if bucket.tokens >= 1 {
|
||||||
|
bucket.tokens -= 1
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSONError(w http.ResponseWriter, status int, message, errType string) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
fmt.Fprintf(w, `{"error":{"message":%q,"type":%q}}`, message, errType)
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Parse template
|
// Parse template
|
||||||
|
|
||||||
@@ -325,9 +463,14 @@ func init() {
|
|||||||
clientIP := getClientIP(r)
|
clientIP := getClientIP(r)
|
||||||
log.Printf("[ERROR] Client: %s | Target: %s | Error: %v", clientIP, targetURL.Host, err)
|
log.Printf("[ERROR] Client: %s | Target: %s | Error: %v", clientIP, targetURL.Host, err)
|
||||||
|
|
||||||
|
var maxErr *http.MaxBytesError
|
||||||
|
if errors.As(err, &maxErr) || strings.Contains(strings.ToLower(err.Error()), "request body too large") {
|
||||||
|
writeJSONError(w, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 返回 JSON 格式错误,方便 AI 客户端解析
|
// 返回 JSON 格式错误,方便 AI 客户端解析
|
||||||
w.WriteHeader(http.StatusBadGateway)
|
writeJSONError(w, http.StatusBadGateway, "Proxy connection error", "proxy_error")
|
||||||
fmt.Fprintf(w, `{"error": {"message": "Proxy Connection Error: %v", "type": "proxy_error"}}`, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyMap[path] = proxy
|
proxyMap[path] = proxy
|
||||||
@@ -385,6 +528,24 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
|||||||
// 4. Serve Proxy
|
// 4. Serve Proxy
|
||||||
proxy := proxyMap[matchedPrefix]
|
proxy := proxyMap[matchedPrefix]
|
||||||
|
|
||||||
|
if rateLimitRPS > 0 {
|
||||||
|
if !allowByRateLimit(clientIP) {
|
||||||
|
lrw.Header().Set("Retry-After", "1")
|
||||||
|
writeJSONError(lrw, http.StatusTooManyRequests, "Rate limit exceeded", "rate_limit_exceeded")
|
||||||
|
log.Printf("[429] Path: %s | IP: %s", path, clientIP)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxBodyBytes > 0 {
|
||||||
|
if r.ContentLength > maxBodyBytes && r.ContentLength != -1 {
|
||||||
|
writeJSONError(lrw, http.StatusRequestEntityTooLarge, "Request body too large", "request_too_large")
|
||||||
|
log.Printf("[413] Path: %s | IP: %s | Content-Length: %d", path, clientIP, r.ContentLength)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Body = http.MaxBytesReader(lrw, r.Body, maxBodyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
// Rewrite the path: remove the prefix
|
// Rewrite the path: remove the prefix
|
||||||
// Example: /openai/v1/chat/completions -> /v1/chat/completions
|
// Example: /openai/v1/chat/completions -> /v1/chat/completions
|
||||||
// The SingleHostReverseProxy will append this to the target URL.
|
// The SingleHostReverseProxy will append this to the target URL.
|
||||||
@@ -440,6 +601,7 @@ func main() {
|
|||||||
if len(os.Args) > 1 {
|
if len(os.Args) > 1 {
|
||||||
port = os.Args[1]
|
port = os.Args[1]
|
||||||
}
|
}
|
||||||
|
loadOptionalProtections()
|
||||||
|
|
||||||
// Determine the HTTP server settings
|
// Determine the HTTP server settings
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
|
|||||||
Reference in New Issue
Block a user