-
Notifications
You must be signed in to change notification settings - Fork 525
Add IP-based rate limiting to protect against abuse #833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
91e9d56
4914d98
bc2a870
358b21a
e16cf8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,297 @@ | ||
| // Package ratelimit provides IP-based rate limiting middleware for HTTP servers. | ||
| package ratelimit | ||
|
|
||
| import ( | ||
| "encoding/json" | ||
| "log" | ||
| "net" | ||
| "net/http" | ||
| "strings" | ||
| "sync" | ||
| "time" | ||
|
|
||
| "golang.org/x/time/rate" | ||
| ) | ||
|
|
||
| // OnRateLimitedFunc is a callback invoked when a request is rate limited. | ||
| // It receives the client IP that was blocked. | ||
| type OnRateLimitedFunc func(ip string) | ||
|
|
||
| // Config holds the rate limiting configuration | ||
| type Config struct { | ||
| // RequestsPerMinute is the maximum number of requests allowed per minute per IP | ||
| RequestsPerMinute int | ||
| // RequestsPerHour is the maximum number of requests allowed per hour per IP | ||
| RequestsPerHour int | ||
| // CleanupInterval is how often to clean up stale entries (default: 10 minutes) | ||
| CleanupInterval time.Duration | ||
| // SkipPaths are paths that should not be rate limited | ||
| SkipPaths []string | ||
| // MaxVisitors is the maximum number of visitor entries to track (memory protection). | ||
| // When exceeded, oldest entries are evicted. Default: 100000. | ||
| MaxVisitors int | ||
| // OnRateLimited is an optional callback invoked when a request is rate limited. | ||
| // Used for recording metrics. | ||
| OnRateLimited OnRateLimitedFunc | ||
| } | ||
|
|
||
| // DefaultConfig returns the default rate limiting configuration | ||
| func DefaultConfig() Config { | ||
| return Config{ | ||
| RequestsPerMinute: 60, | ||
| RequestsPerHour: 1000, | ||
| CleanupInterval: 10 * time.Minute, | ||
| SkipPaths: []string{"/health", "/ping", "/metrics"}, | ||
| MaxVisitors: 100000, | ||
| } | ||
| } | ||
|
|
||
| // visitor tracks rate limiting state for a single IP address | ||
| type visitor struct { | ||
| minuteLimiter *rate.Limiter | ||
| hourLimiter *rate.Limiter | ||
| lastSeen time.Time | ||
| } | ||
|
|
||
| // RateLimiter implements IP-based rate limiting | ||
| type RateLimiter struct { | ||
| config Config | ||
| visitors map[string]*visitor | ||
| mu sync.RWMutex | ||
| stopCh chan struct{} | ||
| } | ||
|
|
||
| // New creates a new RateLimiter with the given configuration | ||
| func New(cfg Config) *RateLimiter { | ||
| if cfg.MaxVisitors <= 0 { | ||
| cfg.MaxVisitors = 100000 | ||
| } | ||
|
|
||
| rl := &RateLimiter{ | ||
| config: cfg, | ||
| visitors: make(map[string]*visitor), | ||
| stopCh: make(chan struct{}), | ||
| } | ||
|
|
||
| // Start background cleanup goroutine | ||
| go rl.cleanupLoop() | ||
|
|
||
| return rl | ||
| } | ||
|
|
||
| // Stop stops the background cleanup goroutine | ||
| func (rl *RateLimiter) Stop() { | ||
| close(rl.stopCh) | ||
| } | ||
|
|
||
| // cleanupLoop periodically removes stale visitor entries | ||
| func (rl *RateLimiter) cleanupLoop() { | ||
| ticker := time.NewTicker(rl.config.CleanupInterval) | ||
| defer ticker.Stop() | ||
|
|
||
| for { | ||
| select { | ||
| case <-ticker.C: | ||
| rl.cleanup() | ||
| case <-rl.stopCh: | ||
| return | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // cleanup removes visitors that haven't been seen in the last hour | ||
| func (rl *RateLimiter) cleanup() { | ||
| rl.mu.Lock() | ||
| defer rl.mu.Unlock() | ||
|
|
||
| threshold := time.Now().Add(-time.Hour) | ||
| for ip, v := range rl.visitors { | ||
| if v.lastSeen.Before(threshold) { | ||
| delete(rl.visitors, ip) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // evictOldestLocked removes the oldest visitor entry. Must be called with lock held. | ||
| func (rl *RateLimiter) evictOldestLocked() { | ||
| var oldestIP string | ||
| var oldestTime time.Time | ||
|
|
||
| for ip, v := range rl.visitors { | ||
| if oldestIP == "" || v.lastSeen.Before(oldestTime) { | ||
| oldestIP = ip | ||
| oldestTime = v.lastSeen | ||
| } | ||
| } | ||
|
|
||
| if oldestIP != "" { | ||
| delete(rl.visitors, oldestIP) | ||
| } | ||
| } | ||
|
|
||
| // getVisitor returns the visitor for the given IP, creating one if necessary. | ||
| // Implements memory protection by evicting oldest entries when MaxVisitors is reached. | ||
| func (rl *RateLimiter) getVisitor(ip string) *visitor { | ||
| // Try read lock first for existing visitors (common case) | ||
| rl.mu.RLock() | ||
| v, exists := rl.visitors[ip] | ||
| rl.mu.RUnlock() | ||
|
|
||
| if exists { | ||
| // Update timestamp - this is a minor race but acceptable for lastSeen | ||
| v.lastSeen = time.Now() | ||
| return v | ||
| } | ||
|
|
||
| // Need to create new visitor - acquire write lock | ||
| rl.mu.Lock() | ||
| defer rl.mu.Unlock() | ||
|
|
||
| // Double-check after acquiring write lock | ||
| v, exists = rl.visitors[ip] | ||
| if exists { | ||
| v.lastSeen = time.Now() | ||
| return v | ||
| } | ||
|
|
||
| // Enforce max visitors limit (memory protection) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tadasant if there are request coming from unique ips then it will iterate over all exisiting entries to find 1 ip then remove while occupying lock. After 100K, every request will take good amount of time. Please feel free to correct me if I am missing something here.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can also implement LRU based structure here. or may be remove some % of entries if the limit is hit instead of removing 1 on every new request.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will be more than happy if I can help in anyway here. |
||
| if len(rl.visitors) >= rl.config.MaxVisitors { | ||
| rl.evictOldestLocked() | ||
| } | ||
|
|
||
| // Create rate limiters: | ||
| // - Minute limiter: allows RequestsPerMinute requests per minute with burst of same | ||
| // - Hour limiter: allows RequestsPerHour requests per hour with burst of same | ||
| minuteRate := rate.Limit(float64(rl.config.RequestsPerMinute) / 60.0) // requests per second | ||
| hourRate := rate.Limit(float64(rl.config.RequestsPerHour) / 3600.0) // requests per second | ||
|
|
||
| v = &visitor{ | ||
| minuteLimiter: rate.NewLimiter(minuteRate, rl.config.RequestsPerMinute), | ||
| hourLimiter: rate.NewLimiter(hourRate, rl.config.RequestsPerHour), | ||
| lastSeen: time.Now(), | ||
| } | ||
| rl.visitors[ip] = v | ||
|
|
||
| return v | ||
| } | ||
|
|
||
| // Allow checks if a request from the given IP should be allowed | ||
| func (rl *RateLimiter) Allow(ip string) bool { | ||
| v := rl.getVisitor(ip) | ||
|
|
||
| // Both limiters must allow the request | ||
| if !v.minuteLimiter.Allow() { | ||
| return false | ||
| } | ||
| if !v.hourLimiter.Allow() { | ||
| return false | ||
| } | ||
| return true | ||
| } | ||
|
|
||
| // shouldSkip returns true if the path should not be rate limited | ||
| func (rl *RateLimiter) shouldSkip(path string) bool { | ||
| for _, skipPath := range rl.config.SkipPaths { | ||
| if path == skipPath || strings.HasPrefix(path, skipPath+"/") { | ||
| return true | ||
| } | ||
| } | ||
| return false | ||
| } | ||
|
|
||
| // getClientIP extracts the client IP from the request. | ||
| // It considers X-Forwarded-For and X-Real-IP headers for reverse proxy scenarios, | ||
| // as the registry is deployed behind NGINX ingress with use-forwarded-headers enabled. | ||
| func getClientIP(r *http.Request) string { | ||
| // Check X-Forwarded-For header (can contain multiple IPs) | ||
| if xff := r.Header.Get("X-Forwarded-For"); xff != "" { | ||
| // Take the first IP (original client) | ||
| if idx := strings.Index(xff, ","); idx != -1 { | ||
| xff = xff[:idx] | ||
| } | ||
| xff = strings.TrimSpace(xff) | ||
| if ip := validateAndNormalizeIP(xff); ip != "" { | ||
| return ip | ||
| } | ||
| } | ||
|
|
||
| // Check X-Real-IP header | ||
| if xri := r.Header.Get("X-Real-IP"); xri != "" { | ||
| if ip := validateAndNormalizeIP(strings.TrimSpace(xri)); ip != "" { | ||
| return ip | ||
| } | ||
| } | ||
|
|
||
| // Fall back to RemoteAddr | ||
| ip, _, err := net.SplitHostPort(r.RemoteAddr) | ||
| if err != nil { | ||
| // RemoteAddr might not have a port | ||
| ip = r.RemoteAddr | ||
| } | ||
|
|
||
| // Validate and normalize the IP | ||
| if validIP := validateAndNormalizeIP(ip); validIP != "" { | ||
| return validIP | ||
| } | ||
|
|
||
| // If all else fails, use a fallback that won't cause issues | ||
| return "unknown" | ||
| } | ||
|
|
||
| // validateAndNormalizeIP validates the IP string and returns a normalized form. | ||
| // Returns empty string if the IP is invalid. | ||
| func validateAndNormalizeIP(ip string) string { | ||
| if ip == "" { | ||
| return "" | ||
| } | ||
|
|
||
| // Parse the IP to validate it | ||
| parsedIP := net.ParseIP(ip) | ||
| if parsedIP == nil { | ||
| return "" | ||
| } | ||
|
|
||
| // Return normalized string representation | ||
| return parsedIP.String() | ||
| } | ||
|
|
||
| // Middleware returns an HTTP middleware that enforces rate limiting | ||
| func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { | ||
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| // Skip rate limiting for certain paths | ||
| if rl.shouldSkip(r.URL.Path) { | ||
| next.ServeHTTP(w, r) | ||
| return | ||
| } | ||
|
|
||
| ip := getClientIP(r) | ||
|
|
||
| if !rl.Allow(ip) { | ||
| // Record the rate-limited request if callback is configured | ||
| if rl.config.OnRateLimited != nil { | ||
| rl.config.OnRateLimited(ip) | ||
| } | ||
|
|
||
| w.Header().Set("Content-Type", "application/problem+json") | ||
| w.Header().Set("Retry-After", "60") | ||
| w.WriteHeader(http.StatusTooManyRequests) | ||
|
|
||
| errorBody := map[string]interface{}{ | ||
| "title": "Too Many Requests", | ||
| "status": http.StatusTooManyRequests, | ||
| "detail": "Rate limit exceeded. Please reduce request frequency and retry after some time.", | ||
| } | ||
|
|
||
| jsonData, err := json.Marshal(errorBody) | ||
| if err != nil { | ||
| log.Printf("Failed to marshal rate limit error response: %v", err) | ||
| _, _ = w.Write([]byte(`{"title":"Too Many Requests","status":429}`)) | ||
| return | ||
| } | ||
| _, _ = w.Write(jsonData) | ||
| return | ||
| } | ||
|
|
||
| next.ServeHTTP(w, r) | ||
| }) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this might be very slow when throughput is high, as per the stats on grafana over per hour throughput is always above 15K and has reached max upto ~35K. Cleanup won't happen till the time an entry is not older than an hour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may be shard based rate limiter will suit better if we are expecting per min throughput to increase.