File: //home/arjun/projects/aigenerator/venv/lib64/python3.12/site-packages/celery/backends/gcs.py
"""Google Cloud Storage result store backend for Celery."""
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from os import getpid
from threading import RLock
from kombu.utils.encoding import bytes_to_str
from kombu.utils.functional import dictfilter
from kombu.utils.url import url_to_parts
from celery.exceptions import ImproperlyConfigured
from .base import KeyValueStoreBackend
try:
import requests
from google.cloud import storage
from google.cloud.storage import Client
from google.cloud.storage.retry import DEFAULT_RETRY
except ImportError:
storage = None
__all__ = ('GCSBackend',)
class GCSBackend(KeyValueStoreBackend):
"""Google Cloud Storage task result backend."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._lock = RLock()
self._pid = getpid()
self._retry_policy = DEFAULT_RETRY
self._client = None
if not storage:
raise ImproperlyConfigured(
'You must install google-cloud-storage to use gcs backend'
)
conf = self.app.conf
if self.url:
url_params = self._params_from_url()
conf.update(**dictfilter(url_params))
self.bucket_name = conf.get('gcs_bucket')
if not self.bucket_name:
raise ImproperlyConfigured(
'Missing bucket name: specify gcs_bucket to use gcs backend'
)
self.project = conf.get('gcs_project')
if not self.project:
raise ImproperlyConfigured(
'Missing project:specify gcs_project to use gcs backend'
)
self.base_path = conf.get('gcs_base_path', '').strip('/')
self._threadpool_maxsize = int(conf.get('gcs_threadpool_maxsize', 10))
self.ttl = float(conf.get('gcs_ttl') or 0)
if self.ttl < 0:
raise ImproperlyConfigured(
f'Invalid ttl: {self.ttl} must be greater than or equal to 0'
)
elif self.ttl:
if not self._is_bucket_lifecycle_rule_exists():
raise ImproperlyConfigured(
f'Missing lifecycle rule to use gcs backend with ttl on '
f'bucket: {self.bucket_name}'
)
def get(self, key):
key = bytes_to_str(key)
blob = self._get_blob(key)
try:
return blob.download_as_bytes(retry=self._retry_policy)
except storage.blob.NotFound:
return None
def set(self, key, value):
key = bytes_to_str(key)
blob = self._get_blob(key)
if self.ttl:
blob.custom_time = datetime.utcnow() + timedelta(seconds=self.ttl)
blob.upload_from_string(value, retry=self._retry_policy)
def delete(self, key):
key = bytes_to_str(key)
blob = self._get_blob(key)
if blob.exists():
blob.delete(retry=self._retry_policy)
def mget(self, keys):
with ThreadPoolExecutor() as pool:
return list(pool.map(self.get, keys))
@property
def client(self):
"""Returns a storage client."""
# make sure it's thread-safe, as creating a new client is expensive
with self._lock:
if self._client and self._pid == getpid():
return self._client
# make sure each process gets its own connection after a fork
self._client = Client(project=self.project)
self._pid = getpid()
# config the number of connections to the server
adapter = requests.adapters.HTTPAdapter(
pool_connections=self._threadpool_maxsize,
pool_maxsize=self._threadpool_maxsize,
max_retries=3,
)
client_http = self._client._http
client_http.mount("https://", adapter)
client_http._auth_request.session.mount("https://", adapter)
return self._client
@property
def bucket(self):
return self.client.bucket(self.bucket_name)
def _get_blob(self, key):
key_bucket_path = f'{self.base_path}/{key}' if self.base_path else key
return self.bucket.blob(key_bucket_path)
def _is_bucket_lifecycle_rule_exists(self):
bucket = self.bucket
bucket.reload()
for rule in bucket.lifecycle_rules:
if rule['action']['type'] == 'Delete':
return True
return False
def _params_from_url(self):
url_parts = url_to_parts(self.url)
return {
'gcs_bucket': url_parts.hostname,
'gcs_base_path': url_parts.path,
**url_parts.query,
}