diff --git a/py/autoevals/__init__.py b/py/autoevals/__init__.py index 4d19a57..b6e1dd4 100644 --- a/py/autoevals/__init__.py +++ b/py/autoevals/__init__.py @@ -134,4 +134,5 @@ async def evaluate_qa(): from .ragas import * from .score import Score, Scorer, SerializableDataClass from .string import * +from .thread_utils import * from .value import ExactMatch diff --git a/py/autoevals/llm.py b/py/autoevals/llm.py index 0bbc7d4..e3a6482 100644 --- a/py/autoevals/llm.py +++ b/py/autoevals/llm.py @@ -45,10 +45,13 @@ ``` """ +import asyncio +import inspect import json import os import re from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass import chevron @@ -58,6 +61,11 @@ from .oai import Client, arun_cached_request, get_default_model, run_cached_request from .score import Score +from .thread_utils import ( + THREAD_VARIABLE_NAMES, + compute_thread_template_vars, + template_uses_thread_variables, +) # Disable HTML escaping in chevron. chevron.renderer._html_escape = lambda x: x # type: ignore[attr-defined] @@ -243,6 +251,9 @@ def _request_args(self, output, expected, **kwargs): return ret + async def _request_args_async(self, output, expected, **kwargs): + return self._request_args(output, expected, **kwargs) + def _process_response(self, resp): metadata = {} if "tool_calls" not in resp: @@ -268,7 +279,9 @@ def _postprocess_response(self, resp): raise ValueError("Empty response from OpenAI") async def _run_eval_async(self, output, expected, **kwargs): - return self._postprocess_response(await arun_cached_request(**self._request_args(output, expected, **kwargs))) + return self._postprocess_response( + await arun_cached_request(**(await self._request_args_async(output, expected, **kwargs))) + ) def _run_eval_sync(self, output, expected, **kwargs): return self._postprocess_response(run_cached_request(**self._request_args(output, expected, **kwargs))) @@ -330,10 +343,15 @@ class LLMClassifier(OpenAILLMClassifier): api_key: Deprecated. Use client instead. base_url: Deprecated. Use client instead. client: OpenAI client. If not provided, uses global client from init(). + trace: Optional trace object for multi-turn scoring. When provided at + evaluation time and the template references thread variables + (`{{thread}}`, `{{thread_count}}`, etc.), thread variables are + derived from `trace.get_thread()` / `trace.getThread()`. **extra_render_args: Additional template variables """ _SPEC_FILE_CONTENTS: dict[str, str] = defaultdict(str) + _thread_variable_names = THREAD_VARIABLE_NAMES def __init__( self, @@ -353,6 +371,7 @@ def __init__( client: Client | None = None, **extra_render_args, ): + self._template_uses_thread_variables = template_uses_thread_variables(prompt_template) choice_strings = list(choice_scores.keys()) # Use configured default model if not specified if model is None: @@ -384,6 +403,67 @@ def __init__( client=client, ) + @staticmethod + def _get_trace_thread_method(trace) -> Callable[..., object] | None: + if hasattr(trace, "get_thread") and callable(trace.get_thread): + return trace.get_thread + return None + + def _compute_thread_vars_sync(self, trace) -> dict[str, object]: + method = self._get_trace_thread_method(trace) + if method is None: + raise TypeError("trace must implement async get_thread(options=None)") + + thread_awaitable = method() + if not inspect.isawaitable(thread_awaitable): + raise TypeError("trace.get_thread() must return an awaitable") + try: + asyncio.get_running_loop() + except RuntimeError: + thread = asyncio.run(thread_awaitable) + else: + raise RuntimeError("trace.get_thread() is async; use eval_async() when already inside an event loop") + + if not isinstance(thread, list): + thread = list(thread) + + computed = compute_thread_template_vars(thread) + return {name: computed[name] for name in self._thread_variable_names} + + async def _compute_thread_vars_async(self, trace) -> dict[str, object]: + method = self._get_trace_thread_method(trace) + if method is None: + raise TypeError("trace must implement async get_thread(options=None)") + + thread_awaitable = method() + if not inspect.isawaitable(thread_awaitable): + raise TypeError("trace.get_thread() must return an awaitable") + thread = await thread_awaitable + + if not isinstance(thread, list): + thread = list(thread) + + computed = compute_thread_template_vars(thread) + return {name: computed[name] for name in self._thread_variable_names} + + def _request_args(self, output, expected, **kwargs): + trace = kwargs.get("trace") + thread_vars: dict[str, object] = {} + if trace is not None and self._template_uses_thread_variables: + thread_vars = self._compute_thread_vars_sync(trace) + + # Thread vars come first so explicit render args can override. + return super()._request_args(output, expected, **thread_vars, **kwargs) + + async def _request_args_async(self, output, expected, **kwargs): + trace = kwargs.get("trace") + thread_vars: dict[str, object] = {} + if trace is not None and self._template_uses_thread_variables: + thread_vars = await self._compute_thread_vars_async(trace) + + # Thread vars come first so explicit render args can override. + return super()._request_args(output, expected, **thread_vars, **kwargs) + @classmethod def from_spec(cls, name: str, spec: ModelGradedSpec, client: Client | None = None, **kwargs): spec_kwargs = {} diff --git a/py/autoevals/test_llm.py b/py/autoevals/test_llm.py index 3b129b3..fd2dd56 100644 --- a/py/autoevals/test_llm.py +++ b/py/autoevals/test_llm.py @@ -11,6 +11,7 @@ from autoevals import init from autoevals.llm import Battle, Factuality, LLMClassifier, OpenAILLMClassifier, build_classification_tools from autoevals.oai import OpenAIV1Module, get_default_model +from autoevals.thread_utils import compute_thread_template_vars class TestModel(BaseModel): @@ -54,6 +55,47 @@ def test_render_messages(): assert rendered[5]["content"] == "" +def test_render_messages_with_thread_variables(): + classifier = OpenAILLMClassifier( + "test", + messages=[ + {"role": "user", "content": "{{thread}}"}, + {"role": "user", "content": "First message: {{thread.0}}"}, + {"role": "user", "content": "Count: {{thread_count}}"}, + {"role": "user", "content": "First: {{first_message}}"}, + {"role": "user", "content": "Users: {{user_messages}}"}, + {"role": "user", "content": "Pairs: {{human_ai_pairs}}"}, + { + "role": "user", + "content": "Messages:{{#thread}}\n- {{role}}: {{content}}{{/thread}}", + }, + ], + model="gpt-4", + choice_scores={"A": 1}, + classification_tools=[], + ) + + sample_thread = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am doing well, thank you!"}, + {"role": "user", "content": "What is the weather like?"}, + {"role": "assistant", "content": "It is sunny and warm today."}, + ] + thread_vars = compute_thread_template_vars(sample_thread) + rendered = classifier._render_messages(**thread_vars) + + assert "User:" in rendered[0]["content"] + assert "Hello, how are you?" in rendered[0]["content"] + assert "Assistant:" in rendered[0]["content"] + assert rendered[1]["content"] == "First message: user: Hello, how are you?" + assert rendered[2]["content"] == "Count: 4" + assert rendered[3]["content"] == "First: user: Hello, how are you?" + assert "Users: User:" in rendered[4]["content"] + assert "Pairs:" in rendered[5]["content"] + assert "human" in rendered[5]["content"] + assert rendered[6]["content"].startswith("Messages:\n- user: Hello, how are you?") + + def test_openai(): e = OpenAILLMClassifier( "title", @@ -547,3 +589,130 @@ def capture_model(request): # Reset for other tests init(None) + + +@respx.mock +def test_llm_classifier_injects_thread_vars_from_trace(): + captured_request_body = None + + class TraceStub: + def __init__(self, thread): + self.thread = thread + self.calls = 0 + + async def get_thread(self): + self.calls += 1 + return self.thread + + thread = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + {"role": "user", "content": "Can you help me?"}, + ] + trace = TraceStub(thread) + + def capture_request(request): + nonlocal captured_request_body + captured_request_body = json.loads(request.content.decode("utf-8")) + return Response( + 200, + json={ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-4o", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_test", + "type": "function", + "function": {"name": "select_choice", "arguments": '{"choice": "1"}'}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + }, + ) + + respx.post("https://api.openai.com/v1/chat/completions").mock(side_effect=capture_request) + client = OpenAI(api_key="test-api-key", base_url="https://api.openai.com/v1") + init(client) + + classifier = LLMClassifier( + "thread_test", + "Thread:\n{{thread}}\nCount: {{thread_count}}\nFirst: {{first_message}}\nUsers:\n{{user_messages}}", + {"1": 1, "2": 0}, + ) + classifier.eval(output="irrelevant", expected="irrelevant", trace=trace) + + content = captured_request_body["messages"][0]["content"] + assert trace.calls == 1 + assert "Thread:" in content + assert "User:" in content + assert "Assistant:" in content + assert "Count: 3" in content + assert "First: user: Hello" in content + assert "Users:" in content + + +@respx.mock +def test_llm_classifier_does_not_fetch_thread_when_template_does_not_use_it(): + class TraceStub: + def __init__(self): + self.calls = 0 + + async def get_thread(self): + self.calls += 1 + return [{"role": "user", "content": "unused"}] + + trace = TraceStub() + + respx.post("https://api.openai.com/v1/chat/completions").mock( + return_value=Response( + 200, + json={ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-4o", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_test", + "type": "function", + "function": {"name": "select_choice", "arguments": '{"choice": "1"}'}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + }, + ) + ) + + client = OpenAI(api_key="test-api-key", base_url="https://api.openai.com/v1") + init(client) + + classifier = LLMClassifier( + "thread_unused_test", + "Output: {{output}}", + {"1": 1, "2": 0}, + ) + classifier.eval(output="x", expected="y", trace=trace) + + assert trace.calls == 0 diff --git a/py/autoevals/thread_utils.py b/py/autoevals/thread_utils.py new file mode 100644 index 0000000..e816348 --- /dev/null +++ b/py/autoevals/thread_utils.py @@ -0,0 +1,269 @@ +"""Thread utilities for LLM-as-a-judge scorers. + +This module provides helpers for deriving template variables from conversation +threads and formatting thread values for Mustache rendering. +""" + +from __future__ import annotations + +import json +import re +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +THREAD_VARIABLE_NAMES = [ + "thread", + "thread_count", + "first_message", + "last_message", + "user_messages", + "assistant_messages", + "human_ai_pairs", +] + +# Match variables after "{{" or "{%" (e.g., {{thread}}, {{ thread }}, {% if thread %}). +THREAD_VARIABLE_PATTERN = re.compile(r"\{[\{%]\s*(" + "|".join(THREAD_VARIABLE_NAMES) + r")") + + +def template_uses_thread_variables(template: str) -> bool: + return bool(THREAD_VARIABLE_PATTERN.search(template)) + + +def is_role_content_message(item: Any) -> bool: + return isinstance(item, Mapping) and "role" in item and "content" in item + + +def is_llm_message_array(value: Any) -> bool: + return isinstance(value, list) and all(is_role_content_message(item) for item in value) + + +def _indent(text: str, prefix: str = " ") -> str: + return "\n".join(prefix + line for line in text.split("\n")) + + +def _truncate_middle(text: str, max_len: int) -> str: + if len(text) <= max_len: + return text + chars_removed = len(text) - max_len + 30 + ellipsis = f" [...{chars_removed} chars truncated...] " + avail = max_len - len(ellipsis) + if avail <= 0: + return text[:max_len] + left = avail // 2 + right = avail - left + return text[:left] + ellipsis + text[-right:] + + +def _is_typed_part(part: Any) -> bool: + return isinstance(part, Mapping) and isinstance(part.get("type"), str) + + +@dataclass +class _PendingToolCall: + name: str + args: str + + +def _extract_tool_calls(content: list[Any]) -> dict[str, _PendingToolCall]: + tool_calls: dict[str, _PendingToolCall] = {} + for part in content: + if not _is_typed_part(part) or part["type"] != "tool_call": + continue + + tool_call_id = part.get("tool_call_id") + if not isinstance(tool_call_id, str) or not tool_call_id: + continue + + tool_name = part.get("tool_name") if isinstance(part.get("tool_name"), str) else "unknown" + + args = "" + args_obj = part.get("arguments") + if isinstance(args_obj, Mapping): + if args_obj.get("type") == "valid": + args = json.dumps(args_obj.get("value")) + else: + value = args_obj.get("value") + if isinstance(value, str): + args = value + else: + args = json.dumps(value) + + tool_calls[tool_call_id] = _PendingToolCall(name=tool_name, args=args) + + return tool_calls + + +def _unwrap_content(content: Any) -> str: + if isinstance(content, str): + try: + return _unwrap_content(json.loads(content)) + except Exception: + error_match = re.match(r"^error:\s*'(.+)'$", content, flags=re.DOTALL) + if error_match: + return error_match.group(1) + return content + + if isinstance(content, list): + text_parts: list[str] = [] + for item in content: + if isinstance(item, Mapping) and isinstance(item.get("text"), str): + text_parts.append(_unwrap_content(item["text"])) + elif isinstance(item, str): + text_parts.append(_unwrap_content(item)) + if text_parts: + return "\n".join(text_parts) + + if isinstance(content, Mapping) and isinstance(content.get("text"), str): + return _unwrap_content(content["text"]) + + return content if isinstance(content, str) else json.dumps(content) + + +def _format_tool_result( + tool_call_id: str, + tool_name: str, + output: Any, + pending_tool_calls: dict[str, _PendingToolCall], +) -> str: + pending_call = pending_tool_calls.get(tool_call_id) + name = tool_name or (pending_call.name if pending_call else "tool") + args = pending_call.args if pending_call else "" + + result_content = _unwrap_content(output) + lines = [f"Tool ({name}):"] + + if args: + lines.append(" Args:") + lines.append(f" {_truncate_middle(args, 500)}") + + is_error = ( + "error:" in result_content.lower() + or '"error"' in result_content.lower() + or result_content.lower().startswith("error") + ) + if is_error: + lines.append(" Error:") + lines.append(f" {_truncate_middle(result_content, 500)}") + else: + lines.append(" Result:") + lines.append(f" {_truncate_middle(result_content, 500)}") + + if pending_call: + del pending_tool_calls[tool_call_id] + + return "\n".join(lines) + + +def _format_tool_results(content: list[Any], pending_tool_calls: dict[str, _PendingToolCall]) -> list[str]: + results: list[str] = [] + for part in content: + if not _is_typed_part(part) or part["type"] != "tool_result": + continue + + tool_call_id = part.get("tool_call_id") if isinstance(part.get("tool_call_id"), str) else "" + tool_name = part.get("tool_name") if isinstance(part.get("tool_name"), str) else "" + results.append(_format_tool_result(tool_call_id, tool_name, part.get("output"), pending_tool_calls)) + + return results + + +def _extract_text_content(content: Any) -> str: + if isinstance(content, str): + return content if content.strip() else "" + if not isinstance(content, list): + return "" + + parts: list[str] = [] + for part in content: + if isinstance(part, str) and part.strip(): + parts.append(part) + elif _is_typed_part(part): + if part["type"] == "text" and isinstance(part.get("text"), str): + parts.append(part["text"]) + elif part["type"] == "reasoning" and isinstance(part.get("text"), str): + parts.append(f"[thinking: {part['text'][:100]}...]") + elif isinstance(part, Mapping) and isinstance(part.get("text"), str): + parts.append(part["text"]) + + return "\n".join(parts) + + +def format_message_array_as_text(messages: list[Mapping[str, Any]]) -> str: + pending_tool_calls: dict[str, _PendingToolCall] = {} + for msg in messages: + if msg.get("role") == "assistant" and isinstance(msg.get("content"), list): + for tool_call_id, pending in _extract_tool_calls(msg["content"]).items(): + pending_tool_calls[tool_call_id] = pending + + parts: list[str] = [] + for msg in messages: + role = msg.get("role", "") + if not isinstance(role, str): + role = str(role) + capitalized_role = role[:1].upper() + role[1:] + content = msg.get("content") + + if role == "tool" and isinstance(content, list): + parts.extend(_format_tool_results(content, pending_tool_calls)) + else: + text = _extract_text_content(content) + if text: + parts.append(f"{capitalized_role}:\n{_indent(text)}") + + return "\n\n".join(parts) + + +class RenderableMessage(dict[str, Any]): + def __str__(self) -> str: + role = self.get("role", "") + content = self.get("content") + content_str = content if isinstance(content, str) else json.dumps(content) + return f"{role}: {content_str}" + + +class RenderableMessageArray(list[Any]): + def __str__(self) -> str: + return format_message_array_as_text(self) + + +def _to_renderable_message(message: Mapping[str, Any]) -> RenderableMessage: + return RenderableMessage(message) + + +def _to_renderable_message_array(messages: list[Any]) -> RenderableMessageArray: + wrapped: list[Any] = [] + for message in messages: + if is_role_content_message(message): + wrapped.append(_to_renderable_message(message)) + else: + wrapped.append(message) + return RenderableMessageArray(wrapped) + + +def compute_thread_template_vars(thread: list[Any]) -> dict[str, Any]: + renderable_thread = _to_renderable_message_array(thread) if is_llm_message_array(thread) else thread + + first_message = renderable_thread[0] if len(renderable_thread) > 0 else None + last_message = renderable_thread[-1] if len(renderable_thread) > 0 else None + + user_messages = [ + message for message in renderable_thread if is_role_content_message(message) and message.get("role") == "user" + ] + assistant_messages = [ + message + for message in renderable_thread + if is_role_content_message(message) and message.get("role") == "assistant" + ] + pair_count = min(len(user_messages), len(assistant_messages)) + human_ai_pairs = [{"human": user_messages[idx], "assistant": assistant_messages[idx]} for idx in range(pair_count)] + + return { + "thread": renderable_thread, + "thread_count": len(thread), + "first_message": first_message, + "last_message": last_message, + "user_messages": _to_renderable_message_array(user_messages), + "assistant_messages": _to_renderable_message_array(assistant_messages), + "human_ai_pairs": human_ai_pairs, + }