File: //home/arjun/projects/aigenerator/venv/lib/python3.12/site-packages/django_redis/client/sharded.py
import re
from collections import OrderedDict
from datetime import datetime
from typing import Union
from redis.exceptions import ConnectionError
from ..exceptions import ConnectionInterrupted
from ..hash_ring import HashRing
from ..util import CacheKey
from .default import DEFAULT_TIMEOUT, DefaultClient
class ShardClient(DefaultClient):
_findhash = re.compile(r".*\{(.*)\}.*", re.I)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not isinstance(self._server, (list, tuple)):
self._server = [self._server]
self._ring = HashRing(self._server)
self._serverdict = self.connect()
def get_client(self, *args, **kwargs):
raise NotImplementedError
def connect(self, index=0):
connection_dict = {}
for name in self._server:
connection_dict[name] = self.connection_factory.connect(name)
return connection_dict
def get_server_name(self, _key):
key = str(_key)
g = self._findhash.match(key)
if g is not None and len(g.groups()) > 0:
key = g.groups()[0]
name = self._ring.get_node(key)
return name
def get_server(self, key):
name = self.get_server_name(key)
return self._serverdict[name]
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().add(
key=key, value=value, version=version, client=client, timeout=timeout
)
def get(self, key, default=None, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().get(key=key, default=default, version=version, client=client)
def get_many(self, keys, version=None):
if not keys:
return {}
recovered_data = OrderedDict()
new_keys = [self.make_key(key, version=version) for key in keys]
map_keys = dict(zip(new_keys, keys))
for key in new_keys:
client = self.get_server(key)
value = self.get(key=key, version=version, client=client)
if value is None:
continue
recovered_data[map_keys[key]] = value
return recovered_data
def set(
self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None, nx=False
):
"""
Persist a value to the cache, and set an optional expiration time.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().set(
key=key, value=value, timeout=timeout, version=version, client=client, nx=nx
)
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
"""
Set a bunch of values in the cache at once from a dict of key/value
pairs. This is much more efficient than calling set() multiple times.
If timeout is given, that timeout will be used for the key; otherwise
the default cache timeout will be used.
"""
for key, value in data.items():
self.set(key, value, timeout, version=version)
def has_key(self, key, version=None, client=None):
"""
Test if key exists.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
key = self.make_key(key, version=version)
try:
return client.exists(key) == 1
except ConnectionError as e:
raise ConnectionInterrupted(connection=client) from e
def delete(self, key, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().delete(key=key, version=version, client=client)
def ttl(self, key, version=None, client=None):
"""
Executes TTL redis command and return the "time-to-live" of specified key.
If key is a non volatile key, it returns None.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().ttl(key=key, version=version, client=client)
def pttl(self, key, version=None, client=None):
"""
Executes PTTL redis command and return the "time-to-live" of specified key
in milliseconds. If key is a non volatile key, it returns None.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().pttl(key=key, version=version, client=client)
def persist(self, key, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().persist(key=key, version=version, client=client)
def expire(self, key, timeout, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().expire(key=key, timeout=timeout, version=version, client=client)
def pexpire(self, key, timeout, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().pexpire(key=key, timeout=timeout, version=version, client=client)
def pexpire_at(self, key, when: Union[datetime, int], version=None, client=None):
"""
Set an expire flag on a ``key`` to ``when`` on a shard client.
``when`` which can be represented as an integer indicating unix
time or a Python datetime object.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().pexpire_at(key=key, when=when, version=version, client=client)
def expire_at(self, key, when: Union[datetime, int], version=None, client=None):
"""
Set an expire flag on a ``key`` to ``when`` on a shard client.
``when`` which can be represented as an integer indicating unix
time or a Python datetime object.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().expire_at(key=key, when=when, version=version, client=client)
def lock(
self,
key,
version=None,
timeout=None,
sleep=0.1,
blocking_timeout=None,
client=None,
thread_local=True,
):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
key = self.make_key(key, version=version)
return super().lock(
key,
timeout=timeout,
sleep=sleep,
client=client,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
def delete_many(self, keys, version=None):
"""
Remove multiple keys at once.
"""
res = 0
for key in [self.make_key(k, version=version) for k in keys]:
client = self.get_server(key)
res += self.delete(key, client=client)
return res
def incr_version(self, key, delta=1, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
if version is None:
version = self._backend.version
old_key = self.make_key(key, version)
value = self.get(old_key, version=version, client=client)
try:
ttl = self.ttl(old_key, version=version, client=client)
except ConnectionError as e:
raise ConnectionInterrupted(connection=client) from e
if value is None:
raise ValueError("Key '%s' not found" % key)
if isinstance(key, CacheKey):
new_key = self.make_key(key.original_key(), version=version + delta)
else:
new_key = self.make_key(key, version=version + delta)
self.set(new_key, value, timeout=ttl, client=self.get_server(new_key))
self.delete(old_key, client=client)
return version + delta
def incr(self, key, delta=1, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().incr(key=key, delta=delta, version=version, client=client)
def decr(self, key, delta=1, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().decr(key=key, delta=delta, version=version, client=client)
def iter_keys(self, key, version=None):
raise NotImplementedError("iter_keys not supported on sharded client")
def keys(self, search, version=None):
pattern = self.make_pattern(search, version=version)
keys = []
try:
for server, connection in self._serverdict.items():
keys.extend(connection.keys(pattern))
except ConnectionError as e:
# FIXME: technically all clients should be passed as `connection`.
client = self.get_server(pattern)
raise ConnectionInterrupted(connection=client) from e
return [self.reverse_key(k.decode()) for k in keys]
def delete_pattern(
self, pattern, version=None, client=None, itersize=None, prefix=None
):
"""
Remove all keys matching pattern.
"""
pattern = self.make_pattern(pattern, version=version, prefix=prefix)
kwargs = {"match": pattern}
if itersize:
kwargs["count"] = itersize
keys = []
for server, connection in self._serverdict.items():
keys.extend(key for key in connection.scan_iter(**kwargs))
res = 0
if keys:
for server, connection in self._serverdict.items():
res += connection.delete(*keys)
return res
def do_close_clients(self):
for client in self._serverdict.values():
self.disconnect(client=client)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().touch(key=key, timeout=timeout, version=version, client=client)
def clear(self, client=None):
for connection in self._serverdict.values():
connection.flushdb()