diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index da4787ed4..632f35d10 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -27,6 +27,9 @@ class RateLimit: def __init__(self, client_id: str, max_active_requests: int): self.max_active_requests = max_active_requests + # must be called after max_active_requests is set + if self.disabled(): + return if hasattr(self, "initialized"): return self.initialized = True @@ -37,6 +40,8 @@ class RateLimit: self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): + if self.disabled(): + return self.last_recalculate_time = time.time() # flush max active requests if use_local_value or not redis_client.exists(self.max_active_requests_key): @@ -59,18 +64,18 @@ class RateLimit: redis_client.hdel(self.active_requests_key, *timeout_requests) def enter(self, request_id: Optional[str] = None) -> str: + if self.disabled(): + return RateLimit._UNLIMITED_REQUEST_ID if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL: self.flush_cache() - if self.max_active_requests <= 0: - return RateLimit._UNLIMITED_REQUEST_ID if not request_id: request_id = RateLimit.gen_request_key() active_requests_count = redis_client.hlen(self.active_requests_key) if active_requests_count >= self.max_active_requests: raise AppInvokeQuotaExceededError( - "Too many requests. Please try again later. The current maximum " - "concurrent requests allowed is {}.".format(self.max_active_requests) + f"Too many requests. Please try again later. The current maximum concurrent requests allowed " + f"for {self.client_id} is {self.max_active_requests}." ) redis_client.hset(self.active_requests_key, request_id, str(time.time())) return request_id @@ -80,6 +85,9 @@ class RateLimit: return redis_client.hdel(self.active_requests_key, request_id) + def disabled(self): + return self.max_active_requests <= 0 + @staticmethod def gen_request_key() -> str: return str(uuid.uuid4())