HEX
Server: Apache/2.4.52 (Ubuntu)
System: Linux spn-python 5.15.0-89-generic #99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023 x86_64
User: arjun (1000)
PHP: 8.1.2-1ubuntu2.20
Disabled: NONE
Upload Files
File: //home/arjun/projects/buyercall/buyercall/lib/util_sqlalchemy.py
import datetime
from sqlalchemy.inspection import inspect
from sqlalchemy import DateTime
from sqlalchemy.types import TypeDecorator
from uuid import UUID
from sqlalchemy_serializer import SerializerMixin
from buyercall.lib.util_datetime import tzware_datetime
from buyercall.extensions import db


class AwareDateTime(TypeDecorator):
    """
    A DateTime type which can only store tz-aware DateTimes.

    Source:
      https://gist.github.com/inklesspen/90b554c864b99340747e
    """
    impl = DateTime(timezone=True)
    cache_ok = False

    def process_bind_param(self, value, dialect):
        if isinstance(value, datetime.datetime) and value.tzinfo is None:
            raise ValueError('{!r} must be TZ-aware'.format(value))
        return value

    def __repr__(self):
        return 'AwareDateTime()'


class CustomSerializerMixin(SerializerMixin):
    serialize_types = (
        (UUID, lambda x: str(x)),
    )


class ResourceMixin(CustomSerializerMixin):
    # Keep track when records are created and updated.
    created_on = db.Column(AwareDateTime(),
                           nullable=False,
                           default=tzware_datetime)
    updated_on = db.Column(AwareDateTime(),
                           nullable=False,
                           default=tzware_datetime,
                           onupdate=tzware_datetime)

    @classmethod
    def sort_by(cls, field, direction):
        """
        Sort a resource by a specific field and direction.

        :param field: Field name
        :type field: str
        :param direction: Direction
        :type direction: str
        :return: tuple
        """
        if field not in cls.__table__.columns:
            # and (field not in sa_inspect(cls).all_orm_descriptors)
            field = 'created_on'

        if direction not in ('asc', 'desc'):
            direction = 'asc'

        return field, direction

    @classmethod
    def get_bulk_action_ids(cls, scope, ids, omit_ids=[],
                            query=''):
        """
        Determine which IDs are to be modified.

        :param scope: Affect all or only a subset of items
        :type scope: str
        :param ids: List of ids to be modified
        :type ids: list
        :param omit_ids: Remove 1 or more IDs from the list
        :type omit_ids: list
        :param query: Search query (if applicable)
        :type query: str
        :return: list
        """
        omit_ids = list(map(str, omit_ids))

        if scope == 'all_search_results':
            # Change the scope to go from selected ids to all search results.
            ids = cls.query.with_entities(cls.id).filter(cls.search(query))

            # SQLAlchemy returns back a list of tuples, we want a list of strs.
            ids = [str(item[0]) for item in ids]

        # Remove 1 or more items from the list, this could be useful in spots
        # where you may want to protect the current user from deleting himself
        # when bulk deleting user accounts.
        if omit_ids:
            ids = [id for id in ids if id not in omit_ids]

        return ids

    @classmethod
    def bulk_delete(cls, ids):
        """
        Delete 1 or more model instances.

        :param ids: List of ids to be deleted
        :type ids: list
        :return: Number of deleted instances
        """
        delete_count = cls.query.filter(cls.id.in_(ids)).delete(
            synchronize_session=False)
        db.session.commit()

        return delete_count

    def save(self):
        """
        Save a model instance.

        :return: self
        """
        db.session.add(self)
        db.session.commit()

        return self

    def delete(self):
        """
        Delete a model instance.

        :return: db.session.commit()'s result
        """
        db.session.delete(self)
        return db.session.commit()

    # def __str__(self):
    #     """
    #     Create a human readable version of a class instance.

    #     :return: self
    #     """
    #     obj_id = hex(id(self))
    #     columns = self.__table__.c.keys()

    #     values = ', '.join("%s=%r" % (n, getattr(self, n)) for n in columns)
    #     return '<%s %s(%s)>' % (obj_id, self.__class__.__name__, values)
    def __str__(self):
        obj_id = hex(id(self))
        columns = self.__table__.c.keys()

        values = ', '.join("%s=%r" % (n, getattr(self, n, None)) for n in columns)
        return '<%s %s(%s)>' % (obj_id, self.__class__.__name__, values)


    @classmethod
    def get_sid_from_id(cls, id):
        sid = None
        if id:
            try:
                obj = cls.query.filter(cls.id == id).first()
                sid = obj.sid
            except:
                pass
        return sid

    @classmethod
    def get_id_from_sid(cls, sid):
        id = None
        if sid:
            try:
                obj = cls.query.filter(cls.sid == sid).first()
                id = obj.id
            except:
                pass
        return id

    @classmethod
    def create(cls, **params):
        obj = cls(**params)
        obj.save()
        return obj

    @classmethod
    def get_by_id(cls, id_):
        return cls.query.filter(cls.id == id_).first() if id_ else None

    @classmethod
    def get_by_sid(cls, sid):
        return cls.query.filter(cls.sid == sid).first() if sid else None

    def put(self, return_object=False, **kwargs):
        try:
            params = locals()
            for k, v in params.items():
                setattr(self, k, v)
            self.save()
            return self if return_object else True
        except Exception as e:
            print(e)
            return False