Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/autoevals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,5 @@ async def evaluate_qa():
from .ragas import *
from .score import Score, Scorer, SerializableDataClass
from .string import *
from .thread_utils import *
from .value import ExactMatch
82 changes: 81 additions & 1 deletion py/autoevals/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@
```
"""

import asyncio
import inspect
import json
import os
import re
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass

import chevron
Expand All @@ -58,6 +61,11 @@

from .oai import Client, arun_cached_request, get_default_model, run_cached_request
from .score import Score
from .thread_utils import (
THREAD_VARIABLE_NAMES,
compute_thread_template_vars,
template_uses_thread_variables,
)

# Disable HTML escaping in chevron.
chevron.renderer._html_escape = lambda x: x # type: ignore[attr-defined]
Expand Down Expand Up @@ -243,6 +251,9 @@ def _request_args(self, output, expected, **kwargs):

return ret

async def _request_args_async(self, output, expected, **kwargs):
return self._request_args(output, expected, **kwargs)

def _process_response(self, resp):
metadata = {}
if "tool_calls" not in resp:
Expand All @@ -268,7 +279,9 @@ def _postprocess_response(self, resp):
raise ValueError("Empty response from OpenAI")

async def _run_eval_async(self, output, expected, **kwargs):
return self._postprocess_response(await arun_cached_request(**self._request_args(output, expected, **kwargs)))
return self._postprocess_response(
await arun_cached_request(**(await self._request_args_async(output, expected, **kwargs)))
)

def _run_eval_sync(self, output, expected, **kwargs):
return self._postprocess_response(run_cached_request(**self._request_args(output, expected, **kwargs)))
Expand Down Expand Up @@ -330,10 +343,15 @@ class LLMClassifier(OpenAILLMClassifier):
api_key: Deprecated. Use client instead.
base_url: Deprecated. Use client instead.
client: OpenAI client. If not provided, uses global client from init().
trace: Optional trace object for multi-turn scoring. When provided at
evaluation time and the template references thread variables
(`{{thread}}`, `{{thread_count}}`, etc.), thread variables are
derived from `trace.get_thread()` / `trace.getThread()`.
**extra_render_args: Additional template variables
"""

_SPEC_FILE_CONTENTS: dict[str, str] = defaultdict(str)
_thread_variable_names = THREAD_VARIABLE_NAMES

def __init__(
self,
Expand All @@ -353,6 +371,7 @@ def __init__(
client: Client | None = None,
**extra_render_args,
):
self._template_uses_thread_variables = template_uses_thread_variables(prompt_template)
choice_strings = list(choice_scores.keys())
# Use configured default model if not specified
if model is None:
Expand Down Expand Up @@ -384,6 +403,67 @@ def __init__(
client=client,
)

@staticmethod
def _get_trace_thread_method(trace) -> Callable[..., object] | None:
if hasattr(trace, "get_thread") and callable(trace.get_thread):
return trace.get_thread
return None

def _compute_thread_vars_sync(self, trace) -> dict[str, object]:
method = self._get_trace_thread_method(trace)
if method is None:
raise TypeError("trace must implement async get_thread(options=None)")

thread_awaitable = method()
if not inspect.isawaitable(thread_awaitable):
raise TypeError("trace.get_thread() must return an awaitable")
try:
asyncio.get_running_loop()
except RuntimeError:
thread = asyncio.run(thread_awaitable)
else:
raise RuntimeError("trace.get_thread() is async; use eval_async() when already inside an event loop")

if not isinstance(thread, list):
thread = list(thread)

computed = compute_thread_template_vars(thread)
return {name: computed[name] for name in self._thread_variable_names}

async def _compute_thread_vars_async(self, trace) -> dict[str, object]:
method = self._get_trace_thread_method(trace)
if method is None:
raise TypeError("trace must implement async get_thread(options=None)")

thread_awaitable = method()
if not inspect.isawaitable(thread_awaitable):
raise TypeError("trace.get_thread() must return an awaitable")
thread = await thread_awaitable

if not isinstance(thread, list):
thread = list(thread)

computed = compute_thread_template_vars(thread)
return {name: computed[name] for name in self._thread_variable_names}

def _request_args(self, output, expected, **kwargs):
trace = kwargs.get("trace")
thread_vars: dict[str, object] = {}
if trace is not None and self._template_uses_thread_variables:
thread_vars = self._compute_thread_vars_sync(trace)

# Thread vars come first so explicit render args can override.
return super()._request_args(output, expected, **thread_vars, **kwargs)

async def _request_args_async(self, output, expected, **kwargs):
trace = kwargs.get("trace")
thread_vars: dict[str, object] = {}
if trace is not None and self._template_uses_thread_variables:
thread_vars = await self._compute_thread_vars_async(trace)

# Thread vars come first so explicit render args can override.
return super()._request_args(output, expected, **thread_vars, **kwargs)

@classmethod
def from_spec(cls, name: str, spec: ModelGradedSpec, client: Client | None = None, **kwargs):
spec_kwargs = {}
Expand Down
169 changes: 169 additions & 0 deletions py/autoevals/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from autoevals import init
from autoevals.llm import Battle, Factuality, LLMClassifier, OpenAILLMClassifier, build_classification_tools
from autoevals.oai import OpenAIV1Module, get_default_model
from autoevals.thread_utils import compute_thread_template_vars


class TestModel(BaseModel):
Expand Down Expand Up @@ -54,6 +55,47 @@ def test_render_messages():
assert rendered[5]["content"] == ""


def test_render_messages_with_thread_variables():
classifier = OpenAILLMClassifier(
"test",
messages=[
{"role": "user", "content": "{{thread}}"},
{"role": "user", "content": "First message: {{thread.0}}"},
{"role": "user", "content": "Count: {{thread_count}}"},
{"role": "user", "content": "First: {{first_message}}"},
{"role": "user", "content": "Users: {{user_messages}}"},
{"role": "user", "content": "Pairs: {{human_ai_pairs}}"},
{
"role": "user",
"content": "Messages:{{#thread}}\n- {{role}}: {{content}}{{/thread}}",
},
],
model="gpt-4",
choice_scores={"A": 1},
classification_tools=[],
)

sample_thread = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I am doing well, thank you!"},
{"role": "user", "content": "What is the weather like?"},
{"role": "assistant", "content": "It is sunny and warm today."},
]
thread_vars = compute_thread_template_vars(sample_thread)
rendered = classifier._render_messages(**thread_vars)

assert "User:" in rendered[0]["content"]
assert "Hello, how are you?" in rendered[0]["content"]
assert "Assistant:" in rendered[0]["content"]
assert rendered[1]["content"] == "First message: user: Hello, how are you?"
assert rendered[2]["content"] == "Count: 4"
assert rendered[3]["content"] == "First: user: Hello, how are you?"
assert "Users: User:" in rendered[4]["content"]
assert "Pairs:" in rendered[5]["content"]
assert "human" in rendered[5]["content"]
assert rendered[6]["content"].startswith("Messages:\n- user: Hello, how are you?")


def test_openai():
e = OpenAILLMClassifier(
"title",
Expand Down Expand Up @@ -547,3 +589,130 @@ def capture_model(request):

# Reset for other tests
init(None)


@respx.mock
def test_llm_classifier_injects_thread_vars_from_trace():
captured_request_body = None

class TraceStub:
def __init__(self, thread):
self.thread = thread
self.calls = 0

async def get_thread(self):
self.calls += 1
return self.thread

thread = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
{"role": "user", "content": "Can you help me?"},
]
trace = TraceStub(thread)

def capture_request(request):
nonlocal captured_request_body
captured_request_body = json.loads(request.content.decode("utf-8"))
return Response(
200,
json={
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4o",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_test",
"type": "function",
"function": {"name": "select_choice", "arguments": '{"choice": "1"}'},
}
],
},
"finish_reason": "tool_calls",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
},
)

respx.post("https://api.openai.com/v1/chat/completions").mock(side_effect=capture_request)
client = OpenAI(api_key="test-api-key", base_url="https://api.openai.com/v1")
init(client)

classifier = LLMClassifier(
"thread_test",
"Thread:\n{{thread}}\nCount: {{thread_count}}\nFirst: {{first_message}}\nUsers:\n{{user_messages}}",
{"1": 1, "2": 0},
)
classifier.eval(output="irrelevant", expected="irrelevant", trace=trace)

content = captured_request_body["messages"][0]["content"]
assert trace.calls == 1
assert "Thread:" in content
assert "User:" in content
assert "Assistant:" in content
assert "Count: 3" in content
assert "First: user: Hello" in content
assert "Users:" in content


@respx.mock
def test_llm_classifier_does_not_fetch_thread_when_template_does_not_use_it():
class TraceStub:
def __init__(self):
self.calls = 0

async def get_thread(self):
self.calls += 1
return [{"role": "user", "content": "unused"}]

trace = TraceStub()

respx.post("https://api.openai.com/v1/chat/completions").mock(
return_value=Response(
200,
json={
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4o",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_test",
"type": "function",
"function": {"name": "select_choice", "arguments": '{"choice": "1"}'},
}
],
},
"finish_reason": "tool_calls",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
},
)
)

client = OpenAI(api_key="test-api-key", base_url="https://api.openai.com/v1")
init(client)

classifier = LLMClassifier(
"thread_unused_test",
"Output: {{output}}",
{"1": 1, "2": 0},
)
classifier.eval(output="x", expected="y", trace=trace)

assert trace.calls == 0
Loading