File: //home/arjun/projects/buyercall_forms/buyercall/buyercall/lib/rate_limit.py
import logging
from datetime import datetime, timedelta
from enum import Enum
import redis
from flask import jsonify, request
from flask_login import current_user
logger = logging.getLogger(__name__)
class TimePeriod(Enum):
SECOND = 'second'
MINUTE = 'minute'
HOUR = 'hour'
DAY = 'day'
WEEK = 'week'
MONTH = 'month'
YEAR = 'year'
@classmethod
def is_valid(cls, value: str) -> bool:
return value in [member.value for member in cls.__members__.values()]
@classmethod
def get_by_value(cls, value: str):
for item in cls:
if item.value == value:
return item
class RateLimit:
"""Custom rate-limiting functionality"""
def __init__(self):
self.redis_client = None
self.parsed_limits = None
self.excluded_urls = None
def init_app(self, flask_app):
self.redis_client = redis.StrictRedis(host=flask_app.config.get('REDIS_CONFIG_URL'),
port=flask_app.config.get('REDIS_CONFIG_PORT'))
self.parsed_limits = self._parse_limits(flask_app.config.get('REQUEST_RATE_LIMIT'))
self.excluded_urls = self._validate_excludes(flask_app.config.get('REQUEST_RATE_LIMIT_EXCLUDE'))
def set(self):
if not self._is_config_valid() or \
not any([str(request.path).startswith(url) for url in self.parsed_limits['urls']]) or \
any([str(request.path).startswith(url) for url in self.excluded_urls]):
return
try:
user_id = current_user.get_id()
dt_string = "%Y-%m-%d %H:%M:%S"
_pattern = [_pattern for _pattern in self.parsed_limits['urls'] if str(request.path).startswith(_pattern)]
_pattern = _pattern[0] if _pattern else None
if not _pattern:
return
key = None
if request.path.startswith('/api'):
if 'Authorize' in request.headers:
auth_header = request.headers.get('Authorize', None)
if not auth_header:
return
token = auth_header.split(" ")[1]
if not token:
return
from buyercall.blueprints.partnership.models import ApiToken
api_token = ApiToken.check_token(token)
if not api_token:
return
from buyercall.blueprints.partnership.models import Partnership
partnership = Partnership.query.filter(Partnership.api_token_id == api_token).first()
if not partnership:
return
from buyercall.blueprints.user.models import User
key = f"rate_limit:partnership-{partnership.id}:{_pattern}"
partner_user = User.query.filter(User.role == 'partner',
User.partnership_id == partnership.id).first()
user_id = partner_user.id
else:
return
if not key and not current_user:
return
if not key and current_user:
user_id = current_user.get_id()
key = f"rate_limit:user-{user_id}:{_pattern}"
# Get the user's rate limit data from Redis
rate_limit_data = self.redis_client.hgetall(key)
_limit_item = None
for rule in self.parsed_limits.get('detail', []):
if _pattern in rule.get('urls'):
_limit_item = rule
break
if rate_limit_data:
remaining_requests = int(rate_limit_data.get(b'remaining_requests', _limit_item['number']))
requests_sent = int(rate_limit_data.get(b'requests_sent', 0))
notify_limit = int(rate_limit_data.get(b'notify_limit', 0))
expiry = rate_limit_data.get(b'expiry', datetime.utcnow().strftime(dt_string))
expiry_date = datetime.strptime(expiry.decode(), dt_string)
# If the expiry date has passed, reset the rate limit data
if datetime.utcnow() > expiry_date:
remaining_requests = _limit_item['number']
requests_sent = 0
notify_limit = 0
expiry_time = self._get_expiry(_limit_item['period'])
expiry_date = datetime.utcnow() + timedelta(seconds=expiry_time)
self.redis_client.hset(key, b'remaining_requests', remaining_requests)
self.redis_client.hset(key, b'requests_sent', requests_sent)
self.redis_client.hset(key, b'expiry', expiry_date.strftime(dt_string))
self.redis_client.hset(key, b'notify_limit', notify_limit)
if remaining_requests <= 0:
notify_limit += 1
if _limit_item.get('notify_limit'):
if notify_limit > _limit_item.get('notify_limit', 0) and notify_limit != 1:
# notify sysadmin
from buyercall.blueprints.sysadmin.tasks import send_request_limit_exceeded_mail
from buyercall.blueprints.sysadmin.utilities.ip_api import IpApi
_limit_string = f"{_limit_item['number']} per {_limit_item['period']}"
user_ip = request.environ.get('HTTP_X_FORWARDED_FOR') \
if request.environ.get('HTTP_X_FORWARDED_FOR') else request.environ.get('REMOTE_ADDR')
ip_details = IpApi.get_request_complete_details(user_ip)
if ip_details.get('city', None):
location = f"{ip_details.get('city', '')}, {ip_details.get('regionName', '')}, " \
f"{ip_details.get('country', '')}"
else:
location = None
send_request_limit_exceeded_mail.delay(user_id, request.path, _limit_string,
method=request.method, user_ip=user_ip,
location=location)
self.redis_client.hset(key, b'notify_limit', 0)
else:
self.redis_client.hset(key, b'notify_limit', notify_limit)
return jsonify(errors='Rate limit exceeded. Try again later.'), 429
else:
remaining_requests = _limit_item['number']
requests_sent = 0
expiry_time = self._get_expiry(_limit_item['period'])
expiry_date = datetime.utcnow() + timedelta(seconds=expiry_time)
self.redis_client.hset(key, b'remaining_requests', remaining_requests)
self.redis_client.hset(key, b'expiry', expiry_date.strftime(dt_string))
# Update the requests count
remaining_requests -= 1 if remaining_requests >= 0 else 0
requests_sent += 1 if requests_sent < _limit_item['number'] else _limit_item['number']
self.redis_client.hset(key, b'remaining_requests', remaining_requests)
self.redis_client.hset(key, b'requests_sent', requests_sent)
except Exception as exc:
logger.warning(f'Rate-Limit Error : {str(exc)}')
return None
def get(self, response):
if not self._is_config_valid() or not any(
[str(request.path).startswith(url) for url in self.parsed_limits['urls']]) or \
any([str(request.path).startswith(url) for url in self.excluded_urls]):
return response
# Get the user's rate limit data from Redis
user_id = current_user.get_id()
dt_string = "%Y-%m-%d %H:%M:%S"
_pattern = [_pattern for _pattern in self.parsed_limits['urls'] if str(request.path).startswith(_pattern)]
_pattern = _pattern[0] if _pattern else None
if not _pattern:
return response
key = f"rate_limit:{user_id}:{_pattern}"
_limit_item = None
for rule in self.parsed_limits.get('detail', []):
if _pattern in rule.get('urls'):
_limit_item = rule
break
rate_limit_data = self.redis_client.hgetall(key)
remaining_requests = _limit_item['number']
requests_sent = 1
expiry_time = self._get_expiry(_limit_item['period'])
expiry_date = datetime.utcnow() + timedelta(seconds=expiry_time)
if rate_limit_data:
remaining_requests = int(rate_limit_data.get(b'remaining_requests', _limit_item['number']))
requests_sent = int(rate_limit_data.get(b'requests_sent', 0))
expiry = rate_limit_data.get(b'expiry', datetime.utcnow().strftime(dt_string))
expiry_date = datetime.strptime(expiry.decode(), dt_string)
# Add rate limit data to the response headers
headers = [
("X-RateLimit-Limit", str(_limit_item['number'])),
("X-RateLimit-Sent", str(requests_sent)),
("X-RateLimit-Remaining", str(remaining_requests)),
("X-RateLimit-Reset", expiry_date.strftime(dt_string))
]
for header in headers:
response.headers.add_header(*header)
return response
@staticmethod
def _parse_limit_value(limit_string: str = None):
# Function to parse the "limit_string" from strings like "1 per day"
try:
if not limit_string or 'per' not in str(limit_string):
logger.warning(f'Invalid config: "{limit_string}"')
return
limit, period = limit_string.split(" per ")
if not TimePeriod.is_valid(period):
logger.warning(f'Invalid config: "{limit_string}"')
return
period = TimePeriod.get_by_value(period)
return int(limit), period
except Exception:
logger.warning(f'Invalid config: "{limit_string}"')
return
def _parse_limits(self, limit_rule: dict = None):
# Function to parse the REQUEST_RATE_LIMIT
try:
rl_config = {'urls': [], 'detail': []}
for rule in limit_rule:
_urls = rule.get('url_starts_with', [])
_limit = rule.get('limit')
_notify_limit = rule.get('notify_limit')
if not all(key in rule for key in ['url_starts_with', 'limit']):
return
if not isinstance(_urls, list):
return
if not _urls or not _limit:
return
number, period = self._parse_limit_value(_limit)
if not number or not period:
return
_item_config = {
'urls': _urls,
'number': number,
'period': period,
'notify_limit': _notify_limit
}
rl_config['urls'].extend(_urls)
rl_config['detail'].append(_item_config)
return rl_config
except Exception as e:
logger.debug(f'Invalid config: "{str(e)}"')
return None
@staticmethod
def _get_expiry(period: TimePeriod) -> int:
periods = {
TimePeriod.SECOND: 1,
TimePeriod.MINUTE: 60,
TimePeriod.HOUR: 60 * 60,
TimePeriod.DAY: 24 * 60 * 60,
TimePeriod.WEEK: 7 * 24 * 60 * 60,
TimePeriod.MONTH: 30 * 24 * 60 * 60,
TimePeriod.YEAR: 365 * 24 * 60 * 60
}
return periods.get(period, 0)
def _is_config_valid(self):
# Validate config
if not self.parsed_limits:
logger.warning('Invalid configuration: REQUEST_RATE_LIMIT. Rate limit is not applied.')
return False
return True
@staticmethod
def _validate_excludes(urls) -> list:
if not urls:
return []
if not isinstance(urls, list):
return []
for url in urls:
if not isinstance(url, str):
return []
return urls