diff --git a/CHANGELOG.md b/CHANGELOG.md index 098673aa..8a7bec29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added +- **Daytona usage telemetry by default** — Daytona runs now start a sandbox-local provider usage proxy so token/cost telemetry works without an external tunnel; use `--usage-tracking off` to bypass proxying when needed. - **Azure AI Foundry providers** — new `azure-foundry-openai/` and `azure-foundry-anthropic/` prefixes routing through Foundry's unified resource. Export `AZURE_API_KEY` plus `AZURE_API_ENDPOINT` (e.g. `https://.openai.azure.com/`); benchflow derives the resource name from the endpoint host, builds the per-surface base URL, and maps the key onto the agent-native auth env automatically. Missing/unrecognized endpoints and unsupported agent/provider protocol pairings fail fast with clear errors instead of falling through to the wrong endpoint. - **Azure Foundry auth guidance** — agent discovery output and docs now call out that provider-prefixed models can use provider-specific credentials instead of the agent's native/default API key. diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 764de6a6..e190a394 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -45,7 +45,7 @@ bench eval create \ --concurrency 64 \ --sandbox-setup-timeout 300 -# From remote repo with required token usage telemetry through an external tunnel +# From remote repo with required token usage telemetry bench eval create \ --source-repo benchflow-ai/skillsbench \ --source-path tasks \ @@ -53,9 +53,7 @@ bench eval create \ --model gemini-3.1-flash-lite-preview \ --sandbox daytona \ --usage-tracking required \ - --usage-proxy-url https://your-tunnel.example.com \ - --usage-proxy-port 18081 \ - --concurrency 1 \ + --concurrency 16 \ --sandbox-setup-timeout 300 # From local directory @@ -98,9 +96,6 @@ bench eval create \ | `--model` | Agent default | Model ID | | `--sandbox` | `docker` | Sandbox: docker, daytona, or modal | | `--usage-tracking` | `auto` | Token usage telemetry policy: `auto`, `required`, or `off` | -| `--usage-proxy-url` | — | Externally reachable usage-proxy base URL for remote sandboxes such as Daytona | -| `--usage-proxy-bind-host` | auto | Local interface for the usage proxy; external proxy mode defaults to `127.0.0.1` | -| `--usage-proxy-port` | random | Fixed local port for externally tunneled usage tracking | | `--environment-manifest` | — | Path to an Environment-plane manifest (`environment.toml`); applied to every rollout in the batch | | `--concurrency` | `4` | Max concurrent tasks (batch mode only) | | `--agent-idle-timeout` | (built-in default) | Abort ACP prompts after this many idle seconds; `0` disables idle detection | @@ -120,15 +115,10 @@ When mounting skills, the recommended docs default is [Architecture: skill loading](../architecture.md#skill-loading) for how `--skills-dir` is registered with each agent and how the nudge modes differ. -For official Daytona batch runs that must report provider token/cost telemetry, -use `--usage-tracking required` with a tunnel or ingress URL pointing at the -fixed `--usage-proxy-port`. The fixed-port tunnel mode supports one rollout per -BenchFlow process; use `--concurrency 1`, or run multiple jobs with separate -ports/tunnels. This limit applies only to metered external-tunnel mode; Daytona -batch runs that do not require usage telemetry can still use higher concurrency. -Without an external URL, Daytona runs continue in `auto` mode and record -`usage_source=unavailable` because the remote sandbox cannot reach a host-bound -proxy. +Daytona batch runs collect provider token/cost telemetry by default with a +sandbox-local proxy. Use `--usage-tracking required` when missing telemetry +should fail the rollout, or `--usage-tracking off` for recovery runs that should +leave provider traffic untouched. `--source-env` is for external hosted environment hubs. The first supported runner is PrimeIntellect / Verifiers: BenchFlow preserves the hosted identity diff --git a/docs/v05-e2e-testing-guide.md b/docs/v05-e2e-testing-guide.md index 41547975..2b4486be 100644 --- a/docs/v05-e2e-testing-guide.md +++ b/docs/v05-e2e-testing-guide.md @@ -25,18 +25,12 @@ All commands below assume you are in the repo root. > The examples below use `weighted-gdp-calc` (fast, ~5 tool calls) as the > default lightweight task. Swap in any task name from `$TASKS/`. -> **Usage telemetry caveat (Daytona / Modal):** Remote sandboxes run the agent -> on a host that cannot reach BenchFlow's host-bound usage proxy. Default -> `--usage-tracking auto` therefore records `agent_result.usage_source == -> "unavailable"` unless you configure an external tunnel/ingress with -> `--usage-proxy-url` and `--usage-proxy-port`. Official batch runs that need -> token/cost telemetry should use `--usage-tracking required` so the run fails -> before the agent starts if the external endpoint is missing or unhealthy. The -> fixed-port tunnel mode supports one rollout per BenchFlow process; use -> `--concurrency 1`, or run multiple jobs with separate ports/tunnels. This -> constraint is specific to metered external-tunnel mode; Daytona batches that do -> not require usage telemetry can still run with higher concurrency. Local -> sandboxes (e.g. `--sandbox docker`) populate usage telemetry without a tunnel. +> **Usage telemetry:** Docker uses a host-side provider proxy; Daytona uses a +> sandbox-local provider proxy because the agent runs on a remote host. Default +> `--usage-tracking auto` records provider token/cost telemetry when the proxy can +> be started. Use `--usage-tracking required` when missing telemetry should fail +> the rollout, or `--usage-tracking off` for recovery runs that should leave +> provider traffic untouched. --- diff --git a/src/benchflow/agents/codex_config.py b/src/benchflow/agents/codex_config.py new file mode 100644 index 00000000..1407bd21 --- /dev/null +++ b/src/benchflow/agents/codex_config.py @@ -0,0 +1,68 @@ +"""Helpers for writing Codex ACP provider configuration.""" + +from __future__ import annotations + +import json +from typing import Any + +CODEX_CONFIG_ENV = "CODEX_CONFIG" +CODEX_MODEL_PROVIDER_ENV = "MODEL_PROVIDER" + +_CODEX_PROVIDER_ID_PREFIX = "benchflow-" + + +def codex_provider_id(provider_name: str | None) -> str: + safe_name = "".join( + char if char.isalnum() or char in {"-", "_"} else "-" + for char in (provider_name or "provider").lower() + ).strip("-") + return f"{_CODEX_PROVIDER_ID_PREFIX}{safe_name or 'provider'}" + + +def apply_codex_provider_config( + agent_env: dict[str, str], + *, + base_url: str, + model: str | None, + provider_name: str, + strict: bool = False, +) -> None: + """Create or update Codex's model provider entry in ``agent_env``.""" + raw_config = agent_env.get(CODEX_CONFIG_ENV) + if not raw_config: + config: dict[str, Any] = {} + else: + try: + config = json.loads(raw_config) + except json.JSONDecodeError as exc: + if strict: + raise ValueError(f"{CODEX_CONFIG_ENV} must be valid JSON") from exc + return + if not isinstance(config, dict): + if strict: + raise ValueError(f"{CODEX_CONFIG_ENV} must decode to a JSON object") + return + + provider_id = ( + agent_env.get(CODEX_MODEL_PROVIDER_ENV) + or config.get("model_provider") + or codex_provider_id(provider_name) + ) + providers = config.get("model_providers") + providers = {} if not isinstance(providers, dict) else dict(providers) + provider = providers.get(provider_id) + provider = dict(provider) if isinstance(provider, dict) else {} + provider.setdefault("name", provider_name) + provider["base_url"] = base_url + provider.setdefault("env_key", "OPENAI_API_KEY") + provider.setdefault("wire_api", "responses") + provider.setdefault("supports_websockets", False) + + providers[provider_id] = provider + config["model_providers"] = providers + config["model_provider"] = provider_id + if model: + config["model"] = model + + agent_env[CODEX_MODEL_PROVIDER_ENV] = str(provider_id) + agent_env[CODEX_CONFIG_ENV] = json.dumps(config, separators=(",", ":")) diff --git a/src/benchflow/agents/credentials.py b/src/benchflow/agents/credentials.py index e6df0294..4570bebf 100644 --- a/src/benchflow/agents/credentials.py +++ b/src/benchflow/agents/credentials.py @@ -102,6 +102,16 @@ async def write_credential_files( await write_gemini_vertex_settings(env, agent, model, cred_home) # Agent credential files (e.g. codex auth.json) + if ( + agent == "codex-acp" + and "OPENAI_API_KEY" not in agent_env + and agent_env.get("CODEX_AUTH_JSON") + ): + path = f"{cred_home}/.codex/auth.json" + await upload_credential(env, path, agent_env["CODEX_AUTH_JSON"], owner=owner) + logger.info("Agent credential file written: %s", path) + return + if agent_cfg and agent_cfg.credential_files: for cf in agent_cfg.credential_files: value = agent_env.get(cf.env_source) diff --git a/src/benchflow/agents/env.py b/src/benchflow/agents/env.py index 8c9784a4..9b2926ba 100644 --- a/src/benchflow/agents/env.py +++ b/src/benchflow/agents/env.py @@ -23,12 +23,20 @@ from urllib.parse import urlparse from benchflow._dotenv import load_dotenv_env +from benchflow.agents.codex_config import apply_codex_provider_config from benchflow.agents.registry import AGENTS logger = logging.getLogger(__name__) _AUTH_CONTEXT_GROUPS = ( - frozenset({"ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN"}), + frozenset( + { + "ANTHROPIC_API_KEY", + "ANTHROPIC_AUTH_TOKEN", + "CLAUDE_CODE_OAUTH_TOKEN", + "CLAUDE_OAUTH_TOKEN", + } + ), frozenset({"GEMINI_API_KEY", "GOOGLE_API_KEY"}), frozenset({"OPENAI_API_KEY", "CODEX_API_KEY", "CODEX_ACCESS_TOKEN"}), ) @@ -36,15 +44,18 @@ _BEDROCK_PROXY_PLACEHOLDER_API_KEY = "bedrock-proxy" _CODEX_API_KEY_ENV = "CODEX_API_KEY" _CODEX_ACCESS_TOKEN_ENV = "CODEX_ACCESS_TOKEN" +_CODEX_AUTH_JSON_ENV = "CODEX_AUTH_JSON" +_CLAUDE_CODE_OAUTH_TOKEN_ENV = "CLAUDE_CODE_OAUTH_TOKEN" +_CLAUDE_OAUTH_TOKEN_ENV = "CLAUDE_OAUTH_TOKEN" _CUSTOM_OPENAI_ENDPOINT_KEYS = frozenset( {"BENCHFLOW_PROVIDER_BASE_URL", "OPENAI_BASE_URL"} ) +_GENERIC_PROVIDER_OVERRIDE_KEYS = frozenset( + {"BENCHFLOW_PROVIDER_BASE_URL", "BENCHFLOW_PROVIDER_API_KEY"} +) _AZURE_RESOURCE_ENV = "AZURE_RESOURCE" _AZURE_ENDPOINT_ENV = "AZURE_API_ENDPOINT" _AZURE_HOST_SUFFIXES = (".openai.azure.com", ".services.ai.azure.com") -_CODEX_CONFIG_ENV = "CODEX_CONFIG" -_CODEX_MODEL_PROVIDER_ENV = "MODEL_PROVIDER" -_CODEX_PROVIDER_ID_PREFIX = "benchflow-" def _derive_azure_resource(agent_env: dict[str, str]) -> None: @@ -121,7 +132,7 @@ def _normalize_openhands_model(model: str) -> str: OpenHands expects provider-qualified model names for some providers even when benchflow uses bare model IDs or its own provider prefixes. """ - from benchflow.agents.providers import strip_provider_prefix + from benchflow.agents.providers import find_provider, strip_provider_prefix from benchflow.agents.registry import is_vertex_model if model.startswith(("gemini/", "vertex_ai/", "openhands/")): @@ -138,6 +149,11 @@ def _normalize_openhands_model(model: str) -> str: return f"vertex_ai/{stripped}" if "gemini" in lower: return f"gemini/{stripped}" + provider = find_provider(model) + if provider is not None: + _, cfg = provider + if cfg.api_protocol == "openai-completions": + return f"openai/{stripped}" return stripped @@ -157,9 +173,12 @@ def auto_inherit_env( "AWS_BEARER_TOKEN_BEDROCK", "AWS_DEFAULT_REGION", "AWS_REGION", - "CLAUDE_CODE_OAUTH_TOKEN", + "AWS_REGION_NAME", + _CLAUDE_CODE_OAUTH_TOKEN_ENV, + _CLAUDE_OAUTH_TOKEN_ENV, "CODEX_ACCESS_TOKEN", "CODEX_API_KEY", + "CODEX_AUTH_JSON", "OPENAI_API_KEY", "OPENAI_BASE_URL", "GOOGLE_API_KEY", @@ -210,6 +229,13 @@ def auto_inherit_env( "AWS_REGION" in explicit_keys and "AWS_DEFAULT_REGION" not in explicit_keys ) or ("AWS_REGION" in agent_env and "AWS_DEFAULT_REGION" not in agent_env): agent_env["AWS_DEFAULT_REGION"] = agent_env["AWS_REGION"] + if "AWS_REGION" in agent_env and "AWS_REGION_NAME" not in agent_env: + agent_env["AWS_REGION_NAME"] = agent_env["AWS_REGION"] + if ( + _CLAUDE_OAUTH_TOKEN_ENV in agent_env + and _CLAUDE_CODE_OAUTH_TOKEN_ENV not in agent_env + ): + agent_env[_CLAUDE_CODE_OAUTH_TOKEN_ENV] = agent_env[_CLAUDE_OAUTH_TOKEN_ENV] _derive_azure_resource(agent_env) # CLAUDE_CODE_OAUTH_TOKEN is a separate auth path — Claude CLI reads it # directly. Don't map to ANTHROPIC_API_KEY (different auth mechanism). @@ -300,6 +326,21 @@ def _has_codex_access_token_auth( ) and bool(agent_env.get(_CODEX_ACCESS_TOKEN_ENV)) +def _has_codex_auth_json_auth( + agent: str, + model: str | None, + required_key: str | None, + agent_env: dict[str, str], +) -> bool: + """Return True when inline Codex auth.json can satisfy native OpenAI auth.""" + return _can_use_codex_subscription_auth( + agent, + model, + required_key, + agent_env, + ) and bool(agent_env.get(_CODEX_AUTH_JSON_ENV)) + + def inject_vertex_credentials(agent_env: dict[str, str], model: str) -> None: """Inject ADC credentials and defaults for Vertex AI models.""" from benchflow.agents.registry import is_vertex_model @@ -445,28 +486,6 @@ def _shares_auth_context(required_key: str | None, candidate_key: str | None) -> ) -def _codex_provider_id(agent_env: dict[str, str]) -> str: - provider_name = agent_env.get("BENCHFLOW_PROVIDER_NAME", "openai-compatible") - safe_name = "".join( - char if char.isalnum() or char in {"-", "_"} else "-" - for char in provider_name.lower() - ).strip("-") - return f"{_CODEX_PROVIDER_ID_PREFIX}{safe_name or 'provider'}" - - -def _load_codex_config(agent_env: dict[str, str]) -> dict: - raw_config = agent_env.get(_CODEX_CONFIG_ENV) - if not raw_config: - return {} - try: - config = json.loads(raw_config) - except json.JSONDecodeError as exc: - raise ValueError(f"{_CODEX_CONFIG_ENV} must be valid JSON") from exc - if not isinstance(config, dict): - raise ValueError(f"{_CODEX_CONFIG_ENV} must decode to a JSON object") - return config - - def _configure_codex_custom_provider( agent: str, model: str | None, @@ -483,24 +502,40 @@ def _configure_codex_custom_provider( if not base_url or not provider_model: return - provider_id = _codex_provider_id(agent_env) - config = _load_codex_config(agent_env) - providers = config.get("model_providers") - providers = {} if not isinstance(providers, dict) else dict(providers) - - providers[provider_id] = { - "name": agent_env.get("BENCHFLOW_PROVIDER_NAME", "BenchFlow provider"), - "base_url": base_url, - "env_key": "OPENAI_API_KEY", - "wire_api": "responses", - "supports_websockets": False, - } - config["model_providers"] = providers - config["model_provider"] = provider_id - config["model"] = provider_model + apply_codex_provider_config( + agent_env, + base_url=base_url, + model=provider_model, + provider_name=agent_env.get("BENCHFLOW_PROVIDER_NAME", "openai-compatible"), + strict=True, + ) + - agent_env[_CODEX_MODEL_PROVIDER_ENV] = provider_id - agent_env[_CODEX_CONFIG_ENV] = json.dumps(config, separators=(",", ":")) +def _drop_inherited_generic_provider_overrides( + agent_env: dict[str, str], + *, + model: str | None, + explicit_agent_env_keys: set[str], +) -> None: + """Let registered providers use their own endpoint/key over host defaults.""" + if not model: + return + + from benchflow.agents.providers import find_provider + + provider = find_provider(model) + if provider is None: + return + _, provider_cfg = provider + # Providers with an empty base URL (for example vllm/) are explicitly + # user-supplied endpoints, so inherited BENCHFLOW_PROVIDER_* is the normal + # configuration path. Providers with a registered URL/auth env should not be + # shadowed by a global generic proxy from .env unless the caller explicitly + # passed that override for this run. + if not provider_cfg.base_url: + return + for key in _GENERIC_PROVIDER_OVERRIDE_KEYS - explicit_agent_env_keys: + agent_env.pop(key, None) def resolve_agent_env( @@ -522,6 +557,11 @@ def resolve_agent_env( # API-key validation are skipped even if a caller forwards a model. if model and agent != "oracle": inject_vertex_credentials(agent_env, model) + _drop_inherited_generic_provider_overrides( + agent_env, + model=model, + explicit_agent_env_keys=explicit_agent_env_keys, + ) resolve_provider_env(agent_env, model, agent) from benchflow.agents.providers import find_provider @@ -567,7 +607,11 @@ def resolve_agent_env( ] has_oauth = any( key in agent_env and _shares_auth_context(required_key, key) - for key in ("CLAUDE_CODE_OAUTH_TOKEN", "ANTHROPIC_AUTH_TOKEN") + for key in ( + _CLAUDE_CODE_OAUTH_TOKEN_ENV, + _CLAUDE_OAUTH_TOKEN_ENV, + "ANTHROPIC_AUTH_TOKEN", + ) ) has_codex_access_token = _has_codex_access_token_auth( agent, @@ -575,12 +619,19 @@ def resolve_agent_env( required_key, agent_env, ) + has_codex_auth_json = _has_codex_auth_json_auth( + agent, + model, + required_key, + agent_env, + ) if ( required_key and required_key not in agent_env and not has_oauth and not has_agent_native_bridge_key and not has_codex_access_token + and not has_codex_auth_json ): if _can_use_subscription_auth( agent, @@ -612,6 +663,12 @@ def resolve_agent_env( req_key, agent_env, ) + and not _has_codex_auth_json_auth( + agent, + model, + req_key, + agent_env, + ) and _can_use_subscription_auth(agent, model, req_key, agent_env) and check_subscription_auth(agent, req_key) ): diff --git a/src/benchflow/agents/providers.py b/src/benchflow/agents/providers.py index a51034e2..27b30aed 100644 --- a/src/benchflow/agents/providers.py +++ b/src/benchflow/agents/providers.py @@ -163,6 +163,14 @@ def all_endpoints(self) -> dict[str, str]: auth_type="api_key", auth_env="OPENAI_API_KEY", # vLLM uses OpenAI-compatible auth ), + "litellm": ProviderConfig( + name="litellm", + base_url="{base_url}", + api_protocol="openai-completions", + auth_type="api_key", + auth_env="LITELLM_API_KEY", + url_params={"base_url": "LITELLM_BASE_URL"}, + ), "aws-bedrock": ProviderConfig( name="aws-bedrock", base_url="", # local Bedrock proxy supplies the runtime URL later @@ -205,6 +213,89 @@ def all_endpoints(self) -> dict[str, str]: }, ], ), + "kimi": ProviderConfig( + name="kimi", + base_url="{base_url}", + api_protocol="openai-completions", + auth_type="api_key", + auth_env="KIMI_API_KEY", + url_params={"base_url": "KIMI_BASE_URL"}, + ), + "minimax": ProviderConfig( + name="minimax", + base_url="{base_url}", + api_protocol="openai-completions", + auth_type="api_key", + auth_env="MINIMAX_API_KEY", + url_params={"base_url": "MINIMAX_BASE_URL"}, + ), + "qwen-dashscope": ProviderConfig( + name="qwen-dashscope", + base_url="{base_url}", + api_protocol="openai-completions", + auth_type="api_key", + auth_env="QWEN_API_KEY", + url_params={"base_url": "QWEN_BASE_URL"}, + ), + "glm": ProviderConfig( + name="glm", + base_url="{base_url}", + api_protocol="openai-completions", + auth_type="api_key", + auth_env="GLM_API_KEY", + url_params={"base_url": "GLM_BASE_URL"}, + models=[ + { + "id": "glm-5.1", + "name": "GLM-5.1", + "reasoning": True, + "input": ["text"], + "cost": {"input": 0, "output": 0, "cacheRead": 0, "cacheWrite": 0}, + "contextWindow": 200000, + "maxTokens": 131072, + }, + ], + ), + "deepseek": ProviderConfig( + name="deepseek", + base_url="{base_url}", + api_protocol="openai-completions", + auth_type="api_key", + auth_env="DEEPSEEK_API_KEY", + url_params={"base_url": "DEEPSEEK_BASE_URL"}, + ), + "xiaomi": ProviderConfig( + name="xiaomi", + base_url="{base_url}", + api_protocol="openai-completions", + auth_type="api_key", + auth_env="XIAOMI_API_KEY", + url_params={"base_url": "XIAOMI_BASE_URL"}, + ), + "doubao-seed-2-lite": ProviderConfig( + name="doubao-seed-2-lite", + base_url="{base_url}", + api_protocol="openai-completions", + auth_type="api_key", + auth_env="DOUBAO_SEED_2_LITE_API_KEY", + url_params={"base_url": "DOUBAO_VOLCES_BASE_URL"}, + ), + "doubao-seed-2-pro": ProviderConfig( + name="doubao-seed-2-pro", + base_url="{base_url}", + api_protocol="openai-completions", + auth_type="api_key", + auth_env="DOUBAO_SEED_2_PRO_API_KEY", + url_params={"base_url": "DOUBAO_VOLCES_BASE_URL"}, + ), + "hunyuan": ProviderConfig( + name="hunyuan", + base_url="{base_url}", + api_protocol="openai-completions", + auth_type="api_key", + auth_env="HUNYUAN_API_KEY", + url_params={"base_url": "HUNYUAN_BASE_URL"}, + ), } diff --git a/src/benchflow/cli/main.py b/src/benchflow/cli/main.py index 0c57f4e2..0491a41f 100644 --- a/src/benchflow/cli/main.py +++ b/src/benchflow/cli/main.py @@ -1069,27 +1069,6 @@ def eval_create( help="Token usage tracking policy: auto, required, or off", ), ] = None, - usage_proxy_url: Annotated[ - str | None, - typer.Option( - "--usage-proxy-url", - help="Externally reachable base URL for remote sandbox usage tracking", - ), - ] = None, - usage_proxy_bind_host: Annotated[ - str | None, - typer.Option( - "--usage-proxy-bind-host", - help="Local interface for the usage proxy to bind", - ), - ] = None, - usage_proxy_port: Annotated[ - int | None, - typer.Option( - "--usage-proxy-port", - help="Fixed local port for externally tunneled usage tracking", - ), - ] = None, environment_manifest: Annotated[ Path | None, typer.Option( @@ -1249,22 +1228,9 @@ def eval_create( eval_environment = environment or "docker" sandbox_user = normalize_sandbox_user(sandbox_user) eval_concurrency = concurrency if concurrency is not None else 4 - usage_tracking_overridden = any( - value is not None - for value in ( - usage_tracking, - usage_proxy_url, - usage_proxy_bind_host, - usage_proxy_port, - ) - ) + usage_tracking_overridden = usage_tracking is not None try: - eval_usage_tracking = UsageTrackingConfig( - mode=usage_tracking, - advertised_base_url=usage_proxy_url, - bind_host=usage_proxy_bind_host, - port=usage_proxy_port, - ) + eval_usage_tracking = UsageTrackingConfig(mode=usage_tracking) except (TypeError, ValueError) as exc: console.print(f"[red]Invalid usage tracking config: {exc}[/red]") raise typer.Exit(1) from None @@ -1301,6 +1267,8 @@ def _run_batch_eval( resolved_tasks_dir: Path, eval_config: EvaluationConfig, ): + from benchflow.eval_sharding import ShardWorkerError + try: if worker_concurrency is None: result = asyncio.run( @@ -1311,10 +1279,7 @@ def _run_batch_eval( ).run() ) else: - from benchflow.eval_sharding import ( - ShardWorkerError, - run_sharded_evaluation, - ) + from benchflow.eval_sharding import run_sharded_evaluation result = asyncio.run( run_sharded_evaluation( @@ -1330,7 +1295,7 @@ def _run_batch_eval( except EmptyTaskSelectionError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from None - except (ValueError, ShardWorkerError) as e: + except (ValueError, RuntimeError, ShardWorkerError) as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from None diff --git a/src/benchflow/eval_sharding.py b/src/benchflow/eval_sharding.py index 713a34b5..cf8f9c47 100644 --- a/src/benchflow/eval_sharding.py +++ b/src/benchflow/eval_sharding.py @@ -282,10 +282,6 @@ async def run_sharded_evaluation( total_concurrency=config.concurrency, worker_concurrency=worker_concurrency, ) - config.usage_tracking.with_env_defaults().validate_parallelism( - concurrency=plan.total_concurrency, - worker_count=plan.worker_count, - ) if not plan.shards: from benchflow.evaluation import EmptyTaskSelectionError diff --git a/src/benchflow/evaluation.py b/src/benchflow/evaluation.py index 479bbffd..855b1c89 100644 --- a/src/benchflow/evaluation.py +++ b/src/benchflow/evaluation.py @@ -907,7 +907,6 @@ def _preflight_usage_tracking(self) -> None: cfg = self._config usage = cfg.usage_tracking.with_env_defaults() - usage.validate_parallelism(concurrency=cfg.concurrency) failure = validate_usage_proxy_preconditions( usage, environment=cfg.environment, diff --git a/src/benchflow/providers/assets/sandbox_usage_proxy.js b/src/benchflow/providers/assets/sandbox_usage_proxy.js new file mode 100644 index 00000000..bfa60061 --- /dev/null +++ b/src/benchflow/providers/assets/sandbox_usage_proxy.js @@ -0,0 +1,231 @@ +#!/usr/bin/env node +const fs = require("fs"); +const http = require("http"); +const https = require("https"); +const { URL } = require("url"); + +function getArg(name) { + const prefix = `--${name}=`; + const arg = process.argv.find((value) => value.startsWith(prefix)); + return arg ? arg.slice(prefix.length) : ""; +} + +function getConfig(name, argName) { + return process.env[`BENCHFLOW_USAGE_PROXY_${name}`] || getArg(argName); +} + +const target = new URL(getConfig("TARGET", "target").replace(/\/+$/, "")); +const statePath = getConfig("STATE_PATH", "state"); +const logPath = getConfig("LOG_PATH", "log"); +const pidPath = getConfig("PID_PATH", "pid"); +const sessionId = getConfig("SESSION_ID", "session-id"); +const agentName = getConfig("AGENT_NAME", "agent-name"); +const promptCacheRetention = getConfig("PROMPT_CACHE_RETENTION", "prompt-cache-retention"); + +function sanitizeHeaders(headers) { + const result = { ...headers }; + for (const key of Object.keys(result)) { + const lower = key.toLowerCase(); + if (["connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailer", "upgrade"].includes(lower)) { + delete result[key]; + } + } + return result; +} + +function responseHeaders(headers) { + const result = sanitizeHeaders(headers); + delete result["content-length"]; + delete result["Content-Length"]; + delete result["transfer-encoding"]; + delete result["Transfer-Encoding"]; + result["connection"] = "close"; + return result; +} + +const sensitiveHeaderNames = new Set([ + "authorization", + "proxy-authorization", + "x-api-key", + "api-key", + "openai-api-key", + "anthropic-api-key", + "x-goog-api-key", + "cookie", + "set-cookie", +]); +const sensitiveQueryNames = new Set([ + "key", + "api_key", + "apikey", + "access_token", +]); + +function captureHeaders(headers) { + const result = { ...headers }; + for (const key of Object.keys(result)) { + if (sensitiveHeaderNames.has(key.toLowerCase())) { + result[key] = "__BENCHFLOW_REDACTED__"; + } + } + return result; +} + +function capturePath(requestUrl) { + const parsed = new URL(requestUrl, "http://benchflow.local"); + for (const key of Array.from(parsed.searchParams.keys())) { + if (sensitiveQueryNames.has(key.toLowerCase())) { + parsed.searchParams.set(key, "__BENCHFLOW_REDACTED__"); + } + } + return `${parsed.pathname}${parsed.search}`; +} + +function appendCapture(record) { + fs.appendFileSync(logPath, JSON.stringify(record) + "\n", { encoding: "utf8" }); +} + +function bodyB64(chunks) { + return Buffer.concat(chunks).toString("base64"); +} + +function upstreamPath(requestUrl) { + const basePath = target.pathname.replace(/\/+$/, ""); + if (!basePath) return requestUrl; + return `${basePath}${requestUrl.startsWith("/") ? requestUrl : `/${requestUrl}`}`; +} + +function maybeApplyPromptCacheRetention(requestUrl, headers, body) { + if (!promptCacheRetention) return { headers, body }; + const requestPath = new URL(requestUrl, "http://127.0.0.1").pathname.replace(/\/+$/, ""); + if (!requestPath.endsWith("/responses") && !requestPath.endsWith("/chat/completions")) { + return { headers, body }; + } + const contentEncoding = String(headers["content-encoding"] || headers["Content-Encoding"] || "identity").toLowerCase(); + if (contentEncoding !== "identity") return { headers, body }; + try { + const parsed = body.length > 0 ? JSON.parse(body.toString("utf8")) : {}; + if (typeof parsed !== "object" || parsed === null || Array.isArray(parsed)) { + return { headers, body }; + } + if (Object.prototype.hasOwnProperty.call(parsed, "prompt_cache_retention")) { + return { headers, body }; + } + parsed.prompt_cache_retention = promptCacheRetention; + const updatedBody = Buffer.from(JSON.stringify(parsed)); + const updatedHeaders = { ...headers }; + delete updatedHeaders["content-encoding"]; + delete updatedHeaders["Content-Encoding"]; + updatedHeaders["content-type"] = updatedHeaders["content-type"] || updatedHeaders["Content-Type"] || "application/json"; + updatedHeaders["content-length"] = updatedBody.length; + return { headers: updatedHeaders, body: updatedBody }; + } catch (_error) { + return { headers, body }; + } +} + +const server = http.createServer((clientReq, clientRes) => { + const healthPath = new URL(clientReq.url, "http://127.0.0.1").pathname; + if ((clientReq.method === "GET" || clientReq.method === "HEAD") && ["/health", "/__benchflow_health"].includes(healthPath.replace(/\/+$/, "") || "/")) { + const payload = Buffer.from(JSON.stringify({ status: "ok" })); + clientRes.writeHead(200, { + "content-type": "application/json", + "content-length": clientReq.method === "HEAD" ? 0 : payload.length, + }); + if (clientReq.method !== "HEAD") clientRes.end(payload); + else clientRes.end(); + return; + } + + const requestChunks = []; + clientReq.on("data", (chunk) => requestChunks.push(chunk)); + clientReq.on("end", () => { + const originalRequestBody = Buffer.concat(requestChunks); + const startedAt = Date.now(); + const prepared = maybeApplyPromptCacheRetention( + clientReq.url, + sanitizeHeaders(clientReq.headers), + originalRequestBody, + ); + const requestBody = prepared.body; + const upstreamHeaders = prepared.headers; + upstreamHeaders.host = target.host; + if (requestBody.length > 0) upstreamHeaders["content-length"] = requestBody.length; + + const options = { + protocol: target.protocol, + hostname: target.hostname, + port: target.port || (target.protocol === "https:" ? 443 : 80), + method: clientReq.method, + path: upstreamPath(clientReq.url), + headers: upstreamHeaders, + }; + const transport = target.protocol === "https:" ? https : http; + const upstreamReq = transport.request(options, (upstreamRes) => { + const responseChunks = []; + clientRes.writeHead( + upstreamRes.statusCode || 502, + upstreamRes.statusMessage || "OK", + responseHeaders(upstreamRes.headers), + ); + upstreamRes.on("data", (chunk) => { + responseChunks.push(chunk); + clientRes.write(chunk); + }); + upstreamRes.on("end", () => { + clientRes.end(); + appendCapture({ + session_id: sessionId, + agent_name: agentName, + duration_ms: Date.now() - startedAt, + request: { + method: clientReq.method, + path: capturePath(clientReq.url), + headers: captureHeaders(clientReq.headers), + body_b64: requestBody.toString("base64"), + }, + response: { + status_code: upstreamRes.statusCode || 0, + headers: captureHeaders(upstreamRes.headers), + body_b64: bodyB64(responseChunks), + }, + }); + }); + }); + upstreamReq.on("error", (error) => { + const payload = Buffer.from(JSON.stringify({ error: String(error.message || error) })); + clientRes.writeHead(502, { + "content-type": "application/json", + "content-length": payload.length, + "connection": "close", + }); + clientRes.end(payload); + appendCapture({ + session_id: sessionId, + agent_name: agentName, + duration_ms: Date.now() - startedAt, + request: { + method: clientReq.method, + path: capturePath(clientReq.url), + headers: captureHeaders(clientReq.headers), + body_b64: requestBody.toString("base64"), + }, + response: { + status_code: 502, + headers: { "content-type": "application/json" }, + body_b64: payload.toString("base64"), + }, + }); + }); + if (requestBody.length > 0) upstreamReq.write(requestBody); + upstreamReq.end(); + }); +}); + +server.listen(0, "127.0.0.1", () => { + const address = server.address(); + fs.writeFileSync(pidPath, String(process.pid)); + fs.writeFileSync(statePath, JSON.stringify({ port: address.port, pid: process.pid })); +}); + +process.on("SIGTERM", () => server.close(() => process.exit(0))); diff --git a/src/benchflow/providers/runtime.py b/src/benchflow/providers/runtime.py index 203fdf4e..f4be50f5 100644 --- a/src/benchflow/providers/runtime.py +++ b/src/benchflow/providers/runtime.py @@ -3,34 +3,18 @@ from __future__ import annotations import logging -import os -import secrets from dataclasses import dataclass from typing import Any -from urllib.parse import urlsplit, urlunsplit - -import httpx from benchflow.agents.providers import find_provider, strip_provider_prefix from benchflow.agents.registry import AGENTS from benchflow.providers.bedrock_proxy import BedrockProxyServer -from benchflow.trajectories.pricing import PRICING_USD_PER_MTOK, PricingEntry -from benchflow.trajectories.proxy import TrajectoryProxy -from benchflow.usage_tracking import ( - DEFAULT_USAGE_PROXY_BIND_HOST, - USAGE_PROXY_ADVERTISED_BASE_URL_ENV, - USAGE_PROXY_PORT_ENV, - UsageTrackingConfig, -) +from benchflow.usage_tracking import UsageTrackingConfig logger = logging.getLogger(__name__) BEDROCK_PROXY_BIND_HOST = "0.0.0.0" BEDROCK_PROXY_LOCAL_HOST = "127.0.0.1" -USAGE_PROXY_BIND_HOST = DEFAULT_USAGE_PROXY_BIND_HOST -PROMPT_CACHE_RETENTION_ENV = "BENCHFLOW_PROVIDER_PROMPT_CACHE_RETENTION" -DISABLE_USAGE_PROXY_ENV = "BENCHFLOW_DISABLE_USAGE_PROXY" -_PROMPT_CACHE_RETENTION_VALUES = {"in_memory", "24h"} @dataclass @@ -134,6 +118,8 @@ def _apply_direct_bedrock_agent_mapping( agent=agent, backend_model=backend_model, ) + if updated.get("AWS_REGION") and not updated.get("AWS_REGION_NAME"): + updated["AWS_REGION_NAME"] = updated["AWS_REGION"] if updated.get("AWS_BEARER_TOKEN_BEDROCK"): updated["LLM_API_KEY"] = updated["AWS_BEARER_TOKEN_BEDROCK"] return updated @@ -219,46 +205,6 @@ def _bedrock_proxy_command( return BEDROCK_PROXY_LOCAL_HOST -def _host_side_proxy_target_url(target: str, *, environment: str) -> str: - """Return the upstream URL a host-side proxy should dial. - - The URL injected into an agent container may use Docker's host alias - (``host.docker.internal`` or the Linux bridge gateway). That address is for - the container. A proxy process running on the host should reach another - host-bound BenchFlow proxy through loopback instead. - """ - if not host_proxy_reachable_from_agent(environment): - return target - parsed = urlsplit(target) - if not parsed.hostname: - return target - if parsed.hostname != _bedrock_proxy_command(environment=environment): - return target - netloc = BEDROCK_PROXY_LOCAL_HOST - if parsed.port is not None: - netloc = f"{netloc}:{parsed.port}" - return urlunsplit( - (parsed.scheme, netloc, parsed.path, parsed.query, parsed.fragment) - ) - - -def _usage_unavailable() -> dict[str, Any]: - return { - "n_input_tokens": None, - "n_output_tokens": None, - "n_cache_read_tokens": None, - "n_cache_creation_tokens": None, - "total_tokens": None, - "cost_usd": None, - "usage_source": "unavailable", - "price_source": None, - } - - -def _env_flag_enabled(value: str | None) -> bool: - return value is not None and value.strip().lower() in {"1", "true", "yes", "on"} - - def _agent_base_url_envs(agent: str) -> list[str]: envs: list[str] = [] agent_cfg = AGENTS.get(agent) @@ -280,234 +226,24 @@ def _agent_base_url_envs(agent: str) -> list[str]: return [e for e in envs if not (e in seen or seen.add(e))] -def _infer_default_provider_url(agent: str, model: str | None) -> str | None: - bare = strip_provider_prefix(model) if model else "" - m = bare.lower() - if "claude" in m or "anthropic" in m or agent == "claude-agent-acp": - return "https://api.anthropic.com" - if ( - "gpt" in m - or "openai" in m - or m.startswith(("o1", "o3", "o4")) - or agent in {"codex-acp", "opencode"} - ): - return "https://api.openai.com/v1" - if "gemini" in m or "gemma" in m or agent == "gemini": - return "https://generativelanguage.googleapis.com" - return None - - -def _resolve_usage_proxy_target( - agent: str, - agent_env: dict[str, str], - model: str | None, -) -> str | None: - if agent_env.get("BENCHFLOW_PROVIDER_BASE_URL"): - return agent_env["BENCHFLOW_PROVIDER_BASE_URL"] - for env_name in _agent_base_url_envs(agent): - if agent_env.get(env_name): - return agent_env[env_name] - return _infer_default_provider_url(agent, model) - - -def _external_usage_proxy_error(environment: str) -> str: - return ( - f"Token usage tracking is required for sandbox={environment!r}, but " - "that sandbox runs the agent on a remote host and cannot reach a " - "host-bound usage proxy and no external usage proxy endpoint is " - "configured. Configure an external usage proxy endpoint with " - f"{USAGE_PROXY_ADVERTISED_BASE_URL_ENV} plus a fixed " - f"{USAGE_PROXY_PORT_ENV}, or rerun with --usage-tracking auto/off." - ) - - -@dataclass(frozen=True) -class UsageProxyPreconditionFailure: - """Why the usage proxy cannot be wired for this rollout.""" - - required_message: str - skip_message: str - log_level: int = logging.WARNING - - -def _usage_proxy_path_prefix() -> str: - return f"/__benchflow/{secrets.token_urlsafe(24)}" - - -def _agent_usage_proxy_base_url( - *, - environment: str, - port: int, - usage_tracking: UsageTrackingConfig, - path_prefix: str, -) -> str: - if usage_tracking.advertised_base_url: - return f"{usage_tracking.advertised_base_url}{path_prefix}" - return f"http://{_bedrock_proxy_command(environment=environment)}:{port}" - - def validate_usage_proxy_preconditions( usage_cfg: UsageTrackingConfig, *, environment: str, model: str | None, disable_usage_proxy: bool | None = None, -) -> UsageProxyPreconditionFailure | None: +) -> Any: """Return the first reason usage telemetry cannot be wired, if any.""" - if usage_cfg.mode == "off": - return None - - if disable_usage_proxy is None: - disable_usage_proxy = _env_flag_enabled(os.environ.get(DISABLE_USAGE_PROXY_ENV)) - if disable_usage_proxy: - return UsageProxyPreconditionFailure( - required_message=( - f"Token usage tracking is required, but {DISABLE_USAGE_PROXY_ENV} " - "is enabled." - ), - skip_message=( - f"Skipping host-side usage telemetry proxy: {DISABLE_USAGE_PROXY_ENV} " - "is enabled." - ), - log_level=logging.INFO, - ) - - host_reachable = host_proxy_reachable_from_agent(environment) - if not host_reachable and not usage_cfg.uses_external_proxy: - return UsageProxyPreconditionFailure( - required_message=_external_usage_proxy_error(environment or "unknown"), - skip_message=( - "Skipping host-side usage telemetry proxy: the " - f"{environment or 'unknown'!r} sandbox runs the agent on a remote " - "host unreachable from the host proxy and no external usage proxy " - "endpoint is configured." - ), - log_level=logging.INFO, - ) - - if usage_cfg.uses_external_proxy and not usage_cfg.has_fixed_proxy_port: - message = ( - "External usage proxy tracking requires a fixed positive local proxy port. " - f"Set {USAGE_PROXY_PORT_ENV} or pass --usage-proxy-port." - ) - return UsageProxyPreconditionFailure( - required_message=message, - skip_message=message, - log_level=logging.WARNING, - ) - - if ( - needs_provider_runtime(model) - and not host_reachable - and usage_cfg.uses_external_proxy - ): - message = ( - "Remote Bedrock-direct runs cannot be metered by the generic usage " - "proxy because the agent calls AWS Bedrock natively instead of an " - "OpenAI/Anthropic-compatible HTTP endpoint. Use an OpenAI-compatible " - "provider proxy for this run, run with --sandbox docker, or leave " - "usage tracking as auto/off." - ) - return UsageProxyPreconditionFailure( - required_message=message, - skip_message=message, - log_level=logging.WARNING, - ) - - return None - - -async def _external_usage_proxy_reachable(base_url: str) -> bool: - health_url = f"{base_url.rstrip('/')}/__benchflow_health" - try: - async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client: - response = await client.get(health_url) - return response.status_code == 200 - except Exception as exc: - logger.debug("External usage proxy health check failed: %s", exc) - return False - - -def _pricing_for_model(model: str | None) -> PricingEntry | None: - if not model: - return None - bare = strip_provider_prefix(model).lower() - for prefix, pricing in PRICING_USD_PER_MTOK.items(): - if bare.startswith(prefix): - return pricing - return None - + from benchflow.providers.usage_proxy_runtime import ( + validate_usage_proxy_preconditions as _validate_usage_proxy_preconditions, + ) -def _estimate_cost_usd( - *, - model: str | None, - input_tokens: int, - output_tokens: int, - cache_read_tokens: int, - cache_creation_tokens: int, - cache_tokens_included_in_input: bool = False, -) -> float | None: - pricing = _pricing_for_model(model) - if pricing is None: - return None - priced_input_tokens = input_tokens - if cache_tokens_included_in_input: - priced_input_tokens = max( - input_tokens - cache_read_tokens - cache_creation_tokens, 0 - ) - cost = ( - priced_input_tokens * pricing.input - + output_tokens * pricing.output - + cache_read_tokens * pricing.cache_read - + cache_creation_tokens * pricing.cache_creation - ) / 1_000_000 - return round(cost, 10) - - -def _model_from_trajectory(runtime: ProviderRuntime) -> str | None: - # Prefer the model the provider actually reported in captured exchanges; - # backend_model is only the model requested at proxy-creation time and can - # be stale if a role switched models. Falls back to it when no exchange - # carries a model (e.g. Gemini, which puts the model in the URL path). - trajectory = getattr(runtime.server, "trajectory", None) - if trajectory: - for exchange in trajectory.exchanges: - response_model = exchange.response.body.get("model") - if response_model: - return response_model - request_model = exchange.request.body.get("model") - if request_model: - return request_model - return runtime.backend_model - - -def _cache_tokens_are_input_breakdown(trajectory: Any) -> bool: - for exchange in trajectory.exchanges: - usage = exchange.response.body.get("usage", {}) - if (usage.get("prompt_tokens_details") or {}).get("cached_tokens") is not None: - return True - if (usage.get("input_tokens_details") or {}).get("cached_tokens") is not None: - return True - return False - - -async def _skip_or_block_usage_proxy( - *, - usage_cfg: UsageTrackingConfig, - failure: UsageProxyPreconditionFailure, - agent_env: dict[str, str], - runtime: ProviderRuntime | None, -) -> tuple[dict[str, str], ProviderRuntime | None]: - if runtime is not None: - await stop_provider_runtime(runtime) - if usage_cfg.mode == "required": - raise RuntimeError(failure.required_message) - logger.log( - failure.log_level, - "%s Usage telemetry will be unavailable for this run.", - failure.skip_message, + return _validate_usage_proxy_preconditions( + usage_cfg, + environment=environment, + model=model, + disable_usage_proxy=disable_usage_proxy, ) - return agent_env, None async def ensure_usage_proxy_runtime( @@ -519,166 +255,30 @@ async def ensure_usage_proxy_runtime( environment: str, session_id: str = "", usage_tracking: UsageTrackingConfig | dict[str, Any] | str | None = None, + sandbox: Any | None = None, ) -> tuple[dict[str, str], ProviderRuntime | None]: - """Start the host-side usage proxy and wire env vars to it. - - Remote cloud sandboxes (e.g. Daytona) can only use this proxy when the - operator supplies an externally reachable URL. The local bind endpoint and - the URL advertised to the agent are deliberately separate: Docker sees the - host bridge address, while Daytona sees the tunnel/ingress URL. - """ - usage_cfg = UsageTrackingConfig.coerce(usage_tracking).with_env_defaults() - if agent == "oracle": - return agent_env, runtime - - if usage_cfg.mode == "off": - if runtime is not None: - await stop_provider_runtime(runtime) - logger.info("Skipping host-side usage telemetry proxy: usage_tracking=off.") - return agent_env, None - - host_reachable = host_proxy_reachable_from_agent(environment) - failure = validate_usage_proxy_preconditions( - usage_cfg, - environment=environment, - model=model, + """Start a reachable usage proxy and wire agent provider env vars to it.""" + from benchflow.providers.usage_proxy_runtime import ( + ensure_usage_proxy_runtime as _ensure_usage_proxy_runtime, ) - if failure is not None: - return await _skip_or_block_usage_proxy( - usage_cfg=usage_cfg, - failure=failure, - agent_env=agent_env, - runtime=runtime, - ) - target = _resolve_usage_proxy_target(agent, agent_env, model) - if not target: - if usage_cfg.mode == "required": - raise RuntimeError( - "Token usage tracking is required, but BenchFlow could not " - "resolve a provider base URL for this agent/model." - ) - return agent_env, runtime - target = target.rstrip("/") - if host_reachable: - target = _host_side_proxy_target_url(target, environment=environment) - - # A multi-role scene can switch providers between connect_as() calls. The - # running proxy forwards to a fixed upstream, so reusing it would route the - # new role's traffic to the wrong endpoint — retire it and start a fresh - # one for the new target. - if runtime is not None and getattr(runtime.server, "target", None) != target: - await stop_provider_runtime(runtime) - runtime = None - - if runtime is None: - prompt_cache_retention = agent_env.get(PROMPT_CACHE_RETENTION_ENV) - if ( - prompt_cache_retention is not None - and prompt_cache_retention not in _PROMPT_CACHE_RETENTION_VALUES - ): - raise ValueError( - f"{PROMPT_CACHE_RETENTION_ENV} must be one of: " - f"{', '.join(sorted(_PROMPT_CACHE_RETENTION_VALUES))}" - ) - logger.info("Starting host-side usage telemetry proxy") - bind_host = usage_cfg.bind_host - if bind_host is None: - bind_host = ( - "127.0.0.1" if usage_cfg.uses_external_proxy else USAGE_PROXY_BIND_HOST - ) - bind_port = usage_cfg.port if usage_cfg.port is not None else 0 - path_prefix = ( - _usage_proxy_path_prefix() if usage_cfg.uses_external_proxy else "" - ) - proxy_kwargs: dict[str, Any] = { - "target": target, - "session_id": session_id, - "agent_name": agent, - "host": bind_host, - "port": bind_port, - "prompt_cache_retention": prompt_cache_retention, - } - if path_prefix: - proxy_kwargs["path_prefix"] = path_prefix - server = TrajectoryProxy(**proxy_kwargs) - await server.start() - agent_base_url = _agent_usage_proxy_base_url( - environment=environment, - port=server.port, - usage_tracking=usage_cfg, - path_prefix=path_prefix, - ) - runtime = ProviderRuntime( - kind="usage-proxy", - agent_base_url=agent_base_url, - backend_model=strip_provider_prefix(model) if model else None, - server=server, - ) - - if usage_cfg.uses_external_proxy: - reachable = await _external_usage_proxy_reachable(runtime.base_url) - if not reachable: - await stop_provider_runtime(runtime) - runtime = None - message = ( - "External usage proxy endpoint was configured but did not " - f"respond to its health check: {usage_cfg.advertised_base_url}" - ) - if usage_cfg.mode == "required": - raise RuntimeError(message) - logger.warning( - "%s. Usage telemetry will be unavailable for this run.", message - ) - return agent_env, None - - updated = dict(agent_env) - updated["BENCHFLOW_PROVIDER_BASE_URL"] = runtime.base_url - agent_cfg = AGENTS.get(agent) - mapped_base = ( - agent_cfg.env_mapping.get("BENCHFLOW_PROVIDER_BASE_URL") if agent_cfg else None + return await _ensure_usage_proxy_runtime( + agent=agent, + agent_env=agent_env, + model=model, + runtime=runtime, + environment=environment, + session_id=session_id, + usage_tracking=usage_tracking, + sandbox=sandbox, ) - for env_name in _agent_base_url_envs(agent): - if env_name in updated or env_name == mapped_base: - updated[env_name] = runtime.base_url - return updated, runtime def extract_usage(runtime: ProviderRuntime | None) -> dict[str, Any]: """Extract aggregate token/cost metrics from a usage proxy runtime.""" - if runtime is None or runtime.kind != "usage-proxy" or runtime.server is None: - return _usage_unavailable() - trajectory = getattr(runtime.server, "trajectory", None) - if trajectory is None or not trajectory.exchanges: - return _usage_unavailable() - - input_tokens = trajectory.total_input_tokens - output_tokens = trajectory.total_output_tokens - cache_read_tokens = trajectory.total_cache_read_tokens - cache_creation_tokens = trajectory.total_cache_creation_tokens - total_tokens = trajectory.total_provider_tokens - model = _model_from_trajectory(runtime) - pricing = _pricing_for_model(model) - cost_usd = _estimate_cost_usd( - model=model, - input_tokens=input_tokens, - output_tokens=output_tokens, - cache_read_tokens=cache_read_tokens, - cache_creation_tokens=cache_creation_tokens, - cache_tokens_included_in_input=_cache_tokens_are_input_breakdown(trajectory), - ) - return { - "n_input_tokens": input_tokens, - "n_output_tokens": output_tokens, - "n_cache_read_tokens": cache_read_tokens, - "n_cache_creation_tokens": cache_creation_tokens, - "total_tokens": total_tokens, - "cost_usd": cost_usd, - "usage_source": "provider_response", - "price_source": pricing.price_source - if cost_usd is not None and pricing - else None, - } + from benchflow.providers.usage_proxy_runtime import extract_usage as _extract_usage + + return _extract_usage(runtime) async def ensure_bedrock_proxy_runtime( diff --git a/src/benchflow/providers/sandbox_usage_proxy.py b/src/benchflow/providers/sandbox_usage_proxy.py new file mode 100644 index 00000000..fdd4cc8a --- /dev/null +++ b/src/benchflow/providers/sandbox_usage_proxy.py @@ -0,0 +1,273 @@ +"""Sandbox-local provider usage proxy runtime. + +The host-side :class:`benchflow.trajectories.proxy.TrajectoryProxy` works when +the agent can route back to the host. Remote sandboxes such as Daytona cannot, +so this module starts a tiny byte-forwarding proxy inside the same sandbox +network namespace as the agent and imports its raw captures back into the host +trajectory model during cleanup. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +import shlex +import tempfile +from pathlib import Path +from typing import Any +from uuid import uuid4 + +from benchflow.agents.registry import _NODE_INSTALL +from benchflow.trajectories.proxy import exchange_from_raw_capture +from benchflow.trajectories.types import Trajectory + +logger = logging.getLogger(__name__) + +_RUNTIME_ROOT = "/tmp/benchflow-usage-proxy" + + +def _read_node_proxy_source() -> str: + return (Path(__file__).with_name("assets") / "sandbox_usage_proxy.js").read_text() + + +_NODE_PROXY_SOURCE = _read_node_proxy_source() +_NODE_LAUNCHER_SOURCE = r""" +const fs = require("fs"); +const { spawn } = require("child_process"); + +const config = JSON.parse(process.env.BENCHFLOW_USAGE_PROXY_CONFIG || "{}"); +const stdout = fs.openSync(config.stdout, "a"); +const stderr = fs.openSync(config.stderr, "a"); +const child = spawn(config.node, [config.script], { + detached: true, + stdio: ["ignore", stdout, stderr], + env: { ...process.env, ...config.env }, +}); +child.unref(); +console.log(child.pid); +""" + + +class SandboxUsageProxy: + """Long-lived proxy process running in the agent sandbox.""" + + def __init__( + self, + *, + sandbox: Any, + target: str, + session_id: str, + agent_name: str, + prompt_cache_retention: str | None = None, + ) -> None: + self.sandbox = sandbox + self.target = target.rstrip("/") + self.session_id = session_id + self.agent_name = agent_name + self.prompt_cache_retention = prompt_cache_retention + self.trajectory = Trajectory(session_id=session_id, agent_name=agent_name) + self._token = uuid4().hex[:16] + self._runtime_dir = f"{_RUNTIME_ROOT}/{self._token}" + self._script_path = f"{self._runtime_dir}/proxy.js" + self._state_path = f"{self._runtime_dir}/state.json" + self._log_path = f"{self._runtime_dir}/captures.jsonl" + self._pid_path = f"{self._runtime_dir}/proxy.pid" + self._base_url: str | None = None + + @property + def base_url(self) -> str: + if self._base_url is None: + raise RuntimeError("sandbox usage proxy has not started") + return self._base_url + + async def start(self) -> None: + await self._upload_proxy_script() + node = await self._ensure_node() + stdout_path = f"{_RUNTIME_ROOT}/{self._token}/stdout.log" + stderr_path = f"{_RUNTIME_ROOT}/{self._token}/stderr.log" + launcher_config = { + "node": node, + "script": self._script_path, + "stdout": stdout_path, + "stderr": stderr_path, + "env": { + "BENCHFLOW_USAGE_PROXY_TARGET": self.target, + "BENCHFLOW_USAGE_PROXY_STATE_PATH": self._state_path, + "BENCHFLOW_USAGE_PROXY_LOG_PATH": self._log_path, + "BENCHFLOW_USAGE_PROXY_PID_PATH": self._pid_path, + "BENCHFLOW_USAGE_PROXY_SESSION_ID": self.session_id, + "BENCHFLOW_USAGE_PROXY_AGENT_NAME": self.agent_name, + "BENCHFLOW_USAGE_PROXY_PROMPT_CACHE_RETENTION": ( + self.prompt_cache_retention or "" + ), + }, + } + command = " ".join( + [ + "mkdir", + "-p", + shlex.quote(str(Path(self._script_path).parent)), + "&&", + "rm", + "-f", + shlex.quote(self._state_path), + shlex.quote(self._log_path), + shlex.quote(self._pid_path), + "&&", + f"BENCHFLOW_USAGE_PROXY_CONFIG={shlex.quote(json.dumps(launcher_config))}", + shlex.quote(node), + "-e", + shlex.quote(_NODE_LAUNCHER_SOURCE), + ] + ) + result = await self.sandbox.exec(command, timeout_sec=15) + if result.return_code != 0: + raise RuntimeError(_exec_details("start sandbox usage proxy", result)) + state = await self._wait_for_state() + self._base_url = f"http://127.0.0.1:{state['port']}" + logger.info("Sandbox usage telemetry proxy listening on %s", self._base_url) + + async def is_running(self) -> bool: + result = await self.sandbox.exec( + ( + f"if [ -s {shlex.quote(self._pid_path)} ] && " + f"kill -0 $(cat {shlex.quote(self._pid_path)}) 2>/dev/null; " + "then echo yes; else echo no; fi" + ), + timeout_sec=5, + ) + return result.return_code == 0 and (result.stdout or "").strip() == "yes" + + async def stop(self) -> None: + try: + await self._load_captures() + except Exception as exc: + logger.warning("Could not import sandbox usage captures: %s", exc) + finally: + await self._terminate() + await self._cleanup_runtime_dir() + + async def _terminate(self) -> None: + kill_cmd = ( + f"if [ -s {shlex.quote(self._pid_path)} ]; then " + f"kill -TERM $(cat {shlex.quote(self._pid_path)}) 2>/dev/null || true; " + "fi" + ) + with contextlib.suppress(Exception): + await self.sandbox.exec(kill_cmd, timeout_sec=10) + + async def _cleanup_runtime_dir(self) -> None: + with contextlib.suppress(Exception): + await self.sandbox.exec( + f"rm -rf {shlex.quote(self._runtime_dir)}", + timeout_sec=10, + ) + + async def _upload_proxy_script(self) -> None: + parent = shlex.quote(str(Path(self._script_path).parent)) + result = await self.sandbox.exec(f"mkdir -p {parent}", timeout_sec=15) + if result.return_code != 0: + raise RuntimeError(_exec_details("prepare sandbox usage proxy dir", result)) + + with tempfile.NamedTemporaryFile("w", suffix=".js", delete=False) as tmp: + tmp.write(_NODE_PROXY_SOURCE) + tmp_path = Path(tmp.name) + try: + await self.sandbox.upload_file(tmp_path, self._script_path) + finally: + tmp_path.unlink(missing_ok=True) + + async def _ensure_node(self) -> str: + node_probe = ( + "if [ -x /opt/benchflow/node/bin/node ]; then " + "echo /opt/benchflow/node/bin/node; " + "elif command -v node >/dev/null 2>&1; then command -v node; " + "else echo ''; fi" + ) + result = await self.sandbox.exec(node_probe, timeout_sec=10) + node = (result.stdout or "").strip().splitlines()[-1:] or [""] + if node[0]: + return node[0] + + install = await self.sandbox.exec(_NODE_INSTALL, timeout_sec=300) + if install.return_code != 0: + raise RuntimeError(_exec_details("install Node for usage proxy", install)) + result = await self.sandbox.exec(node_probe, timeout_sec=10) + node = (result.stdout or "").strip().splitlines()[-1:] or [""] + if not node[0]: + raise RuntimeError("Node.js was not available after usage proxy bootstrap") + return node[0] + + async def _wait_for_state(self) -> dict[str, Any]: + last_output = "" + for _ in range(50): + result = await self.sandbox.exec( + f"cat {shlex.quote(self._state_path)} 2>/dev/null || true", + timeout_sec=5, + ) + last_output = (result.stdout or "").strip() + if last_output: + try: + state = json.loads(last_output) + except (json.JSONDecodeError, ValueError): + await asyncio.sleep(0.2) + continue + if int(state.get("port") or 0) > 0: + return state + await asyncio.sleep(0.2) + stderr = await self.sandbox.exec( + f"cat {shlex.quote(f'{_RUNTIME_ROOT}/{self._token}/stderr.log')} " + "2>/dev/null || true", + timeout_sec=5, + ) + raise RuntimeError( + "sandbox usage proxy did not publish its state" + f": {last_output or (stderr.stdout or '').strip()}" + ) + + async def _load_captures(self) -> None: + capture_text = await self._read_capture_log() + trajectory = Trajectory(session_id=self.session_id, agent_name=self.agent_name) + for line in capture_text.splitlines(): + if not line.strip(): + continue + try: + trajectory.exchanges.append(exchange_from_raw_capture(json.loads(line))) + except Exception as exc: + logger.warning("Skipping malformed sandbox usage capture: %s", exc) + self.trajectory = trajectory + + async def _read_capture_log(self) -> str: + download_file = getattr(self.sandbox, "download_file", None) + if download_file is not None: + with tempfile.NamedTemporaryFile("r", delete=False) as tmp: + tmp_path = Path(tmp.name) + try: + await download_file(self._log_path, tmp_path) + return tmp_path.read_text() + except Exception as exc: + logger.debug("Sandbox usage capture download failed: %s", exc) + finally: + tmp_path.unlink(missing_ok=True) + + result = await self.sandbox.exec( + f"cat {shlex.quote(self._log_path)} 2>/dev/null || true", + timeout_sec=15, + ) + if result.return_code != 0: + logger.warning("Could not read sandbox usage captures: %s", result.stderr) + return "" + return result.stdout or "" + + +def _exec_details(label: str, result: Any) -> str: + stdout = (getattr(result, "stdout", "") or "").strip() + stderr = (getattr(result, "stderr", "") or "").strip() + details = [f"{label} failed with exit code {getattr(result, 'return_code', '?')}"] + if stdout: + details.append(f"stdout: {stdout[:1000]}") + if stderr: + details.append(f"stderr: {stderr[:1000]}") + return "; ".join(details) diff --git a/src/benchflow/providers/usage_proxy_runtime.py b/src/benchflow/providers/usage_proxy_runtime.py new file mode 100644 index 00000000..6ce0821f --- /dev/null +++ b/src/benchflow/providers/usage_proxy_runtime.py @@ -0,0 +1,727 @@ +"""Usage telemetry proxy runtime orchestration.""" + +from __future__ import annotations + +import contextlib +import logging +import os +from dataclasses import dataclass +from typing import Any +from urllib.parse import urlsplit, urlunsplit + +from benchflow.agents.codex_config import apply_codex_provider_config +from benchflow.agents.providers import strip_provider_prefix +from benchflow.agents.registry import AGENTS +from benchflow.providers.runtime import ( + BEDROCK_PROXY_LOCAL_HOST, + ProviderRuntime, + _agent_base_url_envs, + _bedrock_proxy_command, + host_proxy_reachable_from_agent, + needs_provider_runtime, + stop_provider_runtime, +) +from benchflow.providers.sandbox_usage_proxy import SandboxUsageProxy +from benchflow.trajectories.pricing import PRICING_USD_PER_MTOK, PricingEntry +from benchflow.trajectories.proxy import TrajectoryProxy +from benchflow.usage_tracking import UsageTrackingConfig + +logger = logging.getLogger(__name__) + +USAGE_PROXY_BIND_HOST = "0.0.0.0" +PROMPT_CACHE_RETENTION_ENV = "BENCHFLOW_PROVIDER_PROMPT_CACHE_RETENTION" +DISABLE_USAGE_PROXY_ENV = "BENCHFLOW_DISABLE_USAGE_PROXY" +_PROMPT_CACHE_RETENTION_VALUES = {"in_memory", "24h"} +_BEDROCK_RUNTIME_ENDPOINT_ENVS = ( + "AWS_ENDPOINT_URL_BEDROCK_RUNTIME", + "AWS_ENDPOINT_URL_BEDROCK", +) + + +def _host_side_proxy_target_url(target: str, *, environment: str) -> str: + """Return the upstream URL a host-side proxy should dial. + + The URL injected into an agent container may use Docker's host alias + (``host.docker.internal`` or the Linux bridge gateway). That address is for + the container. A proxy process running on the host should reach another + host-bound BenchFlow proxy through loopback instead. + """ + if not host_proxy_reachable_from_agent(environment): + return target + parsed = urlsplit(target) + if not parsed.hostname: + return target + if parsed.hostname != _bedrock_proxy_command(environment=environment): + return target + netloc = BEDROCK_PROXY_LOCAL_HOST + if parsed.port is not None: + netloc = f"{netloc}:{parsed.port}" + return urlunsplit( + (parsed.scheme, netloc, parsed.path, parsed.query, parsed.fragment) + ) + + +def _usage_unavailable() -> dict[str, Any]: + return { + "n_input_tokens": None, + "n_output_tokens": None, + "n_cache_read_tokens": None, + "n_cache_creation_tokens": None, + "total_tokens": None, + "cost_usd": None, + "usage_source": "unavailable", + "price_source": None, + } + + +def _env_flag_enabled(value: str | None) -> bool: + return value is not None and value.strip().lower() in {"1", "true", "yes", "on"} + + +def _apply_codex_config_base_url( + agent: str, + agent_env: dict[str, str], + base_url: str, +) -> None: + """Keep Codex's provider config pointed at the active usage proxy.""" + if agent != "codex-acp": + return + apply_codex_provider_config( + agent_env, + base_url=base_url, + model=agent_env.get("BENCHFLOW_PROVIDER_MODEL"), + provider_name=agent_env.get("BENCHFLOW_PROVIDER_NAME") or "openai", + ) + + +def _infer_default_provider_url(agent: str, model: str | None) -> str | None: + bare = strip_provider_prefix(model) if model else "" + m = bare.lower() + if "claude" in m or "anthropic" in m or agent == "claude-agent-acp": + return "https://api.anthropic.com" + if ( + "gpt" in m + or "openai" in m + or m.startswith(("o1", "o3", "o4")) + or agent in {"codex-acp", "opencode"} + ): + return "https://api.openai.com/v1" + if "gemini" in m or "gemma" in m or agent == "gemini": + return "https://generativelanguage.googleapis.com" + return None + + +def _resolve_usage_proxy_target( + agent: str, + agent_env: dict[str, str], + model: str | None, +) -> str | None: + if agent_env.get("BENCHFLOW_PROVIDER_BASE_URL"): + return agent_env["BENCHFLOW_PROVIDER_BASE_URL"] + for env_name in _agent_base_url_envs(agent): + if agent_env.get(env_name): + return agent_env[env_name] + return _infer_default_provider_url(agent, model) + + +def _bedrock_runtime_target(agent_env: dict[str, str]) -> str | None: + if agent_env.get("ANTHROPIC_BEDROCK_BASE_URL"): + return agent_env["ANTHROPIC_BEDROCK_BASE_URL"].rstrip("/") + region = agent_env.get("AWS_REGION") or agent_env.get("AWS_DEFAULT_REGION") + if not region: + return None + return f"https://bedrock-runtime.{region}.amazonaws.com" + + +def _is_remote_direct_bedrock_usage( + *, + model: str | None, + environment: str, +) -> bool: + return needs_provider_runtime(model) and not host_proxy_reachable_from_agent( + environment + ) + + +def _usage_runtime_target(runtime: ProviderRuntime | None) -> str | None: + """Return the unproxied upstream target for a running usage proxy.""" + if runtime is None or runtime.kind != "usage-proxy": + return None + target = getattr(getattr(runtime, "server", None), "target", None) + if isinstance(target, str) and target.strip(): + return target.rstrip("/") + return None + + +def _unproxied_usage_target( + target: str, + runtime: ProviderRuntime | None, +) -> str: + """Recover the provider URL when reconnect env already points at our proxy.""" + runtime_target = _usage_runtime_target(runtime) + if ( + runtime is not None + and runtime_target + and target.rstrip("/") == runtime.base_url.rstrip("/") + ): + return runtime_target + return target.rstrip("/") + + +@dataclass(frozen=True) +class UsageProxyRouting: + """Provider-specific usage proxy routing resolved before proxy lifecycle.""" + + target: str | None + missing_target_detail: str + remote_direct_bedrock: bool = False + + +def _usage_proxy_routing( + *, + agent: str, + agent_env: dict[str, str], + model: str | None, + environment: str, +) -> UsageProxyRouting: + if _is_remote_direct_bedrock_usage(model=model, environment=environment): + return UsageProxyRouting( + target=_bedrock_runtime_target(agent_env), + missing_target_detail=( + "resolve AWS_REGION or AWS_DEFAULT_REGION for Bedrock Runtime." + ), + remote_direct_bedrock=True, + ) + return UsageProxyRouting( + target=_resolve_usage_proxy_target(agent, agent_env, model), + missing_target_detail="resolve a provider base URL for this agent/model.", + ) + + +@dataclass(frozen=True) +class UsageProxyPreconditionFailure: + """Why the usage proxy cannot be wired for this rollout.""" + + required_message: str + skip_message: str + log_level: int = logging.WARNING + + +def _agent_usage_proxy_base_url( + *, + environment: str, + port: int, +) -> str: + return f"http://{_bedrock_proxy_command(environment=environment)}:{port}" + + +def validate_usage_proxy_preconditions( + usage_cfg: UsageTrackingConfig, + *, + environment: str, + model: str | None, + disable_usage_proxy: bool | None = None, +) -> UsageProxyPreconditionFailure | None: + """Return the first reason usage telemetry cannot be wired, if any.""" + if usage_cfg.mode == "off": + return None + + if disable_usage_proxy is None: + disable_usage_proxy = _env_flag_enabled(os.environ.get(DISABLE_USAGE_PROXY_ENV)) + if disable_usage_proxy: + return UsageProxyPreconditionFailure( + required_message=( + f"Token usage tracking is required, but {DISABLE_USAGE_PROXY_ENV} " + "is enabled." + ), + skip_message=( + f"Skipping host-side usage telemetry proxy: {DISABLE_USAGE_PROXY_ENV} " + "is enabled." + ), + log_level=logging.INFO, + ) + + return None + + +def _pricing_for_model(model: str | None) -> PricingEntry | None: + if not model: + return None + bare = strip_provider_prefix(model).lower() + for prefix, pricing in PRICING_USD_PER_MTOK.items(): + if bare.startswith(prefix): + return pricing + return None + + +def _estimate_cost_usd( + *, + model: str | None, + input_tokens: int, + output_tokens: int, + cache_read_tokens: int, + cache_creation_tokens: int, + cache_tokens_included_in_input: bool = False, +) -> float | None: + pricing = _pricing_for_model(model) + if pricing is None: + return None + priced_input_tokens = input_tokens + if cache_tokens_included_in_input: + priced_input_tokens = max( + input_tokens - cache_read_tokens - cache_creation_tokens, 0 + ) + cost = ( + priced_input_tokens * pricing.input + + output_tokens * pricing.output + + cache_read_tokens * pricing.cache_read + + cache_creation_tokens * pricing.cache_creation + ) / 1_000_000 + return round(cost, 10) + + +def _model_from_trajectory(runtime: ProviderRuntime) -> str | None: + # Prefer the model the provider actually reported in captured exchanges; + # backend_model is only the model requested at proxy-creation time and can + # be stale if a role switched models. Falls back to it when no exchange + # carries a model (e.g. Gemini, which puts the model in the URL path). + trajectory = getattr(runtime.server, "trajectory", None) + if trajectory: + for exchange in trajectory.exchanges: + response_model = exchange.response.body.get("model") + if response_model: + return response_model + request_model = exchange.request.body.get("model") + if request_model: + return request_model + return runtime.backend_model + + +def _cache_tokens_are_input_breakdown(trajectory: Any) -> bool: + for exchange in trajectory.exchanges: + usage = exchange.response.body.get("usage", {}) + if (usage.get("prompt_tokens_details") or {}).get("cached_tokens") is not None: + return True + if (usage.get("input_tokens_details") or {}).get("cached_tokens") is not None: + return True + return False + + +async def _skip_or_block_usage_proxy( + *, + usage_cfg: UsageTrackingConfig, + failure: UsageProxyPreconditionFailure, + agent_env: dict[str, str], + runtime: ProviderRuntime | None, +) -> tuple[dict[str, str], ProviderRuntime | None]: + if runtime is not None: + await stop_provider_runtime(runtime) + if usage_cfg.mode == "required": + raise RuntimeError(failure.required_message) + logger.log( + failure.log_level, + "%s Usage telemetry will be unavailable for this run.", + failure.skip_message, + ) + return agent_env, None + + +@dataclass(frozen=True) +class UsageProxyEndpoint: + """Resolved upstream and runner for one usage proxy lifecycle.""" + + target: str + routing: UsageProxyRouting + runner: UsageProxyRunner + + +class UsageProxyRunner: + """Strategy for starting a usage proxy reachable by the agent.""" + + async def start( + self, + *, + target: str, + session_id: str, + agent: str, + model: str | None, + prompt_cache_retention: str | None, + ) -> ProviderRuntime: + raise NotImplementedError + + +@dataclass(frozen=True) +class HostUsageProxyRunner(UsageProxyRunner): + """Start a host-side proxy for same-host sandboxes.""" + + environment: str + + async def start( + self, + *, + target: str, + session_id: str, + agent: str, + model: str | None, + prompt_cache_retention: str | None, + ) -> ProviderRuntime: + logger.info("Starting host-side usage telemetry proxy") + server: Any | None = None + try: + host_server = TrajectoryProxy( + target=target, + session_id=session_id, + agent_name=agent, + host=USAGE_PROXY_BIND_HOST, + port=0, + prompt_cache_retention=prompt_cache_retention, + ) + server = host_server + await host_server.start() + agent_base_url = _agent_usage_proxy_base_url( + environment=self.environment, + port=host_server.port, + ) + return ProviderRuntime( + kind="usage-proxy", + agent_base_url=agent_base_url, + backend_model=strip_provider_prefix(model) if model else None, + server=host_server, + ) + except Exception: + if server is not None: + with contextlib.suppress(Exception): + await server.stop() + raise + + +@dataclass(frozen=True) +class SandboxUsageProxyRunner(UsageProxyRunner): + """Start a sandbox-local proxy for remote sandboxes such as Daytona.""" + + sandbox: Any + + async def start( + self, + *, + target: str, + session_id: str, + agent: str, + model: str | None, + prompt_cache_retention: str | None, + ) -> ProviderRuntime: + logger.info("Starting Daytona sandbox-local usage telemetry proxy") + server: Any | None = None + try: + sandbox_server = SandboxUsageProxy( + sandbox=self.sandbox, + target=target, + session_id=session_id, + agent_name=agent, + prompt_cache_retention=prompt_cache_retention, + ) + server = sandbox_server + await sandbox_server.start() + return ProviderRuntime( + kind="usage-proxy", + agent_base_url=sandbox_server.base_url, + backend_model=strip_provider_prefix(model) if model else None, + server=sandbox_server, + ) + except Exception: + if server is not None: + with contextlib.suppress(Exception): + await server.stop() + raise + + +def _usage_proxy_runner( + *, + environment: str, + sandbox: Any | None, +) -> UsageProxyRunner | UsageProxyPreconditionFailure: + if host_proxy_reachable_from_agent(environment): + return HostUsageProxyRunner(environment=environment) + if environment == "daytona" and sandbox is not None: + return SandboxUsageProxyRunner(sandbox=sandbox) + return UsageProxyPreconditionFailure( + required_message=( + "Token usage tracking is required, but BenchFlow could not " + f"start a sandbox-local usage proxy for sandbox={environment!r}." + ), + skip_message=( + "Skipping usage telemetry proxy: BenchFlow could not start a " + f"sandbox-local usage proxy for sandbox={environment!r}." + ), + log_level=logging.INFO, + ) + + +def _validate_prompt_cache_retention(agent_env: dict[str, str]) -> str | None: + prompt_cache_retention = agent_env.get(PROMPT_CACHE_RETENTION_ENV) + if ( + prompt_cache_retention is not None + and prompt_cache_retention not in _PROMPT_CACHE_RETENTION_VALUES + ): + raise ValueError( + f"{PROMPT_CACHE_RETENTION_ENV} must be one of: " + f"{', '.join(sorted(_PROMPT_CACHE_RETENTION_VALUES))}" + ) + return prompt_cache_retention + + +def _resolve_usage_proxy_endpoint( + *, + agent: str, + agent_env: dict[str, str], + model: str | None, + environment: str, + runtime: ProviderRuntime | None, + sandbox: Any | None, + usage_cfg: UsageTrackingConfig, +) -> UsageProxyEndpoint | UsageProxyPreconditionFailure | None: + routing = _usage_proxy_routing( + agent=agent, + agent_env=agent_env, + model=model, + environment=environment, + ) + target = routing.target + if not target: + if usage_cfg.mode == "required": + return UsageProxyPreconditionFailure( + required_message=( + "Token usage tracking is required, but BenchFlow could not " + f"{routing.missing_target_detail}" + ), + skip_message=( + "Skipping usage telemetry proxy: BenchFlow could not " + f"{routing.missing_target_detail}" + ), + ) + return None + + target = _unproxied_usage_target(target, runtime) + if host_proxy_reachable_from_agent(environment): + target = _host_side_proxy_target_url(target, environment=environment) + runner = _usage_proxy_runner(environment=environment, sandbox=sandbox) + if isinstance(runner, UsageProxyPreconditionFailure): + return runner + return UsageProxyEndpoint(target=target, routing=routing, runner=runner) + + +async def _retire_unusable_usage_proxy_runtime( + runtime: ProviderRuntime | None, + *, + target: str, +) -> ProviderRuntime | None: + # A multi-role scene can switch providers between connect_as() calls. The + # running proxy forwards to a fixed upstream, so reusing it would route the + # new role's traffic to the wrong endpoint — retire it and start a fresh + # one for the new target. + if runtime is not None and getattr(runtime.server, "target", None) != target: + await stop_provider_runtime(runtime) + return None + if runtime is None: + return None + + is_running = getattr(runtime.server, "is_running", None) + if is_running is None: + return runtime + try: + alive = await is_running() + except Exception as exc: + logger.info("Usage telemetry proxy liveness check failed: %s", exc) + alive = False + if alive: + return runtime + + logger.info("Retiring stale usage telemetry proxy runtime") + await stop_provider_runtime(runtime) + return None + + +async def _ensure_started_usage_proxy_runtime( + *, + endpoint: UsageProxyEndpoint, + runtime: ProviderRuntime | None, + agent_env: dict[str, str], + session_id: str, + agent: str, + model: str | None, + usage_cfg: UsageTrackingConfig, +) -> ProviderRuntime | None: + runtime = await _retire_unusable_usage_proxy_runtime( + runtime, + target=endpoint.target, + ) + if runtime is not None: + return runtime + + prompt_cache_retention = _validate_prompt_cache_retention(agent_env) + try: + return await endpoint.runner.start( + target=endpoint.target, + session_id=session_id, + agent=agent, + model=model, + prompt_cache_retention=prompt_cache_retention, + ) + except Exception as exc: + if usage_cfg.mode == "required": + raise RuntimeError( + "Token usage tracking is required, but the usage proxy " + f"failed to start: {exc}" + ) from exc + logger.warning( + "Skipping usage telemetry proxy: failed to start usage proxy: %s", + exc, + ) + return None + + +def _apply_usage_proxy_env( + *, + agent: str, + agent_env: dict[str, str], + runtime: ProviderRuntime, + routing: UsageProxyRouting, +) -> dict[str, str]: + updated = dict(agent_env) + if routing.remote_direct_bedrock: + for env_name in _BEDROCK_RUNTIME_ENDPOINT_ENVS: + updated[env_name] = runtime.base_url + agent_cfg = AGENTS.get(agent) + mapped_base = ( + agent_cfg.env_mapping.get("BENCHFLOW_PROVIDER_BASE_URL") + if agent_cfg + else None + ) + if mapped_base: + updated[mapped_base] = runtime.base_url + return updated + + updated["BENCHFLOW_PROVIDER_BASE_URL"] = runtime.base_url + agent_cfg = AGENTS.get(agent) + mapped_base = ( + agent_cfg.env_mapping.get("BENCHFLOW_PROVIDER_BASE_URL") if agent_cfg else None + ) + for env_name in _agent_base_url_envs(agent): + if env_name in updated or env_name == mapped_base: + updated[env_name] = runtime.base_url + _apply_codex_config_base_url(agent, updated, runtime.base_url) + return updated + + +async def ensure_usage_proxy_runtime( + *, + agent: str, + agent_env: dict[str, str], + model: str | None, + runtime: ProviderRuntime | None, + environment: str, + session_id: str = "", + usage_tracking: UsageTrackingConfig | dict[str, Any] | str | None = None, + sandbox: Any | None = None, +) -> tuple[dict[str, str], ProviderRuntime | None]: + """Start a reachable usage proxy and wire agent provider env vars to it.""" + usage_cfg = UsageTrackingConfig.coerce(usage_tracking).with_env_defaults() + if agent == "oracle": + return agent_env, runtime + if usage_cfg.mode == "off": + if runtime is not None: + await stop_provider_runtime(runtime) + logger.info("Skipping host-side usage telemetry proxy: usage_tracking=off.") + return agent_env, None + failure = validate_usage_proxy_preconditions( + usage_cfg, + environment=environment, + model=model, + ) + if failure is not None: + return await _skip_or_block_usage_proxy( + usage_cfg=usage_cfg, + failure=failure, + agent_env=agent_env, + runtime=runtime, + ) + + endpoint = _resolve_usage_proxy_endpoint( + agent=agent, + agent_env=agent_env, + model=model, + environment=environment, + runtime=runtime, + sandbox=sandbox, + usage_cfg=usage_cfg, + ) + if endpoint is None: + return agent_env, runtime + if isinstance(endpoint, UsageProxyPreconditionFailure): + return await _skip_or_block_usage_proxy( + usage_cfg=usage_cfg, + failure=endpoint, + agent_env=agent_env, + runtime=runtime, + ) + + runtime = await _ensure_started_usage_proxy_runtime( + endpoint=endpoint, + runtime=runtime, + agent_env=agent_env, + session_id=session_id, + agent=agent, + model=model, + usage_cfg=usage_cfg, + ) + if runtime is None: + return agent_env, None + + return ( + _apply_usage_proxy_env( + agent=agent, + agent_env=agent_env, + runtime=runtime, + routing=endpoint.routing, + ), + runtime, + ) + + +def extract_usage(runtime: ProviderRuntime | None) -> dict[str, Any]: + """Extract aggregate token/cost metrics from a usage proxy runtime.""" + if runtime is None or runtime.kind != "usage-proxy" or runtime.server is None: + return _usage_unavailable() + trajectory = getattr(runtime.server, "trajectory", None) + if trajectory is None or not trajectory.exchanges: + return _usage_unavailable() + if not getattr(trajectory, "has_provider_usage", False): + return _usage_unavailable() + + input_tokens = trajectory.total_input_tokens + output_tokens = trajectory.total_output_tokens + cache_read_tokens = trajectory.total_cache_read_tokens + cache_creation_tokens = trajectory.total_cache_creation_tokens + total_tokens = trajectory.total_provider_tokens + model = _model_from_trajectory(runtime) + pricing = _pricing_for_model(model) + cost_usd = _estimate_cost_usd( + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_read_tokens=cache_read_tokens, + cache_creation_tokens=cache_creation_tokens, + cache_tokens_included_in_input=_cache_tokens_are_input_breakdown(trajectory), + ) + return { + "n_input_tokens": input_tokens, + "n_output_tokens": output_tokens, + "n_cache_read_tokens": cache_read_tokens, + "n_cache_creation_tokens": cache_creation_tokens, + "total_tokens": total_tokens, + "cost_usd": cost_usd, + "usage_source": "provider_response", + "price_source": pricing.price_source + if cost_usd is not None and pricing + else None, + } diff --git a/src/benchflow/rollout.py b/src/benchflow/rollout.py index 94c8be01..7b5b7045 100644 --- a/src/benchflow/rollout.py +++ b/src/benchflow/rollout.py @@ -42,6 +42,7 @@ import json import logging import os +import re import shlex import tempfile from dataclasses import dataclass, field @@ -111,6 +112,16 @@ def _agent_launch_with_web_policy( ) +def _agent_process_kill_pattern(agent_launch: str) -> str | None: + """Return a pkill -f pattern for the launched agent binary.""" + if not agent_launch.strip(): + return None + agent_cmd = agent_launch.split()[0].split("/")[-1] + if not agent_cmd: + return None + return rf"(^|[ /]){re.escape(agent_cmd)}( |$)" + + def _skill_nudge(agent_env: dict[str, str] | None) -> str: """Read skill nudge from explicit agent env or the host environment.""" return (agent_env or {}).get("BENCHFLOW_SKILL_NUDGE") or os.environ.get( @@ -1464,6 +1475,7 @@ async def connect(self) -> None: environment=cfg.environment, session_id=getattr(self, "_rollout_name", "") or "", usage_tracking=cfg.usage_tracking, + sandbox=self._env, ) ( self._acp_client, @@ -1500,10 +1512,13 @@ async def disconnect(self) -> None: self._session = None self._session_adapter = None # Kill any lingering agent processes to prevent context bleed between scenes - if self._env and self._agent_launch.strip(): - agent_cmd = self._agent_launch.split()[0].split("/")[-1] + agent_pattern = _agent_process_kill_pattern(self._agent_launch) + if self._env and agent_pattern: with contextlib.suppress(Exception): - await self._env.exec(f"pkill -f '{agent_cmd}' || true", timeout_sec=10) + await self._env.exec( + f"pkill -f {shlex.quote(agent_pattern)} || true", + timeout_sec=10, + ) self._active_role = None self._session_tool_count = 0 self._session_traj_count = 0 @@ -1877,6 +1892,7 @@ async def cleanup(self) -> None: logger.warning(f"LLM trajectory write failed: {e}") finally: self._usage_runtime = None + self._enforce_required_usage_tracking() if self._environment is not None: with contextlib.suppress(Exception): @@ -1908,6 +1924,20 @@ async def cleanup(self) -> None: self._phase = "cleaned" + def _enforce_required_usage_tracking(self) -> None: + usage_cfg = self._config.usage_tracking.with_env_defaults() + if usage_cfg.mode != "required" or self._config.primary_agent == "oracle": + return + if self._usage_metrics.get("usage_source") == "provider_response": + return + if self._error is not None: + return + self._error = ( + "Token usage tracking is required, but no provider token usage was " + "captured." + ) + logger.error(self._error) + # ── Full run ── async def run(self) -> RolloutResult: @@ -2317,6 +2347,7 @@ async def connect_as(self, role: Role) -> None: environment=cfg.environment, session_id=getattr(self, "_rollout_name", "") or "", usage_tracking=cfg.usage_tracking, + sandbox=self._env, ) role_agent_differs = role.agent != cfg.primary_agent diff --git a/src/benchflow/sandbox/docker.py b/src/benchflow/sandbox/docker.py index efbb538f..4bd8d835 100644 --- a/src/benchflow/sandbox/docker.py +++ b/src/benchflow/sandbox/docker.py @@ -405,7 +405,7 @@ async def stop(self, delete: bool) -> None: self._chown_to_host_user(str(SandboxPaths.logs_dir), recursive=True), timeout=30, ) - except (TimeoutError, asyncio.TimeoutError): + except TimeoutError: self.logger.warning("Chown logs directory timed out; continuing teardown.") except Exception as e: self.logger.warning(f"Failed to chown logs directory: {e}") diff --git a/src/benchflow/trajectories/proxy.py b/src/benchflow/trajectories/proxy.py index a074d925..78d393f1 100644 --- a/src/benchflow/trajectories/proxy.py +++ b/src/benchflow/trajectories/proxy.py @@ -4,6 +4,7 @@ """ import asyncio +import base64 import gzip import importlib import io @@ -419,6 +420,90 @@ def _parse_request_body(body_bytes: bytes, headers: dict[str, str]) -> dict[str, return {"raw": parsed} +def exchange_from_raw_capture(record: dict[str, Any]) -> LLMExchange: + """Build a canonical LLM exchange from a raw proxy capture record. + + Sandbox-local proxies intentionally stay dumb: they forward bytes and log + raw request/response bodies. The host process owns provider-specific JSON, + SSE, and usage parsing so Docker and Daytona share one interpretation path. + """ + request = record.get("request") or {} + response = record.get("response") or {} + request_headers = _lower_headers(request.get("headers") or {}) + response_headers = _lower_headers(response.get("headers") or {}) + request_body_bytes = _decode_b64(request.get("body_b64")) + response_body_bytes = _decode_b64(response.get("body_b64")) + path = str(request.get("path") or "") + request_body = _parse_request_body(request_body_bytes, request_headers) + response_body = _parse_response_body( + response_body_bytes, + response_headers, + path=path, + request_body=request_body, + ) + return LLMExchange( + request=LLMRequest( + method=str(request.get("method") or "POST"), + path=path, + headers=request_headers, + body=request_body, + ), + response=LLMResponse( + status_code=int(response.get("status_code") or 0), + headers=response_headers, + body=response_body, + ), + duration_ms=float(record.get("duration_ms") or 0.0), + ) + + +def _lower_headers(headers: dict[str, Any]) -> dict[str, str]: + return {str(k).lower(): str(v) for k, v in headers.items()} + + +def _decode_b64(value: Any) -> bytes: + if not value: + return b"" + return base64.b64decode(str(value).encode()) + + +def _parse_response_body( + body_bytes: bytes, + headers: dict[str, str], + *, + path: str, + request_body: dict[str, Any], +) -> dict[str, Any]: + if not body_bytes: + return {} + + decoded = _decode_request_body( + body_bytes, headers.get("content-encoding", "identity") + ) + content_type = headers.get("content-type", "").lower() + if "text/event-stream" in content_type: + try: + return _reconstruct_sse_response(decoded) + except Exception as e: + logger.warning(f"SSE response reconstruction failed: {e}") + return {"raw": decoded.decode(errors="replace")[:_RAW_RESP_TRUNCATE]} + + try: + parsed = json.loads(decoded) + except (json.JSONDecodeError, UnicodeDecodeError): + if not content_type and ( + request_body.get("stream", False) or _is_sse_request_path(path) + ): + try: + return _reconstruct_sse_response(decoded) + except Exception as e: + logger.warning(f"SSE response reconstruction failed: {e}") + return {"raw": decoded.decode(errors="replace")[:_RAW_RESP_TRUNCATE]} + if isinstance(parsed, dict): + return parsed + return {"raw": parsed} + + def _normalize_path_prefix(path_prefix: str) -> str: if not path_prefix: return "" diff --git a/src/benchflow/trajectories/types.py b/src/benchflow/trajectories/types.py index a36f7d48..46ef3b77 100644 --- a/src/benchflow/trajectories/types.py +++ b/src/benchflow/trajectories/types.py @@ -1,10 +1,140 @@ """Trajectory types — raw LLM API request/response pairs captured by proxy.""" +from dataclasses import dataclass from datetime import datetime from typing import Any from pydantic import BaseModel, Field +_USAGE_KEYS = { + "input_tokens", + "output_tokens", + "prompt_tokens", + "completion_tokens", + "total_tokens", + "cache_read_input_tokens", + "cache_creation_input_tokens", + "inputTokens", + "outputTokens", + "totalTokens", + "cacheReadInputTokenCount", + "cacheReadInputTokens", + "cacheWriteInputTokenCount", + "cacheWriteInputTokens", +} +_USAGE_DETAIL_KEYS = { + "cached_tokens", +} +_USAGE_METADATA_KEYS = { + "promptTokenCount", + "candidatesTokenCount", + "totalTokenCount", + "cachedContentTokenCount", +} + + +def _has_non_null_key(payload: dict[str, Any], keys: set[str]) -> bool: + return any(key in payload and payload[key] is not None for key in keys) + + +def _has_provider_usage(payload: dict[str, Any]) -> bool: + if _has_non_null_key(payload, _USAGE_KEYS): + return True + for key in ("prompt_tokens_details", "input_tokens_details"): + details = payload.get(key) + if isinstance(details, dict) and _has_non_null_key(details, _USAGE_DETAIL_KEYS): + return True + return False + + +def _first_int(*values: Any) -> int: + """Return the first non-null usage value as an integer.""" + for value in values: + if value is None: + continue + try: + return int(value) + except (TypeError, ValueError): + continue + return 0 + + +def _first_optional_int(*values: Any) -> int | None: + for value in values: + if value is None: + continue + try: + return int(value) + except (TypeError, ValueError): + continue + return None + + +@dataclass(frozen=True) +class TokenUsage: + input_tokens: int = 0 + output_tokens: int = 0 + cache_read_tokens: int = 0 + cache_creation_tokens: int = 0 + provider_total_tokens: int | None = None + + @property + def total_tokens(self) -> int: + if self.provider_total_tokens is not None: + return self.provider_total_tokens + return ( + self.input_tokens + + self.output_tokens + + self.cache_read_tokens + + self.cache_creation_tokens + ) + + +def _exchange_token_usage(exchange: "LLMExchange") -> TokenUsage: + usage = exchange.response.body.get("usage") + usage = usage if isinstance(usage, dict) else {} + usage_metadata = exchange.response.body.get("usageMetadata") + usage_metadata = usage_metadata if isinstance(usage_metadata, dict) else {} + # OpenAI may return these keys with an explicit null value, so + # `or {}` is required — `.get(key, {})` would still yield None. + prompt_details = usage.get("prompt_tokens_details") or {} + prompt_details = prompt_details if isinstance(prompt_details, dict) else {} + input_details = usage.get("input_tokens_details") or {} + input_details = input_details if isinstance(input_details, dict) else {} + + return TokenUsage( + input_tokens=_first_int( + usage.get("input_tokens"), + usage.get("prompt_tokens"), + usage.get("inputTokens"), + usage_metadata.get("promptTokenCount"), + ), + output_tokens=_first_int( + usage.get("output_tokens"), + usage.get("completion_tokens"), + usage.get("outputTokens"), + usage_metadata.get("candidatesTokenCount"), + ), + cache_read_tokens=_first_int( + usage.get("cache_read_input_tokens"), + usage.get("cacheReadInputTokens"), + usage.get("cacheReadInputTokenCount"), + prompt_details.get("cached_tokens"), + input_details.get("cached_tokens"), + usage_metadata.get("cachedContentTokenCount"), + ), + cache_creation_tokens=_first_int( + usage.get("cache_creation_input_tokens"), + usage.get("cacheWriteInputTokens"), + usage.get("cacheWriteInputTokenCount"), + ), + provider_total_tokens=_first_optional_int( + usage.get("total_tokens"), + usage_metadata.get("totalTokenCount"), + usage.get("totalTokens"), + ), + ) + class LLMRequest(BaseModel): """A single request to an LLM API, captured by the proxy.""" @@ -44,86 +174,40 @@ class Trajectory(BaseModel): metadata: dict[str, Any] = Field(default_factory=dict) @property - def total_input_tokens(self) -> int: - total = 0 + def has_provider_usage(self) -> bool: + """Whether any exchange contains provider-supplied usage fields.""" for ex in self.exchanges: - usage = ex.response.body.get("usage", {}) - usage_metadata = ex.response.body.get("usageMetadata", {}) - total += ( - usage.get("input_tokens", 0) - or usage.get("prompt_tokens", 0) - or usage_metadata.get("promptTokenCount", 0) - ) - return total + usage = ex.response.body.get("usage") + if isinstance(usage, dict) and _has_provider_usage(usage): + return True + usage_metadata = ex.response.body.get("usageMetadata") + if isinstance(usage_metadata, dict) and _has_non_null_key( + usage_metadata, _USAGE_METADATA_KEYS + ): + return True + return False + + @property + def total_input_tokens(self) -> int: + return sum(_exchange_token_usage(ex).input_tokens for ex in self.exchanges) @property def total_output_tokens(self) -> int: - total = 0 - for ex in self.exchanges: - usage = ex.response.body.get("usage", {}) - usage_metadata = ex.response.body.get("usageMetadata", {}) - total += ( - usage.get("output_tokens", 0) - or usage.get("completion_tokens", 0) - or usage_metadata.get("candidatesTokenCount", 0) - ) - return total + return sum(_exchange_token_usage(ex).output_tokens for ex in self.exchanges) @property def total_cache_read_tokens(self) -> int: - total = 0 - for ex in self.exchanges: - usage = ex.response.body.get("usage", {}) - usage_metadata = ex.response.body.get("usageMetadata", {}) - # OpenAI may return these keys with an explicit null value, so - # `or {}` is required — `.get(key, {})` would still yield None. - prompt_details = usage.get("prompt_tokens_details") or {} - input_details = usage.get("input_tokens_details") or {} - total += ( - usage.get("cache_read_input_tokens", 0) - or prompt_details.get("cached_tokens", 0) - or input_details.get("cached_tokens", 0) - or usage_metadata.get("cachedContentTokenCount", 0) - or 0 - ) - return total + return sum(_exchange_token_usage(ex).cache_read_tokens for ex in self.exchanges) @property def total_cache_creation_tokens(self) -> int: - total = 0 - for ex in self.exchanges: - usage = ex.response.body.get("usage", {}) - total += usage.get("cache_creation_input_tokens", 0) or 0 - return total + return sum( + _exchange_token_usage(ex).cache_creation_tokens for ex in self.exchanges + ) @property def total_provider_tokens(self) -> int: - total = 0 - for ex in self.exchanges: - usage = ex.response.body.get("usage", {}) - usage_metadata = ex.response.body.get("usageMetadata", {}) - provider_total = usage.get("total_tokens") or usage_metadata.get( - "totalTokenCount" - ) - if provider_total is not None: - total += provider_total - continue - input_tokens = ( - usage.get("input_tokens", 0) - or usage.get("prompt_tokens", 0) - or usage_metadata.get("promptTokenCount", 0) - ) - output_tokens = ( - usage.get("output_tokens", 0) - or usage.get("completion_tokens", 0) - or usage_metadata.get("candidatesTokenCount", 0) - ) - cache_read_tokens = usage.get("cache_read_input_tokens", 0) or 0 - cache_creation_tokens = usage.get("cache_creation_input_tokens", 0) or 0 - total += ( - input_tokens + output_tokens + cache_read_tokens + cache_creation_tokens - ) - return total + return sum(_exchange_token_usage(ex).total_tokens for ex in self.exchanges) @property def total_cost_usd(self) -> float | None: diff --git a/src/benchflow/usage_tracking.py b/src/benchflow/usage_tracking.py index 732f6a6e..d012f490 100644 --- a/src/benchflow/usage_tracking.py +++ b/src/benchflow/usage_tracking.py @@ -5,18 +5,21 @@ import os from dataclasses import dataclass from typing import Any, Literal, cast -from urllib.parse import urlsplit UsageTrackingMode = Literal["auto", "required", "off"] USAGE_TRACKING_ENV = "BENCHFLOW_USAGE_TRACKING" -USAGE_PROXY_ADVERTISED_BASE_URL_ENV = "BENCHFLOW_USAGE_PROXY_ADVERTISED_BASE_URL" -USAGE_PROXY_BIND_HOST_ENV = "BENCHFLOW_USAGE_PROXY_BIND_HOST" -USAGE_PROXY_PORT_ENV = "BENCHFLOW_USAGE_PROXY_PORT" - -DEFAULT_USAGE_PROXY_BIND_HOST = "0.0.0.0" _MODES: set[str] = {"auto", "required", "off"} +_LEGACY_USAGE_PROXY_KEYS: frozenset[str] = frozenset( + { + "usage_proxy", + "usage_proxy_advertised_base_url", + "usage_proxy_bind_host", + "usage_proxy_port", + "usage_proxy_url", + } +) def normalize_usage_tracking_mode(value: str) -> UsageTrackingMode: @@ -33,47 +36,6 @@ def _optional_mode(value: Any) -> UsageTrackingMode | None: return normalize_usage_tracking_mode(str(value)) -def _optional_str(value: Any) -> str | None: - if value is None: - return None - text = str(value).strip() - return text or None - - -def _first_present(*values: Any) -> Any: - for value in values: - if value is not None: - return value - return None - - -def _optional_port(value: Any) -> int | None: - if value is None or value == "": - return None - port = int(value) - if port < 0 or port > 65535: - raise ValueError("usage proxy port must be between 0 and 65535") - return port - - -def normalize_advertised_base_url(value: str | None) -> str | None: - url = _optional_str(value) - if url is None: - return None - parsed = urlsplit(url) - if parsed.scheme not in {"http", "https"} or not parsed.netloc: - raise ValueError( - "usage_proxy.advertised_base_url must be an absolute http(s) URL" - ) - if parsed.query or parsed.fragment: - raise ValueError( - "usage_proxy.advertised_base_url must not include query or fragment" - ) - if parsed.path not in {"", "/"}: - raise ValueError("usage_proxy.advertised_base_url must not include a path") - return url.rstrip("/") - - @dataclass(frozen=True, init=False) class UsageTrackingConfig: """User-facing token/cost telemetry policy. @@ -85,25 +47,12 @@ class UsageTrackingConfig: """ _mode: UsageTrackingMode | None - advertised_base_url: str | None = None - bind_host: str | None = None - port: int | None = None def __init__( self, mode: str | None = None, - advertised_base_url: str | None = None, - bind_host: str | None = None, - port: int | str | None = None, ) -> None: object.__setattr__(self, "_mode", _optional_mode(mode)) - object.__setattr__( - self, - "advertised_base_url", - normalize_advertised_base_url(advertised_base_url), - ) - object.__setattr__(self, "bind_host", _optional_str(bind_host)) - object.__setattr__(self, "port", _optional_port(port)) @property def mode(self) -> UsageTrackingMode: @@ -113,64 +62,21 @@ def mode(self) -> UsageTrackingMode: def mode_is_explicit(self) -> bool: return self._mode is not None - @property - def uses_external_proxy(self) -> bool: - return self.advertised_base_url is not None - - @property - def has_fixed_proxy_port(self) -> bool: - return self.port is not None and self.port > 0 - def overlay(self, override: UsageTrackingConfig) -> UsageTrackingConfig: """Return this config with explicitly supplied override fields applied.""" return UsageTrackingConfig( mode=override._mode if override.mode_is_explicit else self._mode, - advertised_base_url=( - override.advertised_base_url - if override.advertised_base_url is not None - else self.advertised_base_url - ), - bind_host=( - override.bind_host if override.bind_host is not None else self.bind_host - ), - port=override.port if override.port is not None else self.port, ) - def validate_parallelism(self, *, concurrency: int, worker_count: int = 1) -> None: - if ( - self.mode != "off" - and self.uses_external_proxy - and max(concurrency, worker_count) > 1 - ): - raise ValueError( - "External usage proxy tracking currently supports only one " - "rollout per fixed local proxy port. Use --concurrency 1 and " - "one worker, or run separate jobs with separate " - f"{USAGE_PROXY_PORT_ENV} values/tunnels." - ) - @classmethod def from_mapping(cls, raw: dict[str, Any]) -> UsageTrackingConfig: - proxy = raw.get("usage_proxy") - if proxy is None: - proxy = {} - elif not isinstance(proxy, dict): - raise ValueError("usage_proxy must be a mapping") - return cls( - mode=raw.get("usage_tracking"), - advertised_base_url=( - _first_present( - proxy.get("advertised_base_url"), - raw.get("usage_proxy_advertised_base_url"), - raw.get("usage_proxy_url"), - ) - ), - bind_host=_first_present( - proxy.get("bind_host"), - raw.get("usage_proxy_bind_host"), - ), - port=_first_present(proxy.get("port"), raw.get("usage_proxy_port")), - ) + legacy_keys = sorted(set(raw) & _LEGACY_USAGE_PROXY_KEYS) + if legacy_keys: + raise ValueError( + f"{', '.join(legacy_keys)} is no longer supported; use usage_tracking=" + "auto|required|off instead." + ) + return cls(mode=raw.get("usage_tracking")) @classmethod def coerce( @@ -187,19 +93,9 @@ def coerce( raise TypeError(f"invalid usage_tracking config: {type(value).__name__}") def to_mapping(self) -> dict[str, Any]: - proxy: dict[str, Any] = {} - if self.advertised_base_url is not None: - proxy["advertised_base_url"] = self.advertised_base_url - if self.bind_host is not None: - proxy["bind_host"] = self.bind_host - if self.port is not None: - proxy["port"] = self.port - payload: dict[str, Any] = {} if self.mode_is_explicit: payload["usage_tracking"] = self.mode - if proxy: - payload["usage_proxy"] = proxy return payload def with_env_defaults(self) -> UsageTrackingConfig: @@ -208,27 +104,10 @@ def with_env_defaults(self) -> UsageTrackingConfig: if env_mode and not self.mode_is_explicit: mode = normalize_usage_tracking_mode(env_mode) - return UsageTrackingConfig( - mode=mode, - advertised_base_url=( - self.advertised_base_url - or os.environ.get(USAGE_PROXY_ADVERTISED_BASE_URL_ENV) - ), - bind_host=self.bind_host or os.environ.get(USAGE_PROXY_BIND_HOST_ENV), - port=( - self.port - if self.port is not None - else os.environ.get(USAGE_PROXY_PORT_ENV) - ), - ) + return UsageTrackingConfig(mode=mode) def to_config_artifact(self) -> dict[str, Any]: - return { - "requested": self.mode, - "advertised_base_url_configured": self.uses_external_proxy, - "bind_host": self.bind_host, - "port": self.port, - } + return {"requested": self.mode} def to_result_metadata( self, @@ -237,11 +116,13 @@ def to_result_metadata( status: str, usage_source: str, ) -> dict[str, Any]: + endpoint_kind = "sandbox" if environment == "daytona" else "host" + if self.mode == "off": + endpoint_kind = "none" return { "requested": self.mode, "status": status, "environment": environment, - "endpoint_kind": "external" if self.uses_external_proxy else "host", + "endpoint_kind": endpoint_kind, "usage_source": usage_source, - "advertised_base_url_configured": self.uses_external_proxy, } diff --git a/tests/test_agent_registry.py b/tests/test_agent_registry.py index 51328825..fe86a240 100644 --- a/tests/test_agent_registry.py +++ b/tests/test_agent_registry.py @@ -64,7 +64,7 @@ def test_openhands_normalizes_model(self): agent_env=env, ) - assert env["LLM_MODEL"] == "glm-5" + assert env["LLM_MODEL"] == "openai/glm-5" def test_openhands_bedrock_anthropic_model_uses_litellm_provider_prefix(self): """Guards v0.5-integration@e55219d against OpenHands receiving a bare Bedrock profile id.""" diff --git a/tests/test_cli_run.py b/tests/test_cli_run.py index 0fcc2bcd..67e92f86 100644 --- a/tests/test_cli_run.py +++ b/tests/test_cli_run.py @@ -120,3 +120,40 @@ def test_benchflow_eval_list_surfaces_root_summary_memory_score(tmp_path): assert "1/2" in result.output assert "50.0%" in result.output assert "75.0%" in result.output + + +def test_eval_create_reports_runtime_config_errors_without_traceback( + tmp_path, monkeypatch +): + """Guards PR #587: required usage preflight failures stay user-facing.""" + tasks_dir = tmp_path / "tasks" + tasks_dir.mkdir() + + class FakeEvaluation: + def __init__(self, **_kwargs): + pass + + async def run(self): + raise RuntimeError("Token usage tracking is required") + + monkeypatch.setattr("benchflow.evaluation.Evaluation", FakeEvaluation) + + result = CliRunner().invoke( + app, + [ + "eval", + "create", + "--tasks-dir", + str(tasks_dir), + "--agent", + "openhands", + "--model", + "aws-bedrock/example-model", + "--usage-tracking", + "required", + ], + ) + + assert result.exit_code == 1 + assert "Token usage tracking is required" in result.output + assert "Traceback" not in result.output diff --git a/tests/test_config_redaction.py b/tests/test_config_redaction.py index 596bc022..72f95ae7 100644 --- a/tests/test_config_redaction.py +++ b/tests/test_config_redaction.py @@ -119,7 +119,7 @@ def test_write_config_drops_secret_env_vars(tmp_path: Path) -> None: def test_write_config_drops_usage_proxy_secret_base_urls(tmp_path: Path) -> None: - """Guards PR #568: external usage proxy path prefixes are bearer secrets.""" + """Provider proxy URLs with BenchFlow secret path segments must be redacted.""" secret_base = "https://usage.example.test/__benchflow/secret-prefix" agent_env = { "BENCHFLOW_PROVIDER_BASE_URL": secret_base, diff --git a/tests/test_daytona_usage_runtime.py b/tests/test_daytona_usage_runtime.py new file mode 100644 index 00000000..74919037 --- /dev/null +++ b/tests/test_daytona_usage_runtime.py @@ -0,0 +1,414 @@ +"""Tests for Daytona-specific provider usage runtime wiring.""" + +from __future__ import annotations + +import json +from types import SimpleNamespace + +import pytest + +from benchflow.providers import usage_proxy_runtime as usage_runtime_mod +from benchflow.trajectories.types import Trajectory + + +@pytest.mark.asyncio +async def test_registered_openai_compatible_provider_uses_sandbox_usage_proxy( + monkeypatch, +): + """Guards PR #587: direct provider envs route through Daytona usage proxy.""" + from benchflow.agents.env import resolve_agent_env + from benchflow.providers.runtime import ensure_usage_proxy_runtime + + started = [] + + class FakeSandboxUsageProxy: + base_url = "http://127.0.0.1:49001" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.trajectory = Trajectory( + session_id=kwargs["session_id"], agent_name=kwargs["agent_name"] + ) + + async def start(self): + started.append(self.target) + + async def stop(self): + return None + + monkeypatch.setattr(usage_runtime_mod, "SandboxUsageProxy", FakeSandboxUsageProxy) + + env = resolve_agent_env( + "openhands", + "kimi/kimi-k2.6", + { + "KIMI_API_KEY": "sk-kimi", + "KIMI_BASE_URL": "https://api.moonshot.ai/v1", + }, + ) + + updated, runtime = await ensure_usage_proxy_runtime( + agent="openhands", + agent_env=env, + model="kimi/kimi-k2.6", + runtime=None, + environment="daytona", + session_id="rollout-1", + sandbox=object(), + ) + + assert runtime is not None + assert started == ["https://api.moonshot.ai/v1"] + assert runtime.server.target == "https://api.moonshot.ai/v1" + assert updated["LLM_BASE_URL"] == "http://127.0.0.1:49001" + assert updated["BENCHFLOW_PROVIDER_BASE_URL"] == "http://127.0.0.1:49001" + + +@pytest.mark.asyncio +async def test_usage_runtime_reconnect_ignores_own_proxy_url(monkeypatch): + """Guards PR #587: reconnects must not point a new proxy at the old proxy.""" + from benchflow.providers.runtime import ensure_usage_proxy_runtime + + started = [] + stopped = [] + + class FakeSandboxUsageProxy: + base_url = "http://127.0.0.1:49001" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.trajectory = Trajectory( + session_id=kwargs["session_id"], agent_name=kwargs["agent_name"] + ) + + async def start(self): + started.append(self.target) + + async def is_running(self): + return True + + async def stop(self): + stopped.append(self.target) + + monkeypatch.setattr(usage_runtime_mod, "SandboxUsageProxy", FakeSandboxUsageProxy) + + env = { + "ANTHROPIC_BASE_URL": "https://api.anthropic.com", + "ANTHROPIC_API_KEY": "sk-real-key", + } + first_env, first_runtime = await ensure_usage_proxy_runtime( + agent="claude-agent-acp", + agent_env=env, + model="claude-haiku-4-5-20251001", + runtime=None, + environment="daytona", + session_id="rollout-1", + sandbox=object(), + ) + second_env, second_runtime = await ensure_usage_proxy_runtime( + agent="claude-agent-acp", + agent_env=first_env, + model="claude-haiku-4-5-20251001", + runtime=first_runtime, + environment="daytona", + session_id="rollout-1", + sandbox=object(), + ) + + assert first_runtime is not None + assert second_runtime is first_runtime + assert started == ["https://api.anthropic.com"] + assert stopped == [] + assert first_runtime.server.target == "https://api.anthropic.com" + assert second_env["ANTHROPIC_BASE_URL"] == "http://127.0.0.1:49001" + + +@pytest.mark.asyncio +async def test_dead_usage_runtime_reconnect_uses_original_upstream(monkeypatch): + """Guards PR #587: stale-proxy replacement must dial the provider, not itself.""" + from benchflow.providers.runtime import ProviderRuntime, ensure_usage_proxy_runtime + + stopped = [] + started = [] + + class DeadServer: + target = "https://api.anthropic.com" + + async def is_running(self): + return False + + async def stop(self): + stopped.append("dead") + + class FakeSandboxUsageProxy: + base_url = "http://127.0.0.1:49001" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.trajectory = Trajectory( + session_id=kwargs["session_id"], agent_name=kwargs["agent_name"] + ) + + async def start(self): + started.append(self.target) + + async def stop(self): + stopped.append("new") + + monkeypatch.setattr(usage_runtime_mod, "SandboxUsageProxy", FakeSandboxUsageProxy) + + stale_runtime = ProviderRuntime( + kind="usage-proxy", + agent_base_url="http://127.0.0.1:49000", + backend_model="claude-haiku-4-5-20251001", + server=DeadServer(), + ) + + updated, runtime = await ensure_usage_proxy_runtime( + agent="claude-agent-acp", + agent_env={ + "ANTHROPIC_BASE_URL": "http://127.0.0.1:49000", + "BENCHFLOW_PROVIDER_BASE_URL": "http://127.0.0.1:49000", + "ANTHROPIC_API_KEY": "sk-real-key", + }, + model="claude-haiku-4-5-20251001", + runtime=stale_runtime, + environment="daytona", + session_id="rollout-1", + sandbox=object(), + ) + + assert stopped == ["dead"] + assert started == ["https://api.anthropic.com"] + assert runtime is not None + assert runtime is not stale_runtime + assert runtime.server.target == "https://api.anthropic.com" + assert updated["ANTHROPIC_BASE_URL"] == "http://127.0.0.1:49001" + + +@pytest.mark.asyncio +async def test_codex_provider_config_is_repointed_at_usage_proxy(monkeypatch): + """Guards PR #587: Codex custom providers must not bypass telemetry proxy.""" + from benchflow.providers.runtime import ensure_usage_proxy_runtime + + class FakeSandboxUsageProxy: + target = "https://example-resource.openai.azure.com/openai/v1" + base_url = "http://127.0.0.1:49001" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.trajectory = Trajectory( + session_id=kwargs["session_id"], agent_name=kwargs["agent_name"] + ) + + async def start(self): + return None + + async def stop(self): + return None + + monkeypatch.setattr(usage_runtime_mod, "SandboxUsageProxy", FakeSandboxUsageProxy) + + env = { + "BENCHFLOW_PROVIDER_BASE_URL": ( + "https://example-resource.openai.azure.com/openai/v1" + ), + "OPENAI_BASE_URL": "https://example-resource.openai.azure.com/openai/v1", + "BENCHFLOW_PROVIDER_MODEL": "gpt-5.5", + "OPENAI_API_KEY": "az-test", + "MODEL_PROVIDER": "benchflow-azure-foundry-openai", + "CODEX_CONFIG": json.dumps( + { + "model_provider": "benchflow-azure-foundry-openai", + "model": "gpt-5.5", + "model_providers": { + "benchflow-azure-foundry-openai": { + "name": "azure-foundry-openai", + "base_url": ( + "https://example-resource.openai.azure.com/openai/v1" + ), + "env_key": "OPENAI_API_KEY", + "wire_api": "responses", + "supports_websockets": False, + } + }, + } + ), + } + + updated, runtime = await ensure_usage_proxy_runtime( + agent="codex-acp", + agent_env=env, + model="azure-foundry-openai/gpt-5.5", + runtime=None, + environment="daytona", + session_id="rollout-1", + sandbox=object(), + ) + + assert runtime is not None + assert updated["OPENAI_BASE_URL"] == "http://127.0.0.1:49001" + codex_config = json.loads(updated["CODEX_CONFIG"]) + provider = codex_config["model_providers"]["benchflow-azure-foundry-openai"] + assert provider["base_url"] == "http://127.0.0.1:49001" + + +@pytest.mark.asyncio +async def test_codex_native_openai_gets_usage_proxy_provider_config(monkeypatch): + """Guards PR #587: native Codex OpenAI runs must not bypass telemetry.""" + from benchflow.providers.runtime import ensure_usage_proxy_runtime + + class FakeSandboxUsageProxy: + target = "https://api.openai.com/v1" + base_url = "http://127.0.0.1:49001" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.trajectory = Trajectory( + session_id=kwargs["session_id"], agent_name=kwargs["agent_name"] + ) + + async def start(self): + return None + + async def stop(self): + return None + + monkeypatch.setattr(usage_runtime_mod, "SandboxUsageProxy", FakeSandboxUsageProxy) + + updated, runtime = await ensure_usage_proxy_runtime( + agent="codex-acp", + agent_env={ + "OPENAI_API_KEY": "sk-test", + "BENCHFLOW_PROVIDER_MODEL": "gpt-5.4-mini", + }, + model="gpt-5.4-mini", + runtime=None, + environment="daytona", + session_id="rollout-1", + sandbox=object(), + ) + + assert runtime is not None + assert updated["OPENAI_BASE_URL"] == "http://127.0.0.1:49001" + assert updated["MODEL_PROVIDER"] == "benchflow-openai" + codex_config = json.loads(updated["CODEX_CONFIG"]) + assert codex_config["model"] == "gpt-5.4-mini" + assert codex_config["model_provider"] == "benchflow-openai" + provider = codex_config["model_providers"]["benchflow-openai"] + assert provider == { + "name": "openai", + "base_url": "http://127.0.0.1:49001", + "env_key": "OPENAI_API_KEY", + "wire_api": "responses", + "supports_websockets": False, + } + + +@pytest.mark.asyncio +async def test_daytona_openhands_bedrock_usage_proxy_sets_aws_endpoint(monkeypatch): + """Guards PR #587: remote Bedrock-direct OpenHands is metered in-sandbox.""" + from benchflow.providers.runtime import ( + ensure_bedrock_proxy_runtime, + ensure_usage_proxy_runtime, + ) + from benchflow.usage_tracking import UsageTrackingConfig + + class FakeSandboxUsageProxy: + base_url = "http://127.0.0.1:49001" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.trajectory = Trajectory( + session_id=kwargs["session_id"], agent_name=kwargs["agent_name"] + ) + + async def start(self): + return None + + async def stop(self): + return None + + monkeypatch.setattr(usage_runtime_mod, "SandboxUsageProxy", FakeSandboxUsageProxy) + + agent_env = { + "AWS_BEARER_TOKEN_BEDROCK": "bedrock-token", + "AWS_REGION": "us-west-2", + "LLM_BASE_URL": "", + "LLM_MODEL": "anthropic/us.anthropic.claude-opus-4-7", + } + bedrock_env, bedrock_runtime = await ensure_bedrock_proxy_runtime( + agent="openhands", + agent_env=agent_env, + model="aws-bedrock/us.anthropic.claude-opus-4-7", + runtime=None, + environment="daytona", + ) + + assert bedrock_runtime is None + assert "LLM_BASE_URL" not in bedrock_env + assert bedrock_env["LLM_MODEL"] == "bedrock/us.anthropic.claude-opus-4-7" + + usage_env, usage_runtime = await ensure_usage_proxy_runtime( + agent="openhands", + agent_env=bedrock_env, + model="aws-bedrock/us.anthropic.claude-opus-4-7", + runtime=None, + environment="daytona", + usage_tracking=UsageTrackingConfig(mode="required"), + sandbox=object(), + ) + + assert usage_runtime is not None + assert ( + usage_runtime.server.target == "https://bedrock-runtime.us-west-2.amazonaws.com" + ) + assert usage_env["LLM_BASE_URL"] == usage_runtime.base_url + assert "BENCHFLOW_PROVIDER_BASE_URL" not in usage_env + assert usage_env["AWS_REGION_NAME"] == "us-west-2" + assert usage_env["AWS_ENDPOINT_URL_BEDROCK_RUNTIME"] == usage_runtime.base_url + assert usage_env["AWS_ENDPOINT_URL_BEDROCK"] == usage_runtime.base_url + + +def test_extract_usage_accepts_bedrock_converse_usage_shape(): + """Guards PR #587: Bedrock Converse usage fields count as provider usage.""" + from benchflow.providers.runtime import ProviderRuntime, extract_usage + from benchflow.trajectories.types import LLMExchange, LLMRequest, LLMResponse + + trajectory = Trajectory(session_id="rollout-1", agent_name="openhands") + trajectory.exchanges.append( + LLMExchange( + request=LLMRequest( + path="/model/us.anthropic.claude-opus-4-7/converse", + body={"modelId": "us.anthropic.claude-opus-4-7"}, + ), + response=LLMResponse( + status_code=200, + body={ + "usage": { + "cacheReadInputTokens": 100, + "cacheWriteInputTokens": 200, + "inputTokens": 34, + "outputTokens": 13, + "totalTokens": 347, + } + }, + ), + duration_ms=12, + ) + ) + + runtime = ProviderRuntime( + kind="usage-proxy", + agent_base_url="http://127.0.0.1:49000", + backend_model="us.anthropic.claude-opus-4-7", + server=SimpleNamespace(trajectory=trajectory), + ) + usage = extract_usage(runtime) + + assert usage["usage_source"] == "provider_response" + assert usage["n_input_tokens"] == 34 + assert usage["n_output_tokens"] == 13 + assert usage["n_cache_read_tokens"] == 100 + assert usage["n_cache_creation_tokens"] == 200 + assert usage["total_tokens"] == 347 diff --git a/tests/test_provider_runtime.py b/tests/test_provider_runtime.py index b3f54d66..81b938e3 100644 --- a/tests/test_provider_runtime.py +++ b/tests/test_provider_runtime.py @@ -7,6 +7,7 @@ import pytest from benchflow.providers import runtime as provider_runtime_mod +from benchflow.providers import usage_proxy_runtime as usage_runtime_mod from benchflow.providers.runtime import ( ProviderRuntime, _bedrock_frontend_model, @@ -332,7 +333,7 @@ async def start(self): async def stop(self): return None - monkeypatch.setattr(provider_runtime_mod, "TrajectoryProxy", FakeProxy) + monkeypatch.setattr(usage_runtime_mod, "TrajectoryProxy", FakeProxy) updated, runtime = await ensure_usage_proxy_runtime( agent="gemini", diff --git a/tests/test_providers.py b/tests/test_providers.py index 02847f3d..5af5ff46 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -54,6 +54,42 @@ def test_azure_foundry_prefixes(self, model, expected_protocol): assert cfg.api_protocol == expected_protocol assert cfg.auth_env == "AZURE_API_KEY" + @pytest.mark.parametrize( + ("model", "expected_name", "expected_auth_env"), + [ + ("litellm/glm-5.1", "litellm", "LITELLM_API_KEY"), + ("kimi/kimi-k2.6", "kimi", "KIMI_API_KEY"), + ("minimax/MiniMax-M2.7", "minimax", "MINIMAX_API_KEY"), + ( + "qwen-dashscope/qwen3.6-max-preview", + "qwen-dashscope", + "QWEN_API_KEY", + ), + ("glm/glm-5.1", "glm", "GLM_API_KEY"), + ("deepseek/deepseek-v4-pro", "deepseek", "DEEPSEEK_API_KEY"), + ("xiaomi/mimo-v2.5-pro", "xiaomi", "XIAOMI_API_KEY"), + ( + "doubao-seed-2-lite/ep-test", + "doubao-seed-2-lite", + "DOUBAO_SEED_2_LITE_API_KEY", + ), + ( + "doubao-seed-2-pro/ep-test", + "doubao-seed-2-pro", + "DOUBAO_SEED_2_PRO_API_KEY", + ), + ("hunyuan/hy3-preview", "hunyuan", "HUNYUAN_API_KEY"), + ], + ) + def test_openai_compatible_provider_prefixes( + self, model, expected_name, expected_auth_env + ): + """Guards PR #587: direct provider keys resolve without generic vllm envs.""" + name, cfg = find_provider(model) + assert name == expected_name + assert cfg.api_protocol == "openai-completions" + assert cfg.auth_env == expected_auth_env + # ── resolve_base_url: template expansion ── @@ -140,6 +176,12 @@ def test_azure_anthropic_resource_expansion(self): == "https://example-resource.services.ai.azure.com/anthropic" ) + def test_openai_compatible_provider_base_url_expansion(self): + p = PROVIDERS["kimi"] + env = {"KIMI_BASE_URL": "https://api.moonshot.ai/v1"} + + assert resolve_base_url(p, env) == "https://api.moonshot.ai/v1" + # ── resolve_auth_env: which env var does this provider need? ── @@ -164,6 +206,11 @@ def test_azure_foundry_uses_shared_key(self): == "AZURE_API_KEY" ) + def test_direct_openai_compatible_provider_keys(self): + assert resolve_auth_env("kimi/kimi-k2.6") == "KIMI_API_KEY" + assert resolve_auth_env("qwen-dashscope/qwen3.6-max-preview") == "QWEN_API_KEY" + assert resolve_auth_env("glm/glm-5.1") == "GLM_API_KEY" + # ── Integration: backward compat with registry.py ── diff --git a/tests/test_registry_invariants.py b/tests/test_registry_invariants.py index 2102de51..55a52af9 100644 --- a/tests/test_registry_invariants.py +++ b/tests/test_registry_invariants.py @@ -401,6 +401,9 @@ def test_provider_models_and_credentials(name, cfg): ("aws-bedrock/openai.gpt-oss-20b-1:0", "aws-bedrock"), ("zai/glm-5", "zai"), ("vllm/local-model", "vllm"), + ("kimi/kimi-k2.6", "kimi"), + ("qwen-dashscope/qwen3.6-max-preview", "qwen-dashscope"), + ("doubao-seed-2-pro/ep-test", "doubao-seed-2-pro"), ], ) def test_find_provider_resolves_known_prefixes(model, expected): diff --git a/tests/test_resolve_env_helpers.py b/tests/test_resolve_env_helpers.py index 67ab4286..386aa45f 100644 --- a/tests/test_resolve_env_helpers.py +++ b/tests/test_resolve_env_helpers.py @@ -29,12 +29,16 @@ class TestAutoInheritEnv: ("env_name", "env_value"), [ pytest.param("ANTHROPIC_API_KEY", "sk-host", id="anthropic"), + pytest.param("CODEX_AUTH_JSON", '{"tokens": {}}', id="codex-auth-json"), pytest.param("CODEX_ACCESS_TOKEN", "codex-access", id="codex-token"), pytest.param("CODEX_API_KEY", "codex-key", id="codex-api-key"), + pytest.param("CLAUDE_OAUTH_TOKEN", "claude-oauth", id="claude-oauth"), pytest.param("OPENAI_API_KEY", "sk-oai", id="openai"), pytest.param("AWS_BEARER_TOKEN_BEDROCK", "bedrock-token", id="bedrock"), pytest.param("AWS_REGION", "us-east-1", id="bedrock-region"), pytest.param("ZAI_API_KEY", "zk-host", id="provider"), + pytest.param("KIMI_API_KEY", "sk-kimi", id="kimi-api-key"), + pytest.param("KIMI_BASE_URL", "https://api.moonshot.ai/v1", id="kimi-url"), pytest.param("AZURE_API_KEY", "az-host", id="azure-api-key"), pytest.param( "AZURE_API_ENDPOINT", @@ -81,6 +85,12 @@ def test_aws_region_mirrored_to_aws_default_region(self): auto_inherit_env(env) assert env["AWS_DEFAULT_REGION"] == "us-east-1" + def test_claude_oauth_alias_mirrored_to_claude_code_token(self): + """Guards PR #587: pasted Claude Code OAuth vars use both common names.""" + env = {"CLAUDE_OAUTH_TOKEN": "oauth-token"} + auto_inherit_env(env) + assert env["CLAUDE_CODE_OAUTH_TOKEN"] == "oauth-token" + def test_inherits_openai_base_url(self, monkeypatch): """Guards fix from PR #255: OPENAI_BASE_URL must be inherited. @@ -231,6 +241,23 @@ def test_injects_benchflow_provider_vars(self): assert "BENCHFLOW_PROVIDER_PROTOCOL" in env assert env["BENCHFLOW_PROVIDER_API_KEY"] == "zk-test" + def test_openai_compatible_provider_maps_to_openhands_env(self): + """Guards PR #587: direct provider envs reach OpenHands and usage proxy.""" + env = { + "KIMI_API_KEY": "sk-kimi", + "KIMI_BASE_URL": "https://api.moonshot.ai/v1", + } + + resolve_provider_env(env, "kimi/kimi-k2.6", "openhands") + + assert env["BENCHFLOW_PROVIDER_NAME"] == "kimi" + assert env["BENCHFLOW_PROVIDER_MODEL"] == "kimi-k2.6" + assert env["BENCHFLOW_PROVIDER_BASE_URL"] == "https://api.moonshot.ai/v1" + assert env["BENCHFLOW_PROVIDER_API_KEY"] == "sk-kimi" + assert env["LLM_BASE_URL"] == "https://api.moonshot.ai/v1" + assert env["LLM_API_KEY"] == "sk-kimi" + assert env["LLM_MODEL"] == "openai/kimi-k2.6" + def test_env_mapping_applied(self): """claude-agent-acp maps BENCHFLOW_PROVIDER_* → agent-native vars.""" env = {"ZAI_API_KEY": "zk-test"} @@ -771,6 +798,8 @@ def _clean_env(self, monkeypatch, tmp_path): "OPENAI_API_KEY", "OPENAI_BASE_URL", "ZAI_API_KEY", + "KIMI_API_KEY", + "KIMI_BASE_URL", ): monkeypatch.delenv(k, raising=False) empty = tmp_path / "empty.env" @@ -801,18 +830,43 @@ def test_explicit_agent_env_beats_host_provider_base_url(self, monkeypatch): assert result["BENCHFLOW_PROVIDER_BASE_URL"] == "http://explicit/v1" - def test_host_provider_base_url_overrides_resolved_provider_url(self, monkeypatch): - """For a provider with a real registry URL (zai), the host value wins. + def test_inherited_provider_base_url_does_not_shadow_registered_provider( + self, monkeypatch + ): + """A global .env provider proxy must not override direct provider prefixes. - vllm resolves to an empty base_url; zai resolves to a real endpoint, so - this proves the host override beats a *non-empty* resolved value. + Guards PR #587: a LiteLLM BENCHFLOW_PROVIDER_* default in .env broke + direct Kimi/GLM/etc. runs by replacing the provider's own key and URL. """ - monkeypatch.setenv("ZAI_API_KEY", "zk-test") + monkeypatch.setenv("KIMI_API_KEY", "sk-kimi") + monkeypatch.setenv("KIMI_BASE_URL", "https://api.moonshot.ai/v1") monkeypatch.setenv("BENCHFLOW_PROVIDER_BASE_URL", "http://host-proxy:9000/v1") + monkeypatch.setenv("BENCHFLOW_PROVIDER_API_KEY", "sk-host-proxy") - result = resolve_agent_env("codex-acp", "zai/glm-5", {}) + result = resolve_agent_env("openhands", "kimi/kimi-k2.6", {}) + + assert result["BENCHFLOW_PROVIDER_BASE_URL"] == "https://api.moonshot.ai/v1" + assert result["BENCHFLOW_PROVIDER_API_KEY"] == "sk-kimi" + assert result["LLM_BASE_URL"] == "https://api.moonshot.ai/v1" + assert result["LLM_API_KEY"] == "sk-kimi" + + def test_explicit_provider_base_url_can_override_registered_provider(self): + """An explicit --agent-env generic endpoint remains a valid override.""" + result = resolve_agent_env( + "openhands", + "kimi/kimi-k2.6", + { + "KIMI_API_KEY": "sk-kimi", + "KIMI_BASE_URL": "https://api.moonshot.ai/v1", + "BENCHFLOW_PROVIDER_BASE_URL": "http://explicit-proxy:9000/v1", + "BENCHFLOW_PROVIDER_API_KEY": "sk-explicit-proxy", + }, + ) - assert result["BENCHFLOW_PROVIDER_BASE_URL"] == "http://host-proxy:9000/v1" + assert result["BENCHFLOW_PROVIDER_BASE_URL"] == "http://explicit-proxy:9000/v1" + assert result["BENCHFLOW_PROVIDER_API_KEY"] == "sk-explicit-proxy" + assert result["LLM_BASE_URL"] == "http://explicit-proxy:9000/v1" + assert result["LLM_API_KEY"] == "sk-explicit-proxy" def test_no_host_override_keeps_resolved_provider_url(self, monkeypatch): """Sanity counterpart: without a host override zai's own endpoint is used.""" diff --git a/tests/test_sandbox_usage_proxy.py b/tests/test_sandbox_usage_proxy.py new file mode 100644 index 00000000..0af43e63 --- /dev/null +++ b/tests/test_sandbox_usage_proxy.py @@ -0,0 +1,631 @@ +"""Tests for sandbox-local provider usage telemetry.""" + +from __future__ import annotations + +import base64 +import contextlib +import gzip +import json +import os +import re +import shutil +import signal +import subprocess +import threading +import time +from pathlib import Path +from types import SimpleNamespace +from urllib.error import HTTPError +from urllib.request import Request, urlopen + +import pytest + +from benchflow.providers import usage_proxy_runtime as usage_runtime_mod +from benchflow.trajectories.types import Trajectory + + +def test_agent_kill_pattern_excludes_usage_proxy_agent_name_argument(): + """Guards PR #587: agent cleanup must not kill the usage proxy.""" + from benchflow.rollout import _agent_process_kill_pattern + + pattern = _agent_process_kill_pattern("/opt/benchflow/bin/codex-acp") + + assert pattern is not None + assert re.search(pattern, "/opt/benchflow/bin/codex-acp") + assert re.search(pattern, "node /opt/benchflow/js-agents/bin/codex-acp --flag") + assert not re.search(pattern, "node /tmp/benchflow-usage-proxy/proxy.js") + assert not re.search(pattern, "proxy.js --agent-name=codex-acp") + + +@pytest.mark.asyncio +async def test_daytona_uses_sandbox_local_proxy_not_host_proxy(monkeypatch): + """Guards PR #587: Daytona agents must not use host-local proxy URLs.""" + from benchflow.providers.runtime import ensure_usage_proxy_runtime + + class FakeSandboxUsageProxy: + target = "https://api.anthropic.com" + base_url = "http://127.0.0.1:49000" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.started = False + self.trajectory = Trajectory( + session_id=kwargs["session_id"], agent_name=kwargs["agent_name"] + ) + + async def start(self): + self.started = True + + async def stop(self): + return None + + monkeypatch.setattr( + usage_runtime_mod, + "TrajectoryProxy", + lambda *a, **k: (_ for _ in ()).throw( + AssertionError("host proxy must not start") + ), + ) + monkeypatch.setattr(usage_runtime_mod, "SandboxUsageProxy", FakeSandboxUsageProxy) + + env = { + "ANTHROPIC_BASE_URL": "https://api.anthropic.com", + "ANTHROPIC_API_KEY": "sk-real-key", + } + updated, runtime = await ensure_usage_proxy_runtime( + agent="claude-agent-acp", + agent_env=env, + model="claude-haiku-4-5-20251001", + runtime=None, + environment="daytona", + session_id="rollout-1", + sandbox=object(), + ) + + assert runtime is not None + assert runtime.server.started is True + assert updated["ANTHROPIC_BASE_URL"] == "http://127.0.0.1:49000" + + +@pytest.mark.asyncio +async def test_sandbox_usage_proxy_imports_raw_captures(): + """Guards PR #587: sandbox captures reuse the canonical usage parser.""" + from benchflow.providers.runtime import ProviderRuntime, extract_usage + from benchflow.providers.sandbox_usage_proxy import SandboxUsageProxy + + capture = { + "duration_ms": 12, + "request": { + "method": "POST", + "path": "/v1/messages", + "headers": {"content-type": "application/json"}, + "body_b64": base64.b64encode( + json.dumps({"model": "claude-haiku-4-5-20251001"}).encode() + ).decode(), + }, + "response": { + "status_code": 200, + "headers": {"content-type": "application/json"}, + "body_b64": base64.b64encode( + json.dumps( + { + "model": "claude-haiku-4-5-20251001", + "usage": {"input_tokens": 13, "output_tokens": 5}, + } + ).encode() + ).decode(), + }, + } + + class FakeSandbox: + def __init__(self): + self.uploads = [] + self.commands = [] + self.state_reads = 0 + + async def upload_file(self, source_path, target_path): + assert any(command.startswith("mkdir -p ") for command in self.commands) + self.uploads.append((source_path, target_path)) + + async def exec(self, command, timeout_sec=None): + self.commands.append(command) + if command.startswith("mkdir -p "): + return SimpleNamespace(return_code=0, stdout="", stderr="") + if "command -v node" in command: + return SimpleNamespace( + return_code=0, stdout="/usr/bin/node\n", stderr="" + ) + if "node -e" in command or "node' -e" in command: + assert "nohup" not in command + assert "--agent-name" not in command + return SimpleNamespace(return_code=0, stdout="123\n", stderr="") + if "state.json" in command and command.strip().startswith("cat "): + self.state_reads += 1 + if self.state_reads == 1: + return SimpleNamespace(return_code=0, stdout="{", stderr="") + return SimpleNamespace( + return_code=0, + stdout='{"port":49000,"pid":123}\n', + stderr="", + ) + if "captures.jsonl" in command and command.strip().startswith("cat "): + return SimpleNamespace( + return_code=0, + stdout=json.dumps(capture) + "\n", + stderr="", + ) + if "kill -TERM" in command: + return SimpleNamespace(return_code=0, stdout="", stderr="") + if command.startswith("rm -rf "): + return SimpleNamespace(return_code=0, stdout="", stderr="") + return SimpleNamespace(return_code=1, stdout="", stderr=command) + + sandbox = FakeSandbox() + proxy = SandboxUsageProxy( + sandbox=sandbox, + target="https://api.anthropic.com", + session_id="rollout-1", + agent_name="claude-agent-acp", + ) + await proxy.start() + await proxy.stop() + + runtime = ProviderRuntime( + kind="usage-proxy", + agent_base_url=proxy.base_url, + backend_model="claude-haiku-4-5-20251001", + server=proxy, + ) + usage = extract_usage(runtime) + + assert proxy.base_url == "http://127.0.0.1:49000" + assert sandbox.state_reads == 2 + assert usage["usage_source"] == "provider_response" + assert usage["n_input_tokens"] == 13 + assert usage["n_output_tokens"] == 5 + + +@pytest.mark.asyncio +async def test_sandbox_usage_proxy_downloads_capture_log(tmp_path): + """Guards PR #587: large sandbox capture logs avoid exec stdout limits.""" + from benchflow.providers.runtime import ProviderRuntime, extract_usage + from benchflow.providers.sandbox_usage_proxy import SandboxUsageProxy + + capture = { + "duration_ms": 12, + "request": { + "method": "POST", + "path": "/responses", + "headers": {"content-type": "application/json"}, + "body_b64": base64.b64encode( + json.dumps({"model": "gpt-5.5"}).encode() + ).decode(), + }, + "response": { + "status_code": 200, + "headers": {"content-type": "application/json"}, + "body_b64": base64.b64encode( + json.dumps( + { + "model": "gpt-5.5", + "usage": { + "input_tokens": 21, + "output_tokens": 8, + "total_tokens": 29, + }, + } + ).encode() + ).decode(), + }, + } + + class FakeSandbox: + def __init__(self): + self.exec_commands = [] + self.downloads = [] + + async def download_file(self, source_path, target_path): + self.downloads.append((source_path, target_path)) + Path(target_path).write_text(json.dumps(capture) + "\n") + + async def exec(self, command, timeout_sec=None): + self.exec_commands.append(command) + return SimpleNamespace(return_code=0, stdout="", stderr="") + + sandbox = FakeSandbox() + proxy = SandboxUsageProxy( + sandbox=sandbox, + target="https://api.openai.com/v1", + session_id="rollout-1", + agent_name="codex-acp", + ) + await proxy._load_captures() + + runtime = ProviderRuntime( + kind="usage-proxy", + agent_base_url="http://127.0.0.1:49000", + backend_model="gpt-5.5", + server=proxy, + ) + usage = extract_usage(runtime) + + assert [source for source, _target in sandbox.downloads] == [proxy._log_path] + assert not any("captures.jsonl" in command for command in sandbox.exec_commands) + assert usage["usage_source"] == "provider_response" + assert usage["n_input_tokens"] == 21 + assert usage["n_output_tokens"] == 8 + + +@pytest.mark.asyncio +async def test_sandbox_usage_proxy_liveness_reports_pid_status(): + """Guards PR #587: stale sandbox proxies are detected by PID liveness.""" + from benchflow.providers.sandbox_usage_proxy import SandboxUsageProxy + + class FakeSandbox: + async def exec(self, command, timeout_sec=None): + assert "kill -0" in command + return SimpleNamespace(return_code=0, stdout="yes\n", stderr="") + + proxy = SandboxUsageProxy( + sandbox=FakeSandbox(), + target="https://api.anthropic.com", + session_id="rollout-1", + agent_name="claude-agent-acp", + ) + + assert await proxy.is_running() is True + + +@pytest.mark.asyncio +async def test_sandbox_usage_proxy_stop_kills_and_cleans_when_capture_read_fails(): + """Guards PR #587: capture import failures still terminate the proxy.""" + from benchflow.providers.sandbox_usage_proxy import SandboxUsageProxy + + commands = [] + + class FakeSandbox: + async def exec(self, command, timeout_sec=None): + commands.append(command) + if "captures.jsonl" in command: + raise TimeoutError("capture read timed out") + return SimpleNamespace(return_code=0, stdout="", stderr="") + + proxy = SandboxUsageProxy( + sandbox=FakeSandbox(), + target="https://api.anthropic.com", + session_id="rollout-1", + agent_name="claude-agent-acp", + ) + + await proxy.stop() + + assert any("kill -TERM" in command for command in commands) + assert any(command.startswith("rm -rf ") for command in commands) + + +@pytest.mark.asyncio +async def test_daytona_auto_usage_proxy_start_failure_leaves_env_untouched(monkeypatch): + """Guards PR #587: auto mode degrades instead of failing Daytona runs.""" + from benchflow.providers.runtime import ensure_usage_proxy_runtime + from benchflow.usage_tracking import UsageTrackingConfig + + stopped = [] + + class BrokenSandboxUsageProxy: + target = "https://api.anthropic.com" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + async def start(self): + raise RuntimeError("launcher failed") + + async def stop(self): + stopped.append(True) + + monkeypatch.setattr(usage_runtime_mod, "SandboxUsageProxy", BrokenSandboxUsageProxy) + + env = {"ANTHROPIC_BASE_URL": "https://api.anthropic.com"} + updated, runtime = await ensure_usage_proxy_runtime( + agent="claude-agent-acp", + agent_env=env, + model="claude-haiku-4-5-20251001", + runtime=None, + environment="daytona", + session_id="rollout-1", + usage_tracking=UsageTrackingConfig(mode="auto"), + sandbox=object(), + ) + + assert updated == env + assert runtime is None + assert stopped == [True] + + +@pytest.mark.asyncio +async def test_daytona_required_usage_proxy_start_failure_raises(monkeypatch): + """Guards PR #587: required mode still fails fast on proxy startup errors.""" + from benchflow.providers.runtime import ensure_usage_proxy_runtime + from benchflow.usage_tracking import UsageTrackingConfig + + class BrokenSandboxUsageProxy: + target = "https://api.anthropic.com" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + async def start(self): + raise RuntimeError("launcher failed") + + async def stop(self): + return None + + monkeypatch.setattr(usage_runtime_mod, "SandboxUsageProxy", BrokenSandboxUsageProxy) + + with pytest.raises(RuntimeError, match=r"required.*failed to start"): + await ensure_usage_proxy_runtime( + agent="claude-agent-acp", + agent_env={"ANTHROPIC_BASE_URL": "https://api.anthropic.com"}, + model="claude-haiku-4-5-20251001", + runtime=None, + environment="daytona", + session_id="rollout-1", + usage_tracking=UsageTrackingConfig(mode="required"), + sandbox=object(), + ) + + +@pytest.mark.asyncio +async def test_usage_runtime_recreated_when_sandbox_proxy_is_dead(monkeypatch): + """Guards PR #587: dead sandbox proxies are not reused across reconnects.""" + from benchflow.providers.runtime import ProviderRuntime, ensure_usage_proxy_runtime + + stopped = [] + started = [] + + class DeadServer: + target = "https://api.anthropic.com" + + async def is_running(self): + return False + + async def stop(self): + stopped.append("dead") + + class FakeSandboxUsageProxy: + target = "https://api.anthropic.com" + base_url = "http://127.0.0.1:49001" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.trajectory = Trajectory( + session_id=kwargs["session_id"], agent_name=kwargs["agent_name"] + ) + + async def start(self): + started.append(self.target) + + async def stop(self): + stopped.append("new") + + monkeypatch.setattr(usage_runtime_mod, "SandboxUsageProxy", FakeSandboxUsageProxy) + + stale_runtime = ProviderRuntime( + kind="usage-proxy", + agent_base_url="http://127.0.0.1:49000", + backend_model="claude-haiku-4-5-20251001", + server=DeadServer(), + ) + + updated, runtime = await ensure_usage_proxy_runtime( + agent="claude-agent-acp", + agent_env={"ANTHROPIC_BASE_URL": "https://api.anthropic.com"}, + model="claude-haiku-4-5-20251001", + runtime=stale_runtime, + environment="daytona", + session_id="rollout-1", + sandbox=object(), + ) + + assert stopped == ["dead"] + assert started == ["https://api.anthropic.com"] + assert runtime is not None + assert runtime is not stale_runtime + assert updated["ANTHROPIC_BASE_URL"] == "http://127.0.0.1:49001" + + +def test_raw_capture_json_error_beats_stream_request_hint(): + """Guards PR #587: JSON error responses are not parsed as SSE.""" + from benchflow.trajectories.proxy import exchange_from_raw_capture + + exchange = exchange_from_raw_capture( + { + "request": { + "method": "POST", + "path": "/v1/messages", + "headers": {"content-type": "application/json"}, + "body_b64": base64.b64encode( + json.dumps({"stream": True}).encode() + ).decode(), + }, + "response": { + "status_code": 400, + "headers": {"content-type": "application/json"}, + "body_b64": base64.b64encode( + json.dumps( + {"error": {"message": "Budget has been exceeded"}} + ).encode() + ).decode(), + }, + } + ) + + assert exchange.response.body["error"]["message"] == "Budget has been exceeded" + + +def test_node_proxy_forwards_and_imports_redacted_raw_captures(tmp_path): + """Guards PR #587: Node proxy forwards, redacts, and imports captures.""" + from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer + + from benchflow.providers.sandbox_usage_proxy import _NODE_PROXY_SOURCE + from benchflow.trajectories.proxy import exchange_from_raw_capture + + node = shutil.which("node") + if node is None: + pytest.skip("node is required for sandbox usage proxy integration smoke") + + class Upstream(BaseHTTPRequestHandler): + def do_POST(self): + length = int(self.headers.get("content-length", "0")) + body = self.rfile.read(length) + if self.path == "/v1/error": + self.send_response(400) + self.send_header("content-type", "application/json") + self.send_header("set-cookie", "secret-cookie") + self.end_headers() + self.wfile.write( + json.dumps( + {"error": {"message": "Budget has been exceeded"}} + ).encode() + ) + return + if self.path == "/v1/stream": + self.send_response(200) + self.send_header("content-type", "text/event-stream") + self.end_headers() + self.wfile.write( + b'data: {"model":"gpt-4.1-mini","choices":[{"delta":{"content":"hi"}}],"usage":{"prompt_tokens":4,"completion_tokens":1,"total_tokens":5}}\n\n' + ) + return + if self.path == "/v1/gzip": + payload = gzip.compress( + json.dumps( + { + "model": "gpt-4.1-mini", + "usage": {"prompt_tokens": 7, "completion_tokens": 2}, + } + ).encode() + ) + self.send_response(200) + self.send_header("content-type", "application/json") + self.send_header("content-encoding", "gzip") + self.end_headers() + self.wfile.write(payload) + return + self.send_response(200) + self.send_header("content-type", "application/json") + self.end_headers() + self.wfile.write( + json.dumps( + { + "received_gzip": self.headers.get("content-encoding") == "gzip", + "body_len": len(body), + "model": "claude-haiku-4-5-20251001", + "usage": {"input_tokens": 13, "output_tokens": 5}, + } + ).encode() + ) + + def log_message(self, *_args): + return None + + upstream = ThreadingHTTPServer(("127.0.0.1", 0), Upstream) + upstream_thread = threading.Thread(target=upstream.serve_forever, daemon=True) + upstream_thread.start() + + runtime_dir = tmp_path / "proxy" + runtime_dir.mkdir() + script = runtime_dir / "proxy.js" + state = runtime_dir / "state.json" + log_path = runtime_dir / "captures.jsonl" + pid_path = runtime_dir / "proxy.pid" + script.write_text(_NODE_PROXY_SOURCE) + env = { + **os.environ, + "BENCHFLOW_USAGE_PROXY_TARGET": ( + f"http://127.0.0.1:{upstream.server_address[1]}" + ), + "BENCHFLOW_USAGE_PROXY_STATE_PATH": str(state), + "BENCHFLOW_USAGE_PROXY_LOG_PATH": str(log_path), + "BENCHFLOW_USAGE_PROXY_PID_PATH": str(pid_path), + "BENCHFLOW_USAGE_PROXY_SESSION_ID": "rollout-1", + "BENCHFLOW_USAGE_PROXY_AGENT_NAME": "codex-acp", + } + proc = subprocess.Popen([node, str(script)], env=env) + try: + deadline = time.monotonic() + 5 + while time.monotonic() < deadline and not state.exists(): + time.sleep(0.05) + assert state.exists() + proxy_port = json.loads(state.read_text())["port"] + + def post(path, payload, headers=None): + body = ( + payload if isinstance(payload, bytes) else json.dumps(payload).encode() + ) + request = Request( + f"http://127.0.0.1:{proxy_port}{path}", + data=body, + headers={"content-type": "application/json", **(headers or {})}, + method="POST", + ) + try: + with urlopen(request, timeout=5) as response: + return response.status, response.read() + except HTTPError as exc: + return exc.code, exc.read() + + gzipped_request = gzip.compress(json.dumps({"stream": False}).encode()) + assert ( + post( + "/v1/messages?key=secret-query&safe=1", + gzipped_request, + { + "authorization": "Bearer secret", + "x-api-key": "secret", + "content-encoding": "gzip", + }, + )[0] + == 200 + ) + assert post("/v1/error", {"stream": True})[0] == 400 + assert post("/v1/stream", {"stream": True})[0] == 200 + assert post("/v1/gzip", {"stream": False})[0] == 200 + + capture_lines = [] + deadline = time.monotonic() + 5 + while time.monotonic() < deadline: + capture_lines = log_path.read_text().splitlines() + if len(capture_lines) >= 4: + break + time.sleep(0.05) + captures = [json.loads(line) for line in capture_lines] + exchanges = [exchange_from_raw_capture(record) for record in captures] + + assert len(exchanges) == 4 + first = exchanges[0] + assert first.request.path == "/v1/messages?key=__BENCHFLOW_REDACTED__&safe=1" + assert first.request.headers["authorization"] == "__BENCHFLOW_REDACTED__" + assert first.request.headers["x-api-key"] == "__BENCHFLOW_REDACTED__" + assert first.request.body["stream"] is False + assert first.response.body["usage"]["input_tokens"] == 13 + + error_exchange = exchanges[1] + assert error_exchange.response.status_code == 400 + assert ( + error_exchange.response.body["error"]["message"] + == "Budget has been exceeded" + ) + assert error_exchange.response.headers["set-cookie"] == "__BENCHFLOW_REDACTED__" + + stream_exchange = exchanges[2] + assert stream_exchange.response.body["choices"][0]["message"]["content"] == "hi" + assert stream_exchange.response.body["usage"]["total_tokens"] == 5 + + gzip_exchange = exchanges[3] + assert gzip_exchange.response.body["usage"]["prompt_tokens"] == 7 + finally: + with contextlib.suppress(ProcessLookupError, FileNotFoundError): + os.kill(int(pid_path.read_text()), signal.SIGTERM) + proc.wait(timeout=5) + upstream.shutdown() diff --git a/tests/test_sdk_internals.py b/tests/test_sdk_internals.py index 04f0f8a4..f1e11f18 100644 --- a/tests/test_sdk_internals.py +++ b/tests/test_sdk_internals.py @@ -40,6 +40,7 @@ def test_env_mapping_applied_after_provider(self, monkeypatch): "ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN", + "CLAUDE_OAUTH_TOKEN", ): monkeypatch.delenv(key, raising=False) result = self._resolve( @@ -114,6 +115,7 @@ def test_cross_provider_host_native_key_does_not_bypass_required_key( "ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN", + "CLAUDE_OAUTH_TOKEN", "GOOGLE_API_KEY", "GEMINI_API_KEY", ): @@ -207,6 +209,7 @@ def test_required_key_missing_raises(self, monkeypatch, tmp_path): "ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN", + "CLAUDE_OAUTH_TOKEN", "CODEX_ACCESS_TOKEN", "CODEX_API_KEY", "ZAI_API_KEY", diff --git a/tests/test_subscription_auth.py b/tests/test_subscription_auth.py index 23d8f8ae..b3bf5ee3 100644 --- a/tests/test_subscription_auth.py +++ b/tests/test_subscription_auth.py @@ -95,12 +95,24 @@ def test_api_key_present_no_subscription_marker(self): ) assert "_BENCHFLOW_SUBSCRIPTION_AUTH" not in result + def test_claude_oauth_alias_satisfies_anthropic_key_requirement(self): + """Guards PR #587: CLAUDE_OAUTH_TOKEN is accepted as a Claude Code alias.""" + result = self._resolve( + model="claude-haiku-4-5-20251001", + agent_env={"CLAUDE_OAUTH_TOKEN": "oauth-test"}, + ) + + assert result["CLAUDE_CODE_OAUTH_TOKEN"] == "oauth-test" + assert "ANTHROPIC_API_KEY" not in result + def test_subscription_auth_detected(self, monkeypatch, tmp_path): """When host auth file exists and no API key, subscription auth is used.""" for k in ( "ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN", + "CLAUDE_OAUTH_TOKEN", + "CODEX_AUTH_JSON", "CODEX_ACCESS_TOKEN", "CODEX_API_KEY", "OPENAI_API_KEY", @@ -125,6 +137,8 @@ def test_no_auth_file_raises(self, monkeypatch, tmp_path): "ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN", + "CLAUDE_OAUTH_TOKEN", + "CODEX_AUTH_JSON", "CODEX_ACCESS_TOKEN", "CODEX_API_KEY", "OPENAI_API_KEY", @@ -158,6 +172,7 @@ def test_codex_subscription_auth(self, monkeypatch, tmp_path): """Codex subscription auth works with host ~/.codex/auth.json.""" for k in ( "CODEX_ACCESS_TOKEN", + "CODEX_AUTH_JSON", "CODEX_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", @@ -175,6 +190,24 @@ def test_codex_subscription_auth(self, monkeypatch, tmp_path): ) assert result["_BENCHFLOW_SUBSCRIPTION_AUTH"] == "1" + def test_codex_auth_json_auth(self, monkeypatch, tmp_path): + """Guards PR #587: inline Codex auth.json can auth native Codex runs.""" + for k in ("CODEX_ACCESS_TOKEN", "CODEX_API_KEY", "OPENAI_API_KEY"): + monkeypatch.delenv(k, raising=False) + _patch_expanduser(monkeypatch, tmp_path) + + result = self._resolve( + agent="codex-acp", + model="gpt-4o", + agent_env={ + "CODEX_AUTH_JSON": '{"tokens": {"access_token": "access-token"}}' + }, + ) + + assert result["CODEX_AUTH_JSON"].startswith("{") + assert "OPENAI_API_KEY" not in result + assert "_BENCHFLOW_SUBSCRIPTION_AUTH" not in result + def test_codex_access_token_auth(self, monkeypatch, tmp_path): """Guards PR #296: Blocks-style Codex auth via CODEX_ACCESS_TOKEN.""" for k in ("CODEX_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY"): @@ -214,6 +247,7 @@ def test_codex_access_token_does_not_auth_custom_provider( """Guards PR #296: access tokens are not proxy API keys.""" for k in ( "CODEX_ACCESS_TOKEN", + "CODEX_AUTH_JSON", "CODEX_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", @@ -238,6 +272,7 @@ def test_codex_subscription_auth_does_not_auth_custom_base_url( """Guards PR #296: subscription auth is not custom endpoint API-key auth.""" for k in ( "CODEX_API_KEY", + "CODEX_AUTH_JSON", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", ): @@ -254,6 +289,16 @@ def test_codex_subscription_auth_does_not_auth_custom_base_url( }, ) + with pytest.raises(ValueError, match="OPENAI_API_KEY required"): + self._resolve( + agent="codex-acp", + model="gpt-4o", + agent_env={ + "CODEX_AUTH_JSON": '{"tokens": {"access_token": "access-token"}}', + base_url_key: "http://localhost:8765/v1", + }, + ) + @pytest.mark.parametrize( "base_url_key", ["BENCHFLOW_PROVIDER_BASE_URL", "OPENAI_BASE_URL"], @@ -264,6 +309,7 @@ def test_codex_host_login_does_not_auth_custom_base_url( """Guards PR #296: host login is not custom endpoint API-key auth.""" for k in ( "CODEX_ACCESS_TOKEN", + "CODEX_AUTH_JSON", "CODEX_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", @@ -312,6 +358,49 @@ async def upload_file(self, source: str, dest: str): class TestUploadSubscriptionAuth: + @pytest.mark.asyncio + async def test_codex_auth_json_writes_auth_file(self): + """Guards PR #587: inline Codex auth.json is uploaded for Daytona.""" + from benchflow.agents.credentials import write_credential_files + + env = _FakeEnv() + await write_credential_files( + env, + "codex-acp", + {"CODEX_AUTH_JSON": '{"tokens": {"access_token": "test"}}'}, + AGENTS["codex-acp"], + "gpt-4o", + "/home/agent", + ) + + assert len(env.uploads) == 1 + assert env.uploads[0][1:] == ( + "/home/agent/.codex/auth.json", + '{"tokens": {"access_token": "test"}}', + ) + + @pytest.mark.asyncio + async def test_openai_key_wins_over_codex_auth_json_file_write(self): + """Guards PR #587: API-key auth keeps the existing Codex file shape.""" + from benchflow.agents.credentials import write_credential_files + + env = _FakeEnv() + await write_credential_files( + env, + "codex-acp", + { + "OPENAI_API_KEY": "sk-test", + "CODEX_AUTH_JSON": '{"tokens": {"access_token": "test"}}', + }, + AGENTS["codex-acp"], + "gpt-4o", + "/home/agent", + ) + + assert len(env.uploads) == 1 + assert env.uploads[0][1] == "/home/agent/.codex/auth.json" + assert env.uploads[0][2] == '{"OPENAI_API_KEY": "sk-test"}' + @pytest.mark.asyncio async def test_subscription_auth_chowns_uploaded_home_file( self, monkeypatch, tmp_path diff --git a/tests/test_trajectory_proxy_path_prefix.py b/tests/test_trajectory_proxy_path_prefix.py index 82a00835..37026c59 100644 --- a/tests/test_trajectory_proxy_path_prefix.py +++ b/tests/test_trajectory_proxy_path_prefix.py @@ -1,4 +1,4 @@ -"""Regression tests for external usage-proxy path-prefix routing.""" +"""Regression tests for optional usage-proxy path-prefix routing.""" from __future__ import annotations @@ -30,7 +30,7 @@ async def _read_request( @pytest.mark.asyncio async def test_path_prefix_gates_health_and_provider_requests(): - """Guards PR #568: external tunnel traffic must match the secret path prefix.""" + """Prefixed proxy traffic must match the configured secret path prefix.""" upstream_requests: list[tuple[str, str, bytes]] = [] async def upstream_handler( diff --git a/tests/test_usage_proxy.py b/tests/test_usage_proxy.py index 49d44ad7..bb013f4f 100644 --- a/tests/test_usage_proxy.py +++ b/tests/test_usage_proxy.py @@ -8,6 +8,7 @@ import pytest +from benchflow.providers import usage_proxy_runtime as usage_runtime_mod from benchflow.trajectories.types import ( LLMExchange, LLMRequest, @@ -171,6 +172,33 @@ def test_extract_usage_none_proxy(): } +@pytest.mark.parametrize( + "body", + [ + {"id": "msg_123", "content": [{"type": "text", "text": "ok"}]}, + {"error": {"message": "Budget has been exceeded"}}, + {"usage": {"prompt_tokens_details": {}}}, + ], +) +def test_extract_usage_requires_provider_usage_fields(body): + """Guards PR #587: captured HTTP without tokens is not usage telemetry.""" + from benchflow.providers.runtime import ProviderRuntime, extract_usage + + runtime = ProviderRuntime( + kind="usage-proxy", + agent_base_url="http://host.docker.internal:12345", + backend_model="claude-haiku-4-5-20251001", + server=_ProxyLike(_trajectory(body)), + ) + + usage = extract_usage(runtime) + + assert usage["usage_source"] == "unavailable" + assert usage["n_input_tokens"] is None + assert usage["n_output_tokens"] is None + assert usage["total_tokens"] is None + + def test_extract_usage_with_anthropic_exchanges(): from benchflow.providers.runtime import ProviderRuntime, extract_usage @@ -323,7 +351,7 @@ async def start(self): async def stop(self): return None - monkeypatch.setattr(provider_runtime_mod, "TrajectoryProxy", FakeTrajectoryProxy) + monkeypatch.setattr(usage_runtime_mod, "TrajectoryProxy", FakeTrajectoryProxy) updated, runtime = await ensure_usage_proxy_runtime( agent="claude-agent-acp", @@ -376,7 +404,7 @@ async def start(self): async def stop(self): return None - monkeypatch.setattr(provider_runtime_mod, "TrajectoryProxy", FakeTrajectoryProxy) + monkeypatch.setattr(usage_runtime_mod, "TrajectoryProxy", FakeTrajectoryProxy) updated, runtime = await ensure_usage_proxy_runtime( agent="openhands", @@ -398,14 +426,13 @@ async def stop(self): @pytest.mark.asyncio async def test_usage_proxy_can_be_disabled_for_operator_recovery(monkeypatch): """Guards v0.5-integration@e55219d recovery runs when telemetry proxying blocks rollouts.""" - from benchflow.providers import runtime as provider_runtime_mod from benchflow.providers.runtime import ensure_usage_proxy_runtime def _fail_start(*_args, **_kwargs): raise AssertionError("TrajectoryProxy must not start when disabled") monkeypatch.setenv("BENCHFLOW_DISABLE_USAGE_PROXY", "1") - monkeypatch.setattr(provider_runtime_mod, "TrajectoryProxy", _fail_start) + monkeypatch.setattr(usage_runtime_mod, "TrajectoryProxy", _fail_start) env = { "BENCHFLOW_PROVIDER_BASE_URL": "http://host.docker.internal:32123", @@ -457,7 +484,7 @@ async def start(self): async def stop(self): return None - monkeypatch.setattr(provider_runtime_mod, "TrajectoryProxy", FakeTrajectoryProxy) + monkeypatch.setattr(usage_runtime_mod, "TrajectoryProxy", FakeTrajectoryProxy) updated, runtime = await ensure_usage_proxy_runtime( agent="codex-acp", @@ -506,7 +533,7 @@ async def start(self): async def stop(self): return None - monkeypatch.setattr(provider_runtime_mod, "TrajectoryProxy", FakeTrajectoryProxy) + monkeypatch.setattr(usage_runtime_mod, "TrajectoryProxy", FakeTrajectoryProxy) _updated, runtime = await ensure_usage_proxy_runtime( agent="codex-acp", @@ -542,46 +569,10 @@ async def test_no_proxy_for_oracle(): assert runtime is None -@pytest.mark.asyncio -async def test_no_proxy_for_daytona_remote_sandbox(monkeypatch): - """Daytona runs the agent on a remote host the host proxy cannot reach. - - Guards the fix from PR #327: the usage proxy must be skipped so the agent - talks to the provider directly instead of being pointed at an unreachable - 127.0.0.1 address (the regression that produced ACP ECONNREFUSED errors). - """ - from benchflow.providers import runtime as provider_runtime_mod - from benchflow.providers.runtime import ensure_usage_proxy_runtime - - def _fail_start(*_args, **_kwargs): - raise AssertionError("TrajectoryProxy must not start for daytona") - - monkeypatch.setattr(provider_runtime_mod, "TrajectoryProxy", _fail_start) - - env = { - "ANTHROPIC_BASE_URL": "https://api.anthropic.com", - "ANTHROPIC_API_KEY": "sk-real-key", - } - updated, runtime = await ensure_usage_proxy_runtime( - agent="claude-agent-acp", - agent_env=env, - model="claude-haiku-4-5-20251001", - runtime=None, - environment="daytona", - session_id="rollout-1", - ) - - # Proxy skipped: env left untouched (no loopback rewrite), no runtime. - assert runtime is None - assert updated == env - assert updated["ANTHROPIC_BASE_URL"] == "https://api.anthropic.com" - - @pytest.mark.asyncio async def test_daytona_runtime_retired_when_environment_unreachable(monkeypatch): """Guards the fix from PR #327: a stale runtime from an earlier env must be stopped, not reused.""" - from benchflow.providers import runtime as provider_runtime_mod from benchflow.providers.runtime import ( ProviderRuntime, ensure_usage_proxy_runtime, @@ -602,7 +593,7 @@ async def stop(self): ) monkeypatch.setattr( - provider_runtime_mod, + usage_runtime_mod, "TrajectoryProxy", lambda *a, **k: (_ for _ in ()).throw(AssertionError("must not start")), ) @@ -918,7 +909,7 @@ async def start(self): async def stop(self): stopped.append(self._target) - monkeypatch.setattr(provider_runtime_mod, "TrajectoryProxy", FakeTrajectoryProxy) + monkeypatch.setattr(usage_runtime_mod, "TrajectoryProxy", FakeTrajectoryProxy) _env1, runtime1 = await ensure_usage_proxy_runtime( agent="codex-acp", diff --git a/tests/test_usage_proxy_smoke.py b/tests/test_usage_proxy_smoke.py index 0fc45dc8..2c294e04 100644 --- a/tests/test_usage_proxy_smoke.py +++ b/tests/test_usage_proxy_smoke.py @@ -1,7 +1,10 @@ """Optional real-provider smoke test for provider usage telemetry. -Run explicitly with: +Run explicitly with Docker: BENCHFLOW_RUN_TELEMETRY_SMOKE=1 uv run pytest tests/test_usage_proxy_smoke.py -q + +Run explicitly with Daytona: + BENCHFLOW_RUN_DAYTONA_TELEMETRY_SMOKE=1 uv run pytest tests/test_usage_proxy_smoke.py -q """ from __future__ import annotations @@ -94,12 +97,46 @@ async def test_real_acp_rollout_records_provider_usage(tmp_path): model=_smoke_model(agent), jobs_dir=tmp_path, job_name="telemetry-smoke", - trial_name="demo", + rollout_name="demo", environment=_smoke_setting("BENCHFLOW_TELEMETRY_SMOKE_ENV", "docker"), agent_env=_smoke_agent_env(), + usage_tracking="required", + ) + + _assert_provider_usage_recorded( + tmp_path / "telemetry-smoke" / "demo" / "result.json", result + ) + + +@pytest.mark.asyncio +async def test_real_daytona_acp_rollout_records_provider_usage(tmp_path): + if os.environ.get("BENCHFLOW_RUN_DAYTONA_TELEMETRY_SMOKE") != "1": + pytest.skip( + "set BENCHFLOW_RUN_DAYTONA_TELEMETRY_SMOKE=1 to run Daytona telemetry smoke" + ) + + from benchflow.sdk import SDK + + agent = _smoke_agent() + result = await SDK().run( + task_path="src/benchflow/demo_task", + agent=agent, + model=_smoke_model(agent), + jobs_dir=tmp_path, + job_name="daytona-telemetry-smoke", + rollout_name="demo", + environment="daytona", + agent_env=_smoke_agent_env(), + usage_tracking="required", ) - result_json = tmp_path / "telemetry-smoke" / "demo" / "result.json" + _assert_provider_usage_recorded( + tmp_path / "daytona-telemetry-smoke" / "demo" / "result.json", + result, + ) + + +def _assert_provider_usage_recorded(result_json, result) -> None: data = json.loads(result_json.read_text()) agent_result = data["agent_result"] diff --git a/tests/test_usage_required.py b/tests/test_usage_required.py new file mode 100644 index 00000000..6702b0d5 --- /dev/null +++ b/tests/test_usage_required.py @@ -0,0 +1,70 @@ +"""Tests for required provider token usage enforcement.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from benchflow.trajectories.types import ( + LLMExchange, + LLMRequest, + LLMResponse, + Trajectory, +) + + +def _trajectory(body: dict) -> Trajectory: + trajectory = Trajectory(session_id="s1", agent_name="agent") + trajectory.exchanges.append( + LLMExchange( + request=LLMRequest(body={"model": "gpt-5.5", "messages": []}), + response=LLMResponse(body=body), + ) + ) + return trajectory + + +@pytest.mark.asyncio +async def test_required_usage_tracking_fails_when_provider_usage_missing(tmp_path): + """Guards PR #587: required usage must not silently pass without tokens.""" + from benchflow.providers.runtime import ProviderRuntime, extract_usage + from benchflow.rollout import Rollout, RolloutConfig + from benchflow.usage_tracking import UsageTrackingConfig + + class FakeServer: + trajectory = _trajectory({"error": {"type": "budget_exceeded"}}) + + async def stop(self): + return None + + rollout = Rollout.__new__(Rollout) + rollout._config = RolloutConfig( + task_path=tmp_path / "task", + usage_tracking=UsageTrackingConfig(mode="required"), + ) + rollout._error = None + rollout._trajectory = [] + rollout._acp_client = None + rollout._agent_launch = "" + rollout._env = SimpleNamespace(stop=AsyncMock()) + rollout._environment = None + rollout._usage_runtime = ProviderRuntime( + kind="usage-proxy", + agent_base_url="http://host.docker.internal:32124", + backend_model="gpt-5.5", + server=FakeServer(), + ) + rollout._planes = SimpleNamespace( + stop_provider_runtime=lambda runtime: runtime.server.stop(), + extract_usage=extract_usage, + ) + rollout._rollout_dir = tmp_path + + await rollout.cleanup() + + assert rollout._usage_metrics["usage_source"] == "unavailable" + assert rollout._error == ( + "Token usage tracking is required, but no provider token usage was captured." + ) diff --git a/tests/test_usage_tracking.py b/tests/test_usage_tracking.py index 7c8d9e7f..e04d451e 100644 --- a/tests/test_usage_tracking.py +++ b/tests/test_usage_tracking.py @@ -4,16 +4,17 @@ import pytest +from benchflow.providers import usage_proxy_runtime as usage_runtime_mod from benchflow.trajectories.types import Trajectory @pytest.mark.asyncio -async def test_daytona_required_usage_tracking_requires_external_endpoint(): - """Guards PR #568: required remote tracking must fail closed.""" +async def test_daytona_required_usage_tracking_requires_sandbox_handle(): + """Guards the Daytona sandbox-local proxy path: required still fails closed.""" from benchflow.providers.runtime import ensure_usage_proxy_runtime from benchflow.usage_tracking import UsageTrackingConfig - with pytest.raises(RuntimeError, match="Token usage tracking is required"): + with pytest.raises(RuntimeError, match="sandbox-local usage proxy"): await ensure_usage_proxy_runtime( agent="claude-agent-acp", agent_env={ @@ -29,32 +30,28 @@ async def test_daytona_required_usage_tracking_requires_external_endpoint(): @pytest.mark.asyncio -async def test_daytona_external_usage_proxy_advertises_tunnel_url(monkeypatch): - """Guards PR #568: remote tracking must not inject local-only addresses.""" - from benchflow.providers import runtime as provider_runtime_mod +async def test_daytona_usage_tracking_starts_sandbox_local_proxy(monkeypatch): + """Daytona auto telemetry should use a proxy inside the agent sandbox.""" from benchflow.providers.runtime import ensure_usage_proxy_runtime from benchflow.usage_tracking import UsageTrackingConfig - class FakeTrajectoryProxy: + class FakeSandboxUsageProxy: def __init__( self, + sandbox, target, - session_id="", - agent_name="", - host="127.0.0.1", - port=0, + session_id, + agent_name, prompt_cache_retention=None, - path_prefix="", ): + self.sandbox = sandbox self.target = target self.session_id = session_id self.agent_name = agent_name - self.host = host - self.port = port self.prompt_cache_retention = prompt_cache_retention - self.path_prefix = path_prefix self.trajectory = Trajectory(session_id=session_id, agent_name=agent_name) self.started = False + self.base_url = "http://127.0.0.1:49152" async def start(self): self.started = True @@ -62,13 +59,8 @@ async def start(self): async def stop(self): return None - async def reachable(_url): - return True - - monkeypatch.setattr(provider_runtime_mod, "TrajectoryProxy", FakeTrajectoryProxy) - monkeypatch.setattr( - provider_runtime_mod, "_external_usage_proxy_reachable", reachable - ) + monkeypatch.setattr(usage_runtime_mod, "SandboxUsageProxy", FakeSandboxUsageProxy) + sandbox = object() updated, runtime = await ensure_usage_proxy_runtime( agent="openhands", @@ -80,25 +72,21 @@ async def reachable(_url): runtime=None, environment="daytona", session_id="rollout-1", - usage_tracking=UsageTrackingConfig( - mode="required", - advertised_base_url="https://usage-proxy.example.test", - port=18081, - ), + usage_tracking=UsageTrackingConfig(mode="required"), + sandbox=sandbox, ) assert runtime is not None assert runtime.server.started is True - assert runtime.server.host == "127.0.0.1" - assert runtime.server.port == 18081 - assert runtime.server.path_prefix.startswith("/__benchflow/") - assert runtime.base_url.startswith("https://usage-proxy.example.test/__benchflow/") + assert runtime.server.sandbox is sandbox + assert runtime.server.target == "https://llm-proxy.example.test" + assert runtime.base_url == "http://127.0.0.1:49152" assert updated["LLM_BASE_URL"] == runtime.base_url assert updated["BENCHFLOW_PROVIDER_BASE_URL"] == runtime.base_url def test_evaluation_yaml_loads_required_usage_tracking(tmp_path): - """Guards PR #568: eval YAML should preserve required usage tracking.""" + """Usage policy should round-trip through eval YAML.""" from benchflow.evaluation import Evaluation tasks_dir = tmp_path / "tasks" @@ -112,9 +100,6 @@ def test_evaluation_yaml_loads_required_usage_tracking(tmp_path): "model: gpt-4.1-mini", "environment: daytona", "usage_tracking: required", - "usage_proxy:", - " advertised_base_url: https://usage-proxy.example.test", - " port: 18081", ] ) ) @@ -122,15 +107,10 @@ def test_evaluation_yaml_loads_required_usage_tracking(tmp_path): evaluation = Evaluation.from_yaml(config) assert evaluation._config.usage_tracking.mode == "required" - assert ( - evaluation._config.usage_tracking.advertised_base_url - == "https://usage-proxy.example.test" - ) - assert evaluation._config.usage_tracking.port == 18081 -def test_evaluation_preflight_fails_required_daytona_without_endpoint(tmp_path): - """Guards PR #568: required Daytona tracking fails before agent launch.""" +def test_evaluation_preflight_allows_required_daytona(tmp_path): + """Daytona required tracking is checked when the sandbox proxy is started.""" from benchflow.evaluation import Evaluation, EvaluationConfig from benchflow.usage_tracking import UsageTrackingConfig @@ -143,54 +123,7 @@ def test_evaluation_preflight_fails_required_daytona_without_endpoint(tmp_path): ), ) - with pytest.raises(RuntimeError, match="no external usage proxy endpoint"): - evaluation._preflight_usage_tracking() - - -def test_evaluation_preflight_rejects_external_proxy_port_zero(tmp_path): - """Guards PR #568: external proxy tracking needs a stable local port.""" - from benchflow.evaluation import Evaluation, EvaluationConfig - from benchflow.usage_tracking import UsageTrackingConfig - - evaluation = Evaluation( - tasks_dir=tmp_path, - jobs_dir=tmp_path / "jobs", - config=EvaluationConfig( - concurrency=1, - environment="daytona", - usage_tracking=UsageTrackingConfig( - mode="required", - advertised_base_url="https://usage-proxy.example.test", - port=0, - ), - ), - ) - - with pytest.raises(RuntimeError, match="fixed positive local proxy port"): - evaluation._preflight_usage_tracking() - - -def test_evaluation_preflight_rejects_external_proxy_concurrency(tmp_path): - """Guards PR #568: one fixed external proxy port cannot host concurrency.""" - from benchflow.evaluation import Evaluation, EvaluationConfig - from benchflow.usage_tracking import UsageTrackingConfig - - evaluation = Evaluation( - tasks_dir=tmp_path, - jobs_dir=tmp_path / "jobs", - config=EvaluationConfig( - concurrency=2, - environment="daytona", - usage_tracking=UsageTrackingConfig( - mode="required", - advertised_base_url="https://usage-proxy.example.test", - port=18081, - ), - ), - ) - - with pytest.raises(ValueError, match="supports only one rollout"): - evaluation._preflight_usage_tracking() + evaluation._preflight_usage_tracking() def test_explicit_auto_usage_tracking_beats_env_default(monkeypatch): @@ -229,7 +162,7 @@ def test_usage_tracking_shard_payload_preserves_implicit_env_mode(monkeypatch): def test_usage_tracking_shard_payload_uses_flat_yaml_shape(): - """Guards PR #568: worker payload must not nest usage_tracking twice.""" + """Worker payload must preserve the flat usage_tracking policy shape.""" from benchflow.eval_sharding import EvalShard, _config_payload from benchflow.eval_worker import _evaluation_config from benchflow.evaluation import EvaluationConfig @@ -237,11 +170,7 @@ def test_usage_tracking_shard_payload_uses_flat_yaml_shape(): parent_config = EvaluationConfig( environment="daytona", - usage_tracking=UsageTrackingConfig( - mode="required", - advertised_base_url="https://usage-proxy.example.test", - port=18081, - ), + usage_tracking=UsageTrackingConfig(mode="required"), ) payload = _config_payload( @@ -252,79 +181,46 @@ def test_usage_tracking_shard_payload_uses_flat_yaml_shape(): worker_config = _evaluation_config(payload) assert payload["usage_tracking"] == "required" - assert payload["usage_proxy"] == { - "advertised_base_url": "https://usage-proxy.example.test", - "port": 18081, - } assert worker_config.usage_tracking.mode == "required" - assert ( - worker_config.usage_tracking.advertised_base_url - == "https://usage-proxy.example.test" - ) - assert worker_config.usage_tracking.port == 18081 -def test_usage_tracking_overlay_preserves_yaml_fields_for_partial_cli_override(): - """Guards PR #568: partial CLI usage overrides must not erase YAML policy.""" +def test_usage_tracking_overlay_preserves_existing_mode_for_partial_cli_override(): + """A partial CLI override should not erase YAML usage policy.""" from benchflow.usage_tracking import UsageTrackingConfig - yaml_config = UsageTrackingConfig( - mode="required", - advertised_base_url="https://old-proxy.example.test", - port=18081, - ) - cli_override = UsageTrackingConfig( - advertised_base_url="https://new-proxy.example.test", - ) + yaml_config = UsageTrackingConfig(mode="required") + cli_override = UsageTrackingConfig() merged = yaml_config.overlay(cli_override) assert merged.mode == "required" - assert merged.advertised_base_url == "https://new-proxy.example.test" - assert merged.port == 18081 - - -def test_external_usage_tracking_rejects_multiple_shard_workers(): - """Guards PR #568: sharded workers cannot share one fixed proxy port.""" - from benchflow.usage_tracking import UsageTrackingConfig - - config = UsageTrackingConfig( - advertised_base_url="https://usage-proxy.example.test", - port=18081, - ) - with pytest.raises(ValueError, match="supports only one rollout"): - config.validate_parallelism(concurrency=1, worker_count=2) - -def test_usage_proxy_advertised_base_url_rejects_path(): - """Guards PR #568: advertised proxy URLs must be root base URLs.""" +@pytest.mark.parametrize( + "legacy_key", + [ + "usage_proxy", + "usage_proxy_advertised_base_url", + "usage_proxy_bind_host", + "usage_proxy_port", + "usage_proxy_url", + ], +) +def test_usage_tracking_mapping_rejects_legacy_usage_proxy_keys(legacy_key): + """Guards PR #587: legacy usage proxy keys fail instead of being ignored.""" from benchflow.usage_tracking import UsageTrackingConfig - with pytest.raises(ValueError, match="must not include a path"): - UsageTrackingConfig( - advertised_base_url="https://usage-proxy.example.test/benchflow" + with pytest.raises(ValueError, match=f"{legacy_key} is no longer supported"): + UsageTrackingConfig.from_mapping( + { + "usage_tracking": "required", + legacy_key: { + "ignored": "value", + }, + } ) -def test_usage_tracking_mapping_preserves_zero_port(): - """Guards PR #568: config sharding must preserve explicit port=0.""" - from benchflow.usage_tracking import UsageTrackingConfig - - config = UsageTrackingConfig.from_mapping( - { - "usage_tracking": "required", - "usage_proxy": { - "advertised_base_url": "https://usage-proxy.example.test", - "port": 0, - }, - } - ) - - assert config.port == 0 - assert config.has_fixed_proxy_port is False - - @pytest.mark.asyncio async def test_completed_eval_resume_skips_usage_preflight(tmp_path, monkeypatch): """Guards PR #568: completed resumes should not require a live usage proxy."""