File: //proc/1233/root/usr/local/lib/python3.10/dist-packages/langsmith/pytest_plugin.py
"""LangSmith Pytest hooks."""
import importlib.util
import json
import logging
import os
import time
from collections import defaultdict
from threading import Lock
import pytest
from langsmith import utils as ls_utils
from langsmith.testing._internal import test as ls_test
logger = logging.getLogger(__name__)
def pytest_addoption(parser):
"""Set a boolean flag for LangSmith output.
Skip if --langsmith-output is already defined.
"""
try:
# Try to add the option, will raise if it already exists
group = parser.getgroup("langsmith", "LangSmith")
group.addoption(
"--langsmith-output",
action="store_true",
default=False,
help="Use LangSmith output (requires 'rich').",
)
except ValueError:
# Option already exists
logger.warning(
"LangSmith output flag cannot be added because it's already defined."
)
def _handle_output_args(args):
"""Handle output arguments."""
if any(opt in args for opt in ["--langsmith-output"]):
# Only add --quiet if it's not already there
if not any(a in args for a in ["-qq"]):
args.insert(0, "-qq")
# Disable built-in output capturing
if not any(a in args for a in ["-s", "--capture=no"]):
args.insert(0, "-s")
if pytest.__version__.startswith("7."):
def pytest_cmdline_preparse(config, args):
"""Call immediately after command line options are parsed (pytest v7)."""
_handle_output_args(args)
else:
def pytest_load_initial_conftests(args):
"""Handle args in pytest v8+."""
_handle_output_args(args)
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(item):
"""Apply LangSmith tracking to tests marked with @pytest.mark.langsmith."""
marker = item.get_closest_marker("langsmith")
if marker:
# Get marker kwargs if any (e.g.,
# @pytest.mark.langsmith(output_keys=["expected"]))
kwargs = marker.kwargs if marker else {}
# Wrap the test function with our test decorator
original_func = item.obj
item.obj = ls_test(**kwargs)(original_func)
request_obj = getattr(item, "_request", None)
if request_obj is not None and "request" not in item.funcargs:
item.funcargs["request"] = request_obj
# Create a new FuncFixtureInfo instance with updated argnames
item._fixtureinfo = type(item._fixtureinfo)(
argnames=item._fixtureinfo.argnames + ("request",),
initialnames=item._fixtureinfo.initialnames,
names_closure=item._fixtureinfo.names_closure,
name2fixturedefs=item._fixtureinfo.name2fixturedefs,
)
yield
@pytest.hookimpl
def pytest_report_teststatus(report, config):
"""Remove the short test-status character outputs ("./F")."""
# The hook normally returns a 3-tuple: (short_letter, verbose_word, color)
# By returning empty strings, the progress characters won't show.
if config.getoption("--langsmith-output"):
return "", "", ""
class LangSmithPlugin:
"""Plugin for rendering LangSmith results."""
def __init__(self):
"""Initialize."""
from rich.console import Console # type: ignore[import-not-found]
from rich.live import Live # type: ignore[import-not-found]
self.test_suites = defaultdict(list)
self.test_suite_urls = {}
self.process_status = {} # Track process status
self.status_lock = Lock() # Thread-safe updates
self.console = Console()
self.live = Live(
self.generate_tables(), console=self.console, refresh_per_second=10
)
self.live.start()
self.live.console.print("Collecting tests...")
def pytest_collection_finish(self, session):
"""Call after collection phase is completed and session.items is populated."""
self.collected_nodeids = set()
for item in session.items:
self.collected_nodeids.add(item.nodeid)
def add_process_to_test_suite(self, test_suite, process_id):
"""Group a test case with its test suite."""
self.test_suites[test_suite].append(process_id)
def update_process_status(self, process_id, status):
"""Update test results."""
# First update
if not self.process_status:
self.live.console.print("Running tests...")
with self.status_lock:
current_status = self.process_status.get(process_id, {})
if status.get("feedback"):
current_status["feedback"] = {
**current_status.get("feedback", {}),
**status.pop("feedback"),
}
if status.get("inputs"):
current_status["inputs"] = {
**current_status.get("inputs", {}),
**status.pop("inputs"),
}
if status.get("reference_outputs"):
current_status["reference_outputs"] = {
**current_status.get("reference_outputs", {}),
**status.pop("reference_outputs"),
}
if status.get("outputs"):
current_status["outputs"] = {
**current_status.get("outputs", {}),
**status.pop("outputs"),
}
self.process_status[process_id] = {**current_status, **status}
self.live.update(self.generate_tables())
def pytest_runtest_logstart(self, nodeid):
"""Initialize live display when first test starts."""
self.update_process_status(nodeid, {"status": "running"})
def generate_tables(self):
"""Generate a collection of tables—one per suite.
Returns a 'Group' object so it can be rendered simultaneously by Rich Live.
"""
from rich.console import Group
tables = []
for suite_name in self.test_suites:
table = self._generate_table(suite_name)
tables.append(table)
group = Group(*tables)
return group
def _generate_table(self, suite_name: str):
"""Generate results table."""
from rich.table import Table # type: ignore[import-not-found]
process_ids = self.test_suites[suite_name]
title = f"""Test Suite: [bold]{suite_name}[/bold]
LangSmith URL: [bright_cyan]{self.test_suite_urls[suite_name]}[/bright_cyan]""" # noqa: E501
table = Table(title=title, title_justify="left")
table.add_column("Test")
table.add_column("Inputs")
table.add_column("Ref outputs")
table.add_column("Outputs")
table.add_column("Status")
table.add_column("Feedback")
table.add_column("Duration")
# Test, inputs, ref outputs, outputs col width
max_status = len("status")
max_duration = len("duration")
now = time.time()
durations = []
numeric_feedbacks = defaultdict(list)
# Gather data only for this suite
suite_statuses = {pid: self.process_status[pid] for pid in process_ids}
for pid, status in suite_statuses.items():
duration = status.get("end_time", now) - status.get("start_time", now)
durations.append(duration)
for k, v in status.get("feedback", {}).items():
if isinstance(v, (float, int, bool)):
numeric_feedbacks[k].append(v)
max_duration = max(len(f"{duration:.2f}s"), max_duration)
max_status = max(len(status.get("status", "queued")), max_status)
passed_count = sum(s.get("status") == "passed" for s in suite_statuses.values())
failed_count = sum(s.get("status") == "failed" for s in suite_statuses.values())
# You could arrange a row to show the aggregated data—here, in the last column:
if passed_count + failed_count:
rate = passed_count / (passed_count + failed_count)
color = "green" if rate == 1 else "red"
aggregate_status = f"[{color}]{rate:.0%}[/{color}]"
else:
aggregate_status = "Passed: --"
if durations:
aggregate_duration = f"{sum(durations) / len(durations):.2f}s"
else:
aggregate_duration = "--s"
if numeric_feedbacks:
aggregate_feedback = "\n".join(
f"{k}: {sum(v) / len(v)}" for k, v in numeric_feedbacks.items()
)
else:
aggregate_feedback = "--"
max_duration = max(max_duration, len(aggregate_duration))
max_dynamic_col_width = (self.console.width - (max_status + max_duration)) // 5
max_dynamic_col_width = max(max_dynamic_col_width, 8)
for pid, status in suite_statuses.items():
status_color = {
"running": "yellow",
"passed": "green",
"failed": "red",
"skipped": "cyan",
}.get(status.get("status", "queued"), "white")
duration = status.get("end_time", now) - status.get("start_time", now)
feedback = "\n".join(
f"{_abbreviate(k, max_len=max_dynamic_col_width)}: {int(v) if isinstance(v, bool) else v}" # noqa: E501
for k, v in status.get("feedback", {}).items()
)
inputs = json.dumps(status.get("inputs", {}))
reference_outputs = json.dumps(status.get("reference_outputs", {}))
outputs = json.dumps(status.get("outputs", {}))
table.add_row(
_abbreviate_test_name(str(pid), max_len=max_dynamic_col_width),
_abbreviate(inputs, max_len=max_dynamic_col_width),
_abbreviate(reference_outputs, max_len=max_dynamic_col_width),
_abbreviate(outputs, max_len=max_dynamic_col_width)[
-max_dynamic_col_width:
],
f"[{status_color}]{status.get('status', 'queued')}[/{status_color}]",
feedback,
f"{duration:.2f}s",
)
# Add a blank row or a section separator if you like:
table.add_row("", "", "", "", "", "", "")
# Finally, our “footer” row:
table.add_row(
"[bold]Averages[/bold]",
"",
"",
"",
aggregate_status,
aggregate_feedback,
aggregate_duration,
)
return table
def pytest_configure(self, config):
"""Disable warning reporting and show no warnings in output."""
# Disable general warning reporting
config.option.showwarnings = False
# Disable warning summary
reporter = config.pluginmanager.get_plugin("warnings-plugin")
if reporter:
reporter.warning_summary = lambda *args, **kwargs: None
def pytest_sessionfinish(self, session):
"""Stop Rich Live rendering at the end of the session."""
self.live.stop()
self.live.console.print("\nFinishing up...")
def pytest_configure(config):
"""Register the 'langsmith' marker."""
config.addinivalue_line(
"markers", "langsmith: mark test to be tracked in LangSmith"
)
if config.getoption("--langsmith-output"):
if not importlib.util.find_spec("rich"):
msg = (
"Must have 'rich' installed to use --langsmith-output. "
"Please install with: `pip install -U 'langsmith[pytest]'`"
)
raise ValueError(msg)
if os.environ.get("PYTEST_XDIST_TESTRUNUID"):
msg = (
"--langsmith-output not supported with pytest-xdist. "
"Please remove the '--langsmith-output' option or '-n' option."
)
raise ValueError(msg)
if ls_utils.test_tracking_is_disabled():
msg = (
"--langsmith-output not supported when env var"
"LANGSMITH_TEST_TRACKING='false'. Please remove the"
"'--langsmith-output' option "
"or enable test tracking."
)
raise ValueError(msg)
config.pluginmanager.register(LangSmithPlugin(), "langsmith_output_plugin")
# Suppress warnings summary
config.option.showwarnings = False
def _abbreviate(x: str, max_len: int) -> str:
if len(x) > max_len:
return x[: max_len - 3] + "..."
else:
return x
def _abbreviate_test_name(test_name: str, max_len: int) -> str:
if len(test_name) > max_len:
file, test = test_name.split("::")
if len(".py::" + test) > max_len:
return "..." + test[-(max_len - 3) :]
file_len = max_len - len("...::" + test)
return "..." + file[-file_len:] + "::" + test
else:
return test_name