File: //home/arjun/projects/aigenerator/venv/lib64/python3.12/site-packages/psycopg/types/enum.py
"""
Adapters for the enum type.
"""
from __future__ import annotations
from enum import Enum
from typing import Any, Generic, Mapping, Sequence, cast, TYPE_CHECKING
from .. import sql
from .. import postgres
from .. import errors as e
from ..pq import Format
from ..abc import AdaptContext, Query
from ..adapt import Buffer, Dumper, Loader
from .._compat import cache, TypeAlias, TypeVar
from .._encodings import conn_encoding
from .._typeinfo import TypeInfo
if TYPE_CHECKING:
from .._connection_base import BaseConnection
E = TypeVar("E", bound=Enum)
EnumDumpMap: TypeAlias = "dict[E, bytes]"
EnumLoadMap: TypeAlias = "dict[bytes, E]"
EnumMapping: TypeAlias = "Mapping[E, str] | Sequence[tuple[E, str]] | None"
# Hashable versions
_HEnumDumpMap: TypeAlias = "tuple[tuple[E, bytes], ...]"
_HEnumLoadMap: TypeAlias = "tuple[tuple[bytes, E], ...]"
TEXT = Format.TEXT
BINARY = Format.BINARY
class EnumInfo(TypeInfo):
"""Manage information about an enum type."""
def __init__(
self,
name: str,
oid: int,
array_oid: int,
labels: Sequence[str],
):
super().__init__(name, oid, array_oid)
self.labels = labels
# Will be set by register_enum()
self.enum: type[Enum] | None = None
@classmethod
def _get_info_query(cls, conn: BaseConnection[Any]) -> Query:
return sql.SQL(
"""\
SELECT name, oid, array_oid, array_agg(label) AS labels
FROM (
SELECT
t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
e.enumlabel AS label
FROM pg_type t
LEFT JOIN pg_enum e
ON e.enumtypid = t.oid
WHERE t.oid = {regtype}
ORDER BY e.enumsortorder
) x
GROUP BY name, oid, array_oid
"""
).format(regtype=cls._to_regtype(conn))
class _BaseEnumLoader(Loader, Generic[E]):
"""
Loader for a specific Enum class
"""
enum: type[E]
_load_map: EnumLoadMap[E]
def load(self, data: Buffer) -> E:
if not isinstance(data, bytes):
data = bytes(data)
try:
return self._load_map[data]
except KeyError:
enc = conn_encoding(self.connection)
label = data.decode(enc, "replace")
raise e.DataError(
f"bad member for enum {self.enum.__qualname__}: {label!r}"
)
class _BaseEnumDumper(Dumper, Generic[E]):
"""
Dumper for a specific Enum class
"""
enum: type[E]
_dump_map: EnumDumpMap[E]
def dump(self, value: E) -> Buffer | None:
return self._dump_map[value]
class EnumDumper(Dumper):
"""
Dumper for a generic Enum class
"""
def __init__(self, cls: type, context: AdaptContext | None = None):
super().__init__(cls, context)
self._encoding = conn_encoding(self.connection)
def dump(self, value: E) -> Buffer | None:
return value.name.encode(self._encoding)
class EnumBinaryDumper(EnumDumper):
format = BINARY
def register_enum(
info: EnumInfo,
context: AdaptContext | None = None,
enum: type[E] | None = None,
*,
mapping: EnumMapping[E] = None,
) -> None:
"""Register the adapters to load and dump a enum type.
:param info: The object with the information about the enum to register.
:param context: The context where to register the adapters. If `!None`,
register it globally.
:param enum: Python enum type matching to the PostgreSQL one. If `!None`,
a new enum will be generated and exposed as `EnumInfo.enum`.
:param mapping: Override the mapping between `!enum` members and `!info`
labels.
"""
if not info:
raise TypeError("no info passed. Is the requested enum available?")
if enum is None:
enum = cast("type[E]", _make_enum(info.name, tuple(info.labels)))
info.enum = enum
adapters = context.adapters if context else postgres.adapters
info.register(context)
load_map = _make_load_map(info, enum, mapping, context)
loader = _make_loader(info.name, info.enum, load_map)
adapters.register_loader(info.oid, loader)
loader = _make_binary_loader(info.name, info.enum, load_map)
adapters.register_loader(info.oid, loader)
dump_map = _make_dump_map(info, enum, mapping, context)
dumper = _make_dumper(info.enum, info.oid, dump_map)
adapters.register_dumper(info.enum, dumper)
dumper = _make_binary_dumper(info.enum, info.oid, dump_map)
adapters.register_dumper(info.enum, dumper)
# Cache all dynamically-generated types to avoid leaks in case the types
# cannot be GC'd.
@cache
def _make_enum(name: str, labels: tuple[str, ...]) -> Enum:
return Enum(name.title(), labels, module=__name__)
@cache
def _make_loader(
name: str, enum: type[Enum], load_map: _HEnumLoadMap[E]
) -> type[_BaseEnumLoader[E]]:
attribs = {"enum": enum, "_load_map": dict(load_map)}
return type(f"{name.title()}Loader", (_BaseEnumLoader,), attribs)
@cache
def _make_binary_loader(
name: str, enum: type[Enum], load_map: _HEnumLoadMap[E]
) -> type[_BaseEnumLoader[E]]:
attribs = {"enum": enum, "_load_map": dict(load_map), "format": BINARY}
return type(f"{name.title()}BinaryLoader", (_BaseEnumLoader,), attribs)
@cache
def _make_dumper(
enum: type[Enum], oid: int, dump_map: _HEnumDumpMap[E]
) -> type[_BaseEnumDumper[E]]:
attribs = {"enum": enum, "oid": oid, "_dump_map": dict(dump_map)}
return type(f"{enum.__name__}Dumper", (_BaseEnumDumper,), attribs)
@cache
def _make_binary_dumper(
enum: type[Enum], oid: int, dump_map: _HEnumDumpMap[E]
) -> type[_BaseEnumDumper[E]]:
attribs = {"enum": enum, "oid": oid, "_dump_map": dict(dump_map), "format": BINARY}
return type(f"{enum.__name__}BinaryDumper", (_BaseEnumDumper,), attribs)
def _make_load_map(
info: EnumInfo, enum: type[E], mapping: EnumMapping[E], context: AdaptContext | None
) -> _HEnumLoadMap[E]:
enc = conn_encoding(context.connection if context else None)
rv = []
for label in info.labels:
try:
member = enum[label]
except KeyError:
# tolerate a missing enum, assuming it won't be used. If it is we
# will get a DataError on fetch.
pass
else:
rv.append((label.encode(enc), member))
if mapping:
if isinstance(mapping, Mapping):
mapping = list(mapping.items())
for member, label in mapping:
rv.append((label.encode(enc), member))
return tuple(rv)
def _make_dump_map(
info: EnumInfo, enum: type[E], mapping: EnumMapping[E], context: AdaptContext | None
) -> _HEnumDumpMap[E]:
enc = conn_encoding(context.connection if context else None)
rv = []
for member in enum:
rv.append((member, member.name.encode(enc)))
if mapping:
if isinstance(mapping, Mapping):
mapping = list(mapping.items())
for member, label in mapping:
rv.append((member, label.encode(enc)))
return tuple(rv)
def register_default_adapters(context: AdaptContext) -> None:
context.adapters.register_dumper(Enum, EnumBinaryDumper)
context.adapters.register_dumper(Enum, EnumDumper)