Source code for remoulade.rate_limits.backends.stub

# This file is a part of Remoulade.
#
# Copyright (C) 2017,2018 CLEARTYPE SRL <bogdan@cleartype.io>
#
# Remoulade is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at
# your option) any later version.
#
# Remoulade is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import time
from threading import Lock
from typing import Callable, Dict, List, Optional, Tuple

from ..backend import RateLimiterBackend


[docs]class StubBackend(RateLimiterBackend): """An in-memory rate limiter backend. For use in unit tests.""" def __init__(self): self.mutex = Lock() self.db: Dict[str, Tuple[int, float]] = {} def add(self, key: str, value: int, ttl: int) -> bool: with self.mutex: res = self._get(key) if res is not None: return False return self._put(key, value, ttl) def incr(self, key: str, amount: int, maximum: int, ttl: int) -> bool: with self.mutex: value = self._get(key, default=0) + amount if value > maximum: return False return self._put(key, value, ttl) def decr(self, key: str, amount: int, minimum: int, ttl: int) -> bool: with self.mutex: value = self._get(key, default=0) - amount if value < minimum: return False return self._put(key, value, ttl) def incr_and_sum(self, key: str, keys: Callable[[], List[str]], amount: int, maximum: int, ttl: int) -> bool: self.add(key, 0, ttl) with self.mutex: value = self._get(key, default=0) + amount if value > maximum: return False # TODO: Drop non-callable keys in Remoulade v2. key_list = keys() if callable(keys) else keys values = sum(self._get(k, default=0) for k in key_list) total = amount + values if total > maximum: return False return self._put(key, value, ttl) def _get(self, key: str, *, default: Optional[int] = None) -> Optional[int]: value, expiration = self.db.get(key, (None, None)) if expiration and time.monotonic() < expiration: return value return default def _put(self, key: str, value: int, ttl: int) -> bool: self.db[key] = (value, time.monotonic() + ttl / 1000) return True