# This file is a part of Remoulade.
#
# Copyright (C) 2017,2018 CLEARTYPE SRL <bogdan@cleartype.io>
#
# Remoulade is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at
# your option) any later version.
#
# Remoulade is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import time
from threading import Lock
from typing import Callable, Dict, List, Optional, Tuple
from ..backend import RateLimiterBackend
[docs]class StubBackend(RateLimiterBackend):
"""An in-memory rate limiter backend. For use in unit tests."""
def __init__(self):
self.mutex = Lock()
self.db: Dict[str, Tuple[int, float]] = {}
def add(self, key: str, value: int, ttl: int) -> bool:
with self.mutex:
res = self._get(key)
if res is not None:
return False
return self._put(key, value, ttl)
def incr(self, key: str, amount: int, maximum: int, ttl: int) -> bool:
with self.mutex:
value = self._get(key, default=0) + amount
if value > maximum:
return False
return self._put(key, value, ttl)
def decr(self, key: str, amount: int, minimum: int, ttl: int) -> bool:
with self.mutex:
value = self._get(key, default=0) - amount
if value < minimum:
return False
return self._put(key, value, ttl)
def incr_and_sum(self, key: str, keys: Callable[[], List[str]], amount: int, maximum: int, ttl: int) -> bool:
self.add(key, 0, ttl)
with self.mutex:
value = self._get(key, default=0) + amount
if value > maximum:
return False
# TODO: Drop non-callable keys in Remoulade v2.
key_list = keys() if callable(keys) else keys
values = sum(self._get(k, default=0) for k in key_list)
total = amount + values
if total > maximum:
return False
return self._put(key, value, ttl)
def _get(self, key: str, *, default: Optional[int] = None) -> Optional[int]:
value, expiration = self.db.get(key, (None, None))
if expiration and time.monotonic() < expiration:
return value
return default
def _put(self, key: str, value: int, ttl: int) -> bool:
self.db[key] = (value, time.monotonic() + ttl / 1000)
return True