File: //usr/local/lib/python3.10/dist-packages/langsmith/_internal/_operations.py
from __future__ import annotations
import itertools
import logging
import os
import uuid
from io import BufferedReader
from typing import Dict, Iterable, Literal, Optional, Tuple, Union, cast
from langsmith import schemas as ls_schemas
from langsmith._internal import _orjson
from langsmith._internal._compressed_traces import CompressedTraces
from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext
from langsmith._internal._serde import dumps_json as _dumps_json
logger = logging.getLogger(__name__)
class SerializedRunOperation:
operation: Literal["post", "patch"]
id: uuid.UUID
trace_id: uuid.UUID
# this is the whole object, minus the other fields which
# are popped (inputs/outputs/events/attachments)
_none: bytes
inputs: Optional[bytes]
outputs: Optional[bytes]
events: Optional[bytes]
attachments: Optional[ls_schemas.Attachments]
__slots__ = (
"operation",
"id",
"trace_id",
"_none",
"inputs",
"outputs",
"events",
"attachments",
)
def __init__(
self,
operation: Literal["post", "patch"],
id: uuid.UUID,
trace_id: uuid.UUID,
_none: bytes,
inputs: Optional[bytes] = None,
outputs: Optional[bytes] = None,
events: Optional[bytes] = None,
attachments: Optional[ls_schemas.Attachments] = None,
) -> None:
self.operation = operation
self.id = id
self.trace_id = trace_id
self._none = _none
self.inputs = inputs
self.outputs = outputs
self.events = events
self.attachments = attachments
def __eq__(self, other: object) -> bool:
return isinstance(other, SerializedRunOperation) and (
self.operation,
self.id,
self.trace_id,
self._none,
self.inputs,
self.outputs,
self.events,
self.attachments,
) == (
other.operation,
other.id,
other.trace_id,
other._none,
other.inputs,
other.outputs,
other.events,
other.attachments,
)
class SerializedFeedbackOperation:
id: uuid.UUID
trace_id: uuid.UUID
feedback: bytes
__slots__ = ("id", "trace_id", "feedback")
def __init__(self, id: uuid.UUID, trace_id: uuid.UUID, feedback: bytes) -> None:
self.id = id
self.trace_id = trace_id
self.feedback = feedback
def __eq__(self, other: object) -> bool:
return isinstance(other, SerializedFeedbackOperation) and (
self.id,
self.trace_id,
self.feedback,
) == (other.id, other.trace_id, other.feedback)
def serialize_feedback_dict(
feedback: Union[ls_schemas.FeedbackCreate, dict],
) -> SerializedFeedbackOperation:
if hasattr(feedback, "dict") and callable(getattr(feedback, "dict")):
feedback_create: dict = feedback.dict() # type: ignore
else:
feedback_create = cast(dict, feedback)
if "id" not in feedback_create:
feedback_create["id"] = uuid.uuid4()
elif isinstance(feedback_create["id"], str):
feedback_create["id"] = uuid.UUID(feedback_create["id"])
if "trace_id" not in feedback_create:
feedback_create["trace_id"] = uuid.uuid4()
elif isinstance(feedback_create["trace_id"], str):
feedback_create["trace_id"] = uuid.UUID(feedback_create["trace_id"])
return SerializedFeedbackOperation(
id=feedback_create["id"],
trace_id=feedback_create["trace_id"],
feedback=_dumps_json(feedback_create),
)
def serialize_run_dict(
operation: Literal["post", "patch"], payload: dict
) -> SerializedRunOperation:
inputs = payload.pop("inputs", None)
outputs = payload.pop("outputs", None)
events = payload.pop("events", None)
attachments = payload.pop("attachments", None)
return SerializedRunOperation(
operation=operation,
id=payload["id"],
trace_id=payload["trace_id"],
_none=_dumps_json(payload),
inputs=_dumps_json(inputs) if inputs is not None else None,
outputs=_dumps_json(outputs) if outputs is not None else None,
events=_dumps_json(events) if events is not None else None,
attachments=attachments if attachments is not None else None,
)
def combine_serialized_queue_operations(
ops: list[Union[SerializedRunOperation, SerializedFeedbackOperation]],
) -> list[Union[SerializedRunOperation, SerializedFeedbackOperation]]:
create_ops_by_id = {
op.id: op
for op in ops
if isinstance(op, SerializedRunOperation) and op.operation == "post"
}
passthrough_ops: list[
Union[SerializedRunOperation, SerializedFeedbackOperation]
] = []
for op in ops:
if isinstance(op, SerializedRunOperation):
if op.operation == "post":
continue
# must be patch
create_op = create_ops_by_id.get(op.id)
if create_op is None:
passthrough_ops.append(op)
continue
if op._none is not None and op._none != create_op._none:
# TODO optimize this more - this would currently be slowest
# for large payloads
create_op_dict = _orjson.loads(create_op._none)
op_dict = {
k: v for k, v in _orjson.loads(op._none).items() if v is not None
}
create_op_dict.update(op_dict)
create_op._none = _orjson.dumps(create_op_dict)
if op.inputs is not None:
create_op.inputs = op.inputs
if op.outputs is not None:
create_op.outputs = op.outputs
if op.events is not None:
create_op.events = op.events
if op.attachments is not None:
if create_op.attachments is None:
create_op.attachments = {}
create_op.attachments.update(op.attachments)
else:
passthrough_ops.append(op)
return list(itertools.chain(create_ops_by_id.values(), passthrough_ops))
def serialized_feedback_operation_to_multipart_parts_and_context(
op: SerializedFeedbackOperation,
) -> MultipartPartsAndContext:
return MultipartPartsAndContext(
[
(
f"feedback.{op.id}",
(
None,
op.feedback,
"application/json",
{"Content-Length": str(len(op.feedback))},
),
)
],
f"trace={op.trace_id},id={op.id}",
)
def serialized_run_operation_to_multipart_parts_and_context(
op: SerializedRunOperation,
) -> tuple[MultipartPartsAndContext, Dict[str, BufferedReader]]:
acc_parts: list[MultipartPart] = []
opened_files_dict: Dict[str, BufferedReader] = {}
# this is main object, minus inputs/outputs/events/attachments
acc_parts.append(
(
f"{op.operation}.{op.id}",
(
None,
op._none,
"application/json",
{"Content-Length": str(len(op._none))},
),
)
)
for key, value in (
("inputs", op.inputs),
("outputs", op.outputs),
("events", op.events),
):
if value is None:
continue
valb = value
acc_parts.append(
(
f"{op.operation}.{op.id}.{key}",
(
None,
valb,
"application/json",
{"Content-Length": str(len(valb))},
),
),
)
if op.attachments:
for n, (content_type, data_or_path) in op.attachments.items():
if "." in n:
logger.warning(
f"Skipping logging of attachment '{n}' "
f"for run {op.id}:"
" Invalid attachment name. Attachment names must not contain"
" periods ('.'). Please rename the attachment and try again."
)
continue
if isinstance(data_or_path, bytes):
acc_parts.append(
(
f"attachment.{op.id}.{n}",
(
None,
data_or_path,
content_type,
{"Content-Length": str(len(data_or_path))},
),
)
)
else:
file_size = os.path.getsize(data_or_path)
file = open(data_or_path, "rb")
opened_files_dict[str(data_or_path) + str(uuid.uuid4())] = file
acc_parts.append(
(
f"attachment.{op.id}.{n}",
(
None,
file,
f"{content_type}; length={file_size}",
{},
),
)
)
return (
MultipartPartsAndContext(acc_parts, f"trace={op.trace_id},id={op.id}"),
opened_files_dict,
)
def encode_multipart_parts_and_context(
parts_and_context: MultipartPartsAndContext,
boundary: str,
) -> Iterable[Tuple[bytes, Union[bytes, BufferedReader]]]:
for part_name, (filename, data, content_type, headers) in parts_and_context.parts:
header_parts = [
f"--{boundary}\r\n",
f'Content-Disposition: form-data; name="{part_name}"',
]
if filename:
header_parts.append(f'; filename="{filename}"')
header_parts.extend(
[
f"\r\nContent-Type: {content_type}\r\n",
*[f"{k}: {v}\r\n" for k, v in headers.items()],
"\r\n",
]
)
yield ("".join(header_parts).encode(), data)
def compress_multipart_parts_and_context(
parts_and_context: MultipartPartsAndContext,
compressed_traces: CompressedTraces,
boundary: str,
) -> None:
for headers, data in encode_multipart_parts_and_context(
parts_and_context, boundary
):
compressed_traces.compressor_writer.write(headers)
if isinstance(data, (bytes, bytearray)):
compressed_traces.uncompressed_size += len(data)
compressed_traces.compressor_writer.write(data)
else:
if isinstance(data, BufferedReader):
encoded_data = data.read()
else:
encoded_data = str(data).encode()
compressed_traces.uncompressed_size += len(encoded_data)
compressed_traces.compressor_writer.write(encoded_data)
# Write part terminator
compressed_traces.compressor_writer.write(b"\r\n")