diff --git a/README.md b/README.md index 46f4679..08a8765 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,47 @@ We then extend `KernelBenchEnv` to support: - **Batching**: `KernelBenchEnvGroupBuilder` groups multiple rollouts for the same problem, enabling **GRPO-style** training where rewards are normalized within groups. - **Dataset Construction**: `KernelBenchDatasetBuilder` handles the iteration over KernelBench levels and problems, partitioning them into training and evaluation sets. You are welcome to extend it to support more problems beyond what is currently in KernelBench. +### Multi-Turn RL + +We extend the single-turn pipeline with multi-turn iterative refinement, following the approach in [Kevin](https://arxiv.org/abs/2507.11948). Instead of generating one kernel per problem, the model generates a kernel, receives evaluation feedback (compilation errors, correctness failures, or speedup results), and refines its solution over multiple turns. + +`MultiTurnKernelBenchEnv` manages the multi-turn loop: +- **History management**: Prior turns (prompt, response, feedback) are kept in context with token-based truncation to stay within the context window. +- **Evaluation feedback**: Structured feedback tells the model what went wrong (compilation error, incorrect output, or correct but slow) so it can fix specific issues. +- **Early stopping**: Optionally stop the episode when the kernel passes all correctness tests. + +Training uses GRPO with discounted returns across turns: +- Per-turn scores are computed as `S = 0.3 * correct + speedup` (only for correct kernels). +- Discounted returns: `R_t = S_t + γ * R_{t+1}` (backward recursion, γ=0.4 by default). +- Advantages are normalized across all `group_size × max_turns` turn-level samples: `(R - mean) / (std + ε)`. +- PPO with asymmetric clipping (Clip-Higher, ε_low=0.2, ε_high=0.28) and constant length normalization. + +Enable multi-turn via config: +```yaml +multiturn: + enabled: true + max_turns: 4 # Refinement turns per trajectory + gamma: 0.4 # Discount factor + aggregation: "sum" # "sum" or "max" +``` + +Or via CLI: +```bash +uv run python -m kernelbench_tinker.scripts.train_kernel_rl \ + --config src/kernelbench_tinker/config/rl_kernelbench.yaml \ + multiturn.enabled=true \ + log_path=./runs/my_multiturn_experiment +``` + +Multi-turn inference is also supported via the eval script: +```bash +uv run python -m kernelbench_tinker.scripts.eval_kernel_rl \ + checkpoint_path= \ + multiturn_enabled=true \ + multiturn_max_turns=8 \ + level=1 +``` + ### Directory Structure ```text @@ -54,6 +95,7 @@ src/kernelbench_tinker/ envs/ kernelbench_client.py # KernelBench Python API wrapper kernelbench_env.py # Single-turn RL environment + multiturn_kernelbench_env.py # Multi-turn RL environment training/ models.py # Model/renderer configuration reward.py # Reward shaping @@ -282,7 +324,6 @@ Note the scope of this repo is an open-source implementation of KernelBench-Tink * More reward examples leveraging more fine-grained metrics * More reward hack checking -* Multi-turn RL to have denser reward signal like [Kevin](https://arxiv.org/abs/2507.11948) * Improve Step time and training efficiency diff --git a/src/kernelbench_tinker/config/configs.py b/src/kernelbench_tinker/config/configs.py index 0144250..2e5e987 100644 --- a/src/kernelbench_tinker/config/configs.py +++ b/src/kernelbench_tinker/config/configs.py @@ -14,9 +14,9 @@ class EvalConfig: """ Configuration for kernel evaluation. - - This config is passed to Modal for kernel evaluation and controls - how correctness and performance are measured. + + Drives correctness/performance measurement and selects the evaluator + backend (Modal cloud NVIDIA, or local subprocess for on-host AMD/HIP). """ # Correctness testing @@ -34,10 +34,18 @@ class EvalConfig: check_for_excessive_speedup: bool = True excessive_speedup_threshold: float = 10.0 + # Evaluator backend selection: "modal" (cloud, NVIDIA) or "local" (in-process subprocess) + evaluator_backend: str = "modal" + # Modal configuration modal_gpu_type: str = "A100" modal_timeout: float = 120.0 + # Local evaluator configuration (used when evaluator_backend == "local") + # gpu_arch is passed to set_gpu_arch() — e.g. ["gfx950"] for MI350X, ["gfx942"] for MI300X + gpu_arch: list[str] = field(default_factory=list) + local_timeout: float = 300.0 + @dataclass class PromptConfig: @@ -81,3 +89,57 @@ class DatasetConfig: # Train/test split test_fraction: float = 0.1 + + +@dataclass +class MultiTurnConfig: + """ + Configuration for multi-turn RL training. + + Controls the iterative refinement loop where the model receives + evaluation feedback and can fix errors across multiple turns. + """ + + # Enable multi-turn mode (False = single-turn) + enabled: bool = False + + # Maximum refinement turns per trajectory + max_turns: int = 4 + + # Discount factor for multi-turn returns: R_t = S_t + gamma * R_{t+1} + gamma: float = 0.4 + + # Return aggregation mode: "sum" or "max" + # sum: R_t = Σ γ^(i-t) × S_i (reward turns leading to many good kernels) + # max: R_t = max{ γ^(i-t) × S_i } (reward turns leading to one great kernel) + aggregation: str = "sum" + + # Stop the episode early when the kernel is correct. + # Default False for training: model needs post-correctness turns to + # learn speedup optimization. Set True at eval time if desired. + early_stop_on_correct: bool = False + + # Optional: require this speedup before early stopping + speedup_threshold: float | None = None + + # Prompt + prompt_max_tokens: int | None = None # Token budget for history truncation (None = char fallback) + inject_think_token: bool = False # Append \n to generation prompts + + # Generation + temperature: float = 0.9 + top_p: float = 1.0 + seed: int | None = None + + # Response length extension mid-training (0 = disabled) + max_tokens_extended: int = 22000 + max_tokens_extend_after_step: int = 30 + + # Training + loss_fn: str = "ppo" + max_grad_norm: float = 0.05 + warmup_ratio: float = 0.03 + clip_epsilon_low: float = 0.2 + clip_epsilon_high: float = 0.28 + constant_length_norm: int = 16384 + num_substeps: int = 2 diff --git a/src/kernelbench_tinker/config/rl_kernelbench.yaml b/src/kernelbench_tinker/config/rl_kernelbench.yaml index bda2995..38bc2f0 100644 --- a/src/kernelbench_tinker/config/rl_kernelbench.yaml +++ b/src/kernelbench_tinker/config/rl_kernelbench.yaml @@ -26,6 +26,33 @@ learning_rate: 0.000002 # 2e-6 as explicit float max_tokens: 16384 temperature: 1.0 +# ============================================================================= +# Multi-turn Configuration (disabled by default) +# ============================================================================= +multiturn: + enabled: false # true to enable iterative refinement + max_turns: 4 # Maximum refinement turns per trajectory + gamma: 0.4 # Discount factor for multi-turn returns + aggregation: "sum" # "sum" (reward many good kernels) or "max" (reward one great kernel) + early_stop_on_correct: false # Stop episode when kernel passes all tests + speedup_threshold: null # Required speedup before early stopping (null = any correct) + # Prompt + prompt_max_tokens: null # Token budget for history truncation (null = char fallback) + inject_think_token: false # Append \n to generation prompts + # Generation + temperature: 0.9 # Generation temperature + top_p: 1.0 # Nucleus sampling (1.0 = disabled) + seed: null # Random seed for generation (null = random) + max_tokens_extended: 22000 # Extend max_tokens mid-training (0 = disabled) + max_tokens_extend_after_step: 30 # Step at which to switch + # Training + loss_fn: "ppo" # Loss function (single-turn uses top-level loss_fn) + max_grad_norm: 0.05 # Gradient clipping (0.0 = disabled) + warmup_ratio: 0.03 # Linear LR warmup fraction + clip_epsilon_low: 0.2 # PPO clip lower bound + clip_epsilon_high: 0.28 # PPO clip upper bound (Clip-High) + constant_length_norm: 16384 # GRPO constant length normalization (0 = disabled) + # ============================================================================= # Training Configuration # ============================================================================= @@ -57,6 +84,7 @@ dataset_builder: # Problem Selection # --------------------------------------------------------------------------- level: 1 # KernelBench level (1, 2, 3, or 4) + levels: null # Train on multiple levels (e.g. [1, 2]); overrides level when set start_problem: null # First problem ID (null = start from 1) end_problem: null # Last problem ID (null = all problems) dataset_src: "huggingface" # "huggingface" or "local" @@ -107,6 +135,9 @@ dataset_builder: reward_correctness_weight: 0.3 reward_speed_weight: 1.0 reward_length_weight: 0.0 + reward_speed_max_reward: 10.0 # Cap on speed reward component (set high to uncap) + reward_clip_min: null # Lower bound on total reward (null = no clipping) + reward_clip_max: null # Upper bound on total reward (null = no clipping) # --------------------------------------------------------------------------- # Reward Hacking Detection (Static Checker) diff --git a/src/kernelbench_tinker/config/rl_kernelbench_hip.yaml b/src/kernelbench_tinker/config/rl_kernelbench_hip.yaml new file mode 100644 index 0000000..2da2b2e --- /dev/null +++ b/src/kernelbench_tinker/config/rl_kernelbench_hip.yaml @@ -0,0 +1,156 @@ +# KernelBench RL Training Configuration — HIP backend on AMD MI350X +# =================================================================== +# +# Drives the kernelbench-tinker integration against AMD GPUs (MI300X / MI350X) +# using the HIP backend added in upstream KernelBench PR #135. +# +# Key differences from rl_kernelbench.yaml: +# - dataset_builder.backend = "hip" +# - dataset_builder.evaluator_backend = "local" (no Modal — runs in-process subprocess) +# - dataset_builder.gpu_arch = ["gfx950"] (MI350X; use ["gfx942"] for MI300X) +# - dataset_builder.precision = "fp32" +# - reward_static_checker_backend = "hip" +# +# Required Environment Variables: +# TINKER_API_KEY - Tinker distributed training API key +# PYTORCH_ROCM_ARCH - Set automatically by the local evaluator from gpu_arch +# +# Usage (on the AMD cluster, inside the apptainer container): +# srun --account=matx --partition=matx-interactive --gres=gpu:1 \ +# apptainer exec --rocm /matx/u/knatalia/rocm_pytorch.sif \ +# python -m kernelbench_tinker.scripts.train_kernel_rl \ +# --config src/kernelbench_tinker/config/rl_kernelbench_hip.yaml \ +# log_path=./runs/hip_rl_experiment + +# ============================================================================= +# Model Configuration +# ============================================================================= +model_name: "Qwen/Qwen3-30B-A3B" +lora_rank: 64 +learning_rate: 0.000002 + +# ============================================================================= +# Generation Configuration +# ============================================================================= +max_tokens: 16384 +temperature: 1.0 + +# ============================================================================= +# Multi-turn Configuration +# ============================================================================= +multiturn: + enabled: true # Multi-turn refinement is the whole point of moving to HIP RL + max_turns: 4 + gamma: 0.4 + aggregation: "sum" + early_stop_on_correct: false + speedup_threshold: null + prompt_max_tokens: null + inject_think_token: false + temperature: 0.9 + top_p: 1.0 + seed: null + max_tokens_extended: 22000 + max_tokens_extend_after_step: 30 + loss_fn: "ppo" + max_grad_norm: 0.05 + warmup_ratio: 0.03 + clip_epsilon_low: 0.2 + clip_epsilon_high: 0.28 + constant_length_norm: 16384 + +# ============================================================================= +# Training Configuration +# ============================================================================= +num_substeps: 2 +loss_fn: "importance_sampling" +kl_penalty_coef: 0.0 +kl_discount_factor: 0.0 +remove_constant_reward_groups: true + +# ============================================================================= +# Logging and Checkpointing +# ============================================================================= +log_path: "./runs/hip_rl_default" +save_every: 1 +wandb_project: kernelbench-tinker +wandb_name: kb-tinker-rl-hip + +# ============================================================================= +# Dataset Builder Configuration +# ============================================================================= +dataset_builder: + # --------------------------------------------------------------------------- + # Problem Selection — sweep all 300 problems across all 3 levels + # --------------------------------------------------------------------------- + level: 1 + levels: [1, 2, 3] # Iterate every task across every level + start_problem: null + end_problem: null + dataset_src: "local" # Use local KernelBench/ submodule (HF dataset is CUDA-only) + + # --------------------------------------------------------------------------- + # Kernel Backend — HIP via PR #135 + # --------------------------------------------------------------------------- + backend: "hip" + + # --------------------------------------------------------------------------- + # Evaluator Backend — local in-process subprocess (no Modal) + # --------------------------------------------------------------------------- + evaluator_backend: "local" + gpu_arch: ["gfx950"] # MI350X; use ["gfx942"] for MI300X + local_timeout: 300.0 # seconds per kernel subprocess (matches eval_hip.py) + + # --------------------------------------------------------------------------- + # Batching + # --------------------------------------------------------------------------- + batch_size: 8 + group_size: 16 + num_epochs: 22 + shuffle: true + test_fraction: 0.1 + + # --------------------------------------------------------------------------- + # Prompt Configuration + # --------------------------------------------------------------------------- + renderer_name: "qwen3" + prompt_option: "one_shot" + prompt_precision: null + prompt_include_hardware: false + prompt_gpu_name: null + + # --------------------------------------------------------------------------- + # Evaluation + # --------------------------------------------------------------------------- + num_correct_trials: 5 + measure_performance: true + num_perf_trials: 100 + timing_method: "cuda_event" # Reused as the HIP-event timing path internally + precision: "fp32" + check_for_excessive_speedup: true + excessive_speedup_threshold: 10.0 + + # Modal fields kept as inert defaults (ignored when evaluator_backend == "local") + modal_gpu_type: "A100" + modal_timeout: 60.0 + + # --------------------------------------------------------------------------- + # Reward Weights + # --------------------------------------------------------------------------- + reward_format_weight: 0.0 + reward_compile_weight: 0.0 + reward_correctness_weight: 0.3 + reward_speed_weight: 1.0 + reward_length_weight: 0.0 + reward_speed_max_reward: 10.0 + reward_clip_min: null + reward_clip_max: null + + # --------------------------------------------------------------------------- + # Reward Hacking Detection (Static Checker) + # --------------------------------------------------------------------------- + reward_enable_static_checker: true + reward_static_checker_backend: "hip" + reward_static_checker_precision: "fp32" + reward_static_checker_strict: null + reward_static_checker_warnings: null diff --git a/src/kernelbench_tinker/envs/env_utils.py b/src/kernelbench_tinker/envs/env_utils.py new file mode 100644 index 0000000..562e20b --- /dev/null +++ b/src/kernelbench_tinker/envs/env_utils.py @@ -0,0 +1,150 @@ +""" +Shared utilities for KernelBench environments. + +Contains helpers used by both the single-turn and multi-turn environments: +- System prompt construction +- Step evaluation (parse → evaluate → reward → metrics) +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from tinker_cookbook import renderers +from tinker_cookbook.rl.types import Action, Metrics + +from kernelbench_tinker.config.configs import EvalConfig +from kernelbench_tinker.envs.kernelbench_client import ( + KernelBenchProblem, + KernelEvalResult, + ParsedResponse, + evaluate_kernel_async, + parse_structured_response, +) +from kernelbench_tinker.training.reward import ( + RewardConfig, + compute_reward, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class EvalStepResult: + """Result from evaluate_step(), shared by single-turn and multi-turn envs.""" + + parsed: ParsedResponse + eval_result: KernelEvalResult + format_ok: bool + kernel_code: str + reward: float + metrics: Metrics + response_text: str # Raw response content from renderer (before structured parsing) + + +def build_system_prompt(backend: str) -> str: + """Build a backend-specific system prompt for kernel generation. + + Used by both single-turn and multi-turn environments. + """ + return ( + f"You are an expert GPU kernel developer. Your task is to optimize PyTorch " + f"operations by writing efficient custom {backend.upper()} kernels.\n" + f"\n" + f"When given a PyTorch model, write an optimized kernel implementation.\n" + f"\n" + f"Your solution must:\n" + f"- Be a drop-in replacement as a class named `ModelNew`\n" + f"- Use custom {backend.upper()} kernels, not just PyTorch operations\n" + f"- Be correct and produce the same results as the reference\n" + f"\n" + f"You MUST respond in exactly this format:\n" + f"\n" + f"\n" + f"```python\n" + f"# Your complete optimized implementation here\n" + f"class ModelNew(nn.Module):\n" + f" ...\n" + f"```\n" + f"" + ) + + +async def evaluate_step( + problem: KernelBenchProblem, + renderer: renderers.Renderer, + action: Action, + eval_config: EvalConfig, + reward_config: RewardConfig, + step_start: float, +) -> EvalStepResult: + """Parse, evaluate, and compute reward for a single action. + + Shared by KernelBenchEnv.step() and MultiTurnKernelBenchEnv.step(). + """ + message, _ = renderer.parse_response(action) + response_text = message.get("content", "") + + parsed = parse_structured_response(response_text) + kernel_code = parsed.kernel + format_ok = parsed.format_ok + + eval_start = time.perf_counter() + cfg = eval_config + eval_result = await evaluate_kernel_async( + level=problem.level, + problem_id=problem.problem_id, + backend=problem.backend, + kernel_code=kernel_code, + dataset_src=problem.dataset_src, + num_correct_trials=cfg.num_correct_trials, + measure_performance=cfg.measure_performance, + num_perf_trials=cfg.num_perf_trials, + timing_method=cfg.timing_method, + precision=cfg.precision, + check_for_excessive_speedup=cfg.check_for_excessive_speedup, + excessive_speedup_threshold=cfg.excessive_speedup_threshold, + timeout=cfg.modal_timeout, + ) + eval_time = time.perf_counter() - eval_start + + reward = compute_reward( + eval_result, + reward_config, + kernel_code=kernel_code, + backend=problem.backend, + ) + + metrics: Metrics = { + "level": problem.level, + "problem_id": problem.problem_id, + "format_ok": float(format_ok), + "compiled": float(eval_result["compiled"]), + "correctness": float(eval_result["correctness"]), + "tests_passed": eval_result["tests_passed"], + "tests_total": eval_result["tests_total"], + } + if eval_result.get("speedup") is not None: + metrics["speedup"] = eval_result["speedup"] + if eval_result.get("runtime_ms") is not None: + metrics["runtime_ms"] = eval_result["runtime_ms"] + metrics["time/eval"] = eval_time + timing_metadata = (eval_result.get("metadata") or {}).get("timings", {}) + if "reference_load_s" in timing_metadata: + metrics["time/ref_load"] = timing_metadata["reference_load_s"] + if "modal_eval_s" in timing_metadata: + metrics["time/modal_eval"] = timing_metadata["modal_eval_s"] + metrics["time/step_total"] = time.perf_counter() - step_start + + return EvalStepResult( + parsed=parsed, + eval_result=eval_result, + format_ok=format_ok, + kernel_code=kernel_code, + reward=reward, + metrics=metrics, + response_text=response_text, + ) diff --git a/src/kernelbench_tinker/envs/evaluator_dispatch.py b/src/kernelbench_tinker/envs/evaluator_dispatch.py new file mode 100644 index 0000000..c5b04d1 --- /dev/null +++ b/src/kernelbench_tinker/envs/evaluator_dispatch.py @@ -0,0 +1,123 @@ +""" +Evaluator backend dispatch. + +Lets the rest of kernelbench-tinker stay agnostic to whether kernels are +evaluated on Modal (cloud, NVIDIA only) or in a local subprocess (works +with AMD ROCm via PR #135's HIP backend). + +Selection happens once per process via `set_evaluator_from_eval_config()`, +which inspects `eval_config.evaluator_backend` ("modal" or "local") and +installs the corresponding global evaluator. After that, every call site +fetches the current evaluator with `get_current_evaluator()` — which +returns whichever backend was registered. + +Both backend evaluator classes (ModalKernelEvaluator, LocalKernelEvaluator) +expose the same async surface: `evaluate_single`, `evaluate_batch`, +`evaluate_single_batched`. So callers don't need conditionals. +""" + +from __future__ import annotations + +import logging +from typing import Any, Protocol + +logger = logging.getLogger(__name__) + + +# Module-level state. Tracks which backend was last registered so callers +# that need to know (e.g. for logging) can ask, but normally callers should +# just use get_current_evaluator() and treat the result as opaque. +_current_backend: str = "modal" + + +class _EvaluatorLike(Protocol): + """Structural type both ModalKernelEvaluator and LocalKernelEvaluator satisfy.""" + + async def evaluate_single(self, *args: Any, **kwargs: Any) -> dict[str, Any]: ... + async def evaluate_batch(self, evaluations: list[dict[str, Any]]) -> list[dict[str, Any]]: ... + async def evaluate_single_batched(self, *args: Any, **kwargs: Any) -> dict[str, Any]: ... + + +def set_evaluator_from_eval_config(eval_config: Any) -> _EvaluatorLike: + """ + Install the appropriate global evaluator based on `eval_config.evaluator_backend`. + + Accepts either the dataclass `EvalConfig` from + `kernelbench_tinker.config.configs` or the chz `EvalConfig` from + `kernelbench_tinker.scripts.eval_kernel_rl` — they share the relevant + field names so duck typing is fine. + + Returns the registered evaluator instance. + """ + global _current_backend + + backend = getattr(eval_config, "evaluator_backend", "modal") or "modal" + backend = backend.lower() + + if backend == "local": + from kernelbench_tinker.local.evaluator import ( + LocalEvaluatorConfig, + LocalKernelEvaluator, + set_local_evaluator, + ) + + gpu_arch = list(getattr(eval_config, "gpu_arch", []) or []) + timeout = int(getattr(eval_config, "local_timeout", 300)) + local_cfg = LocalEvaluatorConfig( + enabled=True, + gpu_arch=gpu_arch, + timeout=timeout, + ) + evaluator = LocalKernelEvaluator(local_cfg) + set_local_evaluator(evaluator) + _current_backend = "local" + logger.info( + "Local evaluator configured: gpu_arch=%s, timeout=%ds", + gpu_arch, timeout, + ) + return evaluator + + if backend == "modal": + from kernelbench_tinker.modal.evaluator import ( + ModalEvaluatorConfig, + ModalKernelEvaluator, + set_modal_evaluator, + ) + + modal_cfg = ModalEvaluatorConfig( + enabled=True, + gpu_type=getattr(eval_config, "modal_gpu_type", "A100"), + timeout=int(getattr(eval_config, "modal_timeout", 120.0)), + ) + evaluator = ModalKernelEvaluator(modal_cfg) + set_modal_evaluator(evaluator) + _current_backend = "modal" + logger.info( + "Modal evaluator configured: gpu_type=%s, timeout=%ds", + modal_cfg.gpu_type, modal_cfg.timeout, + ) + return evaluator + + raise ValueError( + f"Unknown evaluator_backend={backend!r}. Expected 'modal' or 'local'." + ) + + +def get_current_evaluator() -> _EvaluatorLike: + """ + Return whichever evaluator is currently registered as the global. + + Defaults to the Modal evaluator if no explicit selection has been made, + preserving the historical behavior of the project. + """ + if _current_backend == "local": + from kernelbench_tinker.local.evaluator import get_local_evaluator + return get_local_evaluator() + + from kernelbench_tinker.modal.evaluator import get_modal_evaluator + return get_modal_evaluator() + + +def get_current_backend() -> str: + """Return the name of the currently registered backend ('modal' or 'local').""" + return _current_backend diff --git a/src/kernelbench_tinker/envs/kernelbench_client.py b/src/kernelbench_tinker/envs/kernelbench_client.py index 495d44f..6a60095 100644 --- a/src/kernelbench_tinker/envs/kernelbench_client.py +++ b/src/kernelbench_tinker/envs/kernelbench_client.py @@ -9,13 +9,13 @@ import functools import hashlib +import logging import os import re import sys import time from collections import OrderedDict from dataclasses import dataclass, field -import logging from typing import Any, TypedDict, cast logger = logging.getLogger(__name__) @@ -33,11 +33,18 @@ re.DOTALL | re.IGNORECASE ) +# Summary block pattern - reasoning summary inside ... +SUMMARY_BLOCK_PATTERN = re.compile( + r"(.*?)", + re.DOTALL | re.IGNORECASE +) + @dataclass class ParsedResponse: """Parsed model response with kernel blocks.""" kernel: str # Kernel code (from block or extracted code block) + cot_summary: str # Reasoning summary (from block) raw: str # Original raw response format_ok: bool # Whether we successfully extracted kernel code @@ -94,8 +101,15 @@ def parse_structured_response(text: str) -> ParsedResponse: # Check if we got valid kernel code format_ok = bool(kernel) and ("class ModelNew" in kernel or "def forward" in kernel) + # Extract CoT summary from block + cot_summary = "" + summary_match = SUMMARY_BLOCK_PATTERN.search(text) + if summary_match: + cot_summary = summary_match.group(1).strip() + return ParsedResponse( kernel=kernel, + cot_summary=cot_summary, raw=raw, format_ok=format_ok, ) @@ -281,11 +295,12 @@ async def evaluate_kernel_async( cache_results: bool = True, ) -> KernelEvalResult: """ - Evaluate a generated kernel using Modal for isolated GPU execution. + Evaluate a generated kernel via the currently-registered evaluator backend + (Modal for cloud NVIDIA, Local subprocess for on-host AMD/HIP). - This function provides: + Both backends provide: - Hard timeout enforcement (kills bad kernels after timeout) - - Process isolation (each kernel runs in separate container) + - Process isolation (each kernel runs in a fresh container or subprocess) - Protection against GPU corruption from bad kernels Args: @@ -302,13 +317,18 @@ async def evaluate_kernel_async( Returns: KernelEvalResult with evaluation results """ - from kernelbench_tinker.modal.evaluator import ( - ModalEvaluatorConfig, - get_modal_evaluator, + from kernelbench_tinker.envs.evaluator_dispatch import ( + get_current_backend, + get_current_evaluator, ) t_total_start = time.perf_counter() timings: dict[str, float] = {} + # Resolve backend name once so it can flow through the cache key (a switch + # between Modal and Local would otherwise return stale runtimes from the + # other backend). + backend_name = get_current_backend() + # Simple LRU cache to avoid re-evaluating identical kernels for the same problem. # We cache even failures to avoid repeatedly paying for hopeless kernels. _eval_cache = _EVAL_CACHE @@ -316,7 +336,7 @@ async def evaluate_kernel_async( def _make_cache_key(code: str) -> str: h = hashlib.sha1(code.encode("utf-8"), usedforsecurity=False).hexdigest() return ( - f"{level}:{problem_id}:{backend}:{dataset_src}:" + f"{level}:{problem_id}:{backend}:{dataset_src}:{backend_name}:" f"{num_correct_trials}:{measure_performance}:{num_perf_trials}:" f"{precision}:{timing_method}:" f"{check_for_excessive_speedup}:{excessive_speedup_threshold}:" @@ -380,12 +400,13 @@ def _prune_cache(maxsize: int = 512) -> None: return default_result timings["reference_load_s"] = time.perf_counter() - ref_start - # Get Modal evaluator with configured timeout - config = ModalEvaluatorConfig(timeout=int(timeout)) - evaluator = get_modal_evaluator(config) + # Get whichever evaluator backend was registered for this process + # (Modal for cloud NVIDIA, Local for on-host AMD/HIP). Falls back to + # Modal if nothing has been explicitly selected. + evaluator = get_current_evaluator() - # Run evaluation on Modal - modal_start = time.perf_counter() + # Run evaluation + eval_start = time.perf_counter() try: result = cast( KernelEvalResult, @@ -407,25 +428,27 @@ def _prune_cache(maxsize: int = 512) -> None: _eval_cache[cache_key] = result_copy _prune_cache() - timings["modal_eval_s"] = time.perf_counter() - modal_start + timings["eval_s"] = time.perf_counter() - eval_start timings["total_eval_s"] = time.perf_counter() - t_total_start result_metadata = result.get("metadata", {}) or {} result_metadata.setdefault("timings", {}).update(timings) + result_metadata.setdefault("evaluator_backend", backend_name) result["metadata"] = result_metadata logger.debug( - "Modal eval timings level=%s problem=%s ref_load=%.3fs modal=%.3fs total=%.3fs", + "Eval timings backend=%s level=%s problem=%s ref_load=%.3fs eval=%.3fs total=%.3fs", + backend_name, level, problem_id, timings.get("reference_load_s", 0.0), - timings.get("modal_eval_s", 0.0), + timings.get("eval_s", 0.0), timings.get("total_eval_s", 0.0), ) return result except Exception as e: - default_result["error_message"] = f"Modal evaluation failed: {e}" - logger.exception("Modal kernel evaluation failed") - timings["modal_eval_s"] = time.perf_counter() - modal_start + default_result["error_message"] = f"{backend_name} evaluation failed: {e}" + logger.exception("Kernel evaluation failed (backend=%s)", backend_name) + timings["eval_s"] = time.perf_counter() - eval_start timings["total_eval_s"] = time.perf_counter() - t_total_start default_result["metadata"]["timings"] = timings if cache_results: @@ -487,6 +510,7 @@ class KernelBenchProblem: prompt_gpu_name: str | None = None _prompt: str | None = field(default=None, repr=False) + _base_prompt: str | None = field(default=None, repr=False) @property def prompt(self) -> str: @@ -504,3 +528,23 @@ def prompt(self) -> str: ) return self._prompt + @property + def base_prompt(self) -> str: + """Get the zero-shot prompt (no examples) for refinement turns. + + In multi-turn training, the one-shot example is included only on the + first turn. Subsequent turns use this stripped-down prompt to save + context tokens. + """ + if self._base_prompt is None: + self._base_prompt = get_prompt_for_problem( + self.level, + self.problem_id, + self.backend, + option="zero_shot", + dataset_src=self.dataset_src, + precision=self.prompt_precision, + include_hardware=self.prompt_include_hardware, + gpu_name=self.prompt_gpu_name, + ) + return self._base_prompt diff --git a/src/kernelbench_tinker/envs/kernelbench_env.py b/src/kernelbench_tinker/envs/kernelbench_env.py index 209ef5b..215607b 100644 --- a/src/kernelbench_tinker/envs/kernelbench_env.py +++ b/src/kernelbench_tinker/envs/kernelbench_env.py @@ -30,50 +30,27 @@ ) from tinker_cookbook.utils import logtree +from kernelbench_tinker.config.configs import EvalConfig +from kernelbench_tinker.envs.env_utils import ( + EvalStepResult, + build_system_prompt, + evaluate_step, +) from kernelbench_tinker.envs.kernelbench_client import ( KernelBenchProblem, KernelEvalResult, ParsedResponse, - evaluate_kernel_async, get_problem_ids, - parse_structured_response, ) -from kernelbench_tinker.config.configs import EvalConfig from kernelbench_tinker.training.reward import ( - compute_reward, - compute_reward_breakdown, RewardConfig, + compute_reward_breakdown, ) from kernelbench_tinker.training.trace_logger import get_trace_logger logger = logging.getLogger(__name__) -# Default system prompt for kernel generation (structured format) -DEFAULT_SYSTEM_PROMPT = """You are an expert GPU kernel developer. Your task is to optimize PyTorch operations by writing efficient custom GPU kernels. - -When given a PyTorch model, you should: -1. Analyze the operations being performed -2. Write an optimized kernel implementation -3. Return your solution as a Python class named `ModelNew` that implements the same interface - -Your kernel should: -- Be functionally correct (produce the same outputs as the reference) -- Be efficient (aim for speedup over the PyTorch baseline) -- Handle edge cases properly -- Use the specified backend (Triton, CUDA, etc.) - -You MUST respond in exactly this format: - - -```python -# Your complete optimized implementation here -class ModelNew(nn.Module): - ... -``` -""" - - class KernelBenchEnv(Env): """ A single-turn RL environment for a KernelBench problem. @@ -119,8 +96,7 @@ def _build_initial_messages(self) -> list[renderers.Message]: """Build the initial conversation for the problem.""" messages: list[renderers.Message] = [] - # Add system prompt if supported - messages.append({"role": "system", "content": DEFAULT_SYSTEM_PROMPT}) + messages.append({"role": "system", "content": build_system_prompt(self.problem.backend)}) # Add the problem prompt as user message messages.append({"role": "user", "content": self.problem.prompt}) @@ -151,97 +127,40 @@ async def step(self, action: Action) -> StepResult: StepResult with reward and episode done status """ step_start = time.perf_counter() - # Parse the response to get text - message, _ = self.renderer.parse_response(action) - response_text = message.get("content", "") - - # Parse structured response (extracts block) - parsed = parse_structured_response(response_text) - kernel_code = parsed.kernel - # Check format validity - format_ok = parsed.format_ok - - # Evaluate the kernel (Modal for isolated GPU execution) - eval_start = time.perf_counter() - cfg = self.eval_config - eval_result = await evaluate_kernel_async( - level=self.problem.level, - problem_id=self.problem.problem_id, - backend=self.problem.backend, - kernel_code=kernel_code, - dataset_src=self.problem.dataset_src, - num_correct_trials=cfg.num_correct_trials, - measure_performance=cfg.measure_performance, - num_perf_trials=cfg.num_perf_trials, - timing_method=cfg.timing_method, - precision=cfg.precision, - check_for_excessive_speedup=cfg.check_for_excessive_speedup, - excessive_speedup_threshold=cfg.excessive_speedup_threshold, - timeout=cfg.modal_timeout, - ) - eval_time = time.perf_counter() - eval_start - - # Compute reward (pass kernel_code for static checking) - reward = compute_reward( - eval_result, - self.reward_config, - kernel_code=kernel_code, - backend=self.problem.backend, + r = await evaluate_step( + self.problem, self.renderer, action, + self.eval_config, self.reward_config, step_start, ) # Log the attempt logtree.log_text(f"Problem: Level {self.problem.level}, ID {self.problem.problem_id}") - logtree.log_text(f"Format OK: {'Yes' if format_ok else 'No'}") - logtree.log_text(f"Compiled: {'Yes' if eval_result['compiled'] else 'No'}") + logtree.log_text(f"Format OK: {'Yes' if r.format_ok else 'No'}") + logtree.log_text(f"Compiled: {'Yes' if r.eval_result['compiled'] else 'No'}") logtree.log_text( - f"Correctness: {eval_result['tests_passed']}/{eval_result['tests_total']}" + f"Correctness: {r.eval_result['tests_passed']}/{r.eval_result['tests_total']}" ) - if eval_result.get("speedup"): - logtree.log_text(f"Speedup: {eval_result['speedup']:.2f}x") - logtree.log_text(f"Reward: {reward:.3f}") - error_message = eval_result.get("error_message") + if r.eval_result.get("speedup") is not None: + logtree.log_text(f"Speedup: {r.eval_result['speedup']:.2f}x") + logtree.log_text(f"Reward: {r.reward:.3f}") + error_message = r.eval_result.get("error_message") if error_message: logtree.log_text(f"Error: {error_message[:200]}") - # Build metrics - metrics: Metrics = { - "level": self.problem.level, - "problem_id": self.problem.problem_id, - "format_ok": float(format_ok), - "compiled": float(eval_result["compiled"]), - "correctness": float(eval_result["correctness"]), - "tests_passed": eval_result["tests_passed"], - "tests_total": eval_result["tests_total"], - } - if eval_result.get("speedup"): - metrics["speedup"] = eval_result["speedup"] - if eval_result.get("runtime_ms"): - metrics["runtime_ms"] = eval_result["runtime_ms"] - metrics["time/eval"] = eval_time - timing_metadata = (eval_result.get("metadata") or {}).get("timings", {}) - if "reference_load_s" in timing_metadata: - metrics["time/ref_load"] = timing_metadata["reference_load_s"] - if "modal_eval_s" in timing_metadata: - metrics["time/modal_eval"] = timing_metadata["modal_eval_s"] - metrics["time/step_total"] = time.perf_counter() - step_start - # Trace logging (prompt + response + eval) await self._log_trace( - parsed=parsed, - eval_result=eval_result, - format_ok=format_ok, - reward=reward, - metrics=metrics, + parsed=r.parsed, + eval_result=r.eval_result, + format_ok=r.format_ok, + reward=r.reward, + metrics=r.metrics, ) - episode_done = True - return StepResult( - reward=reward, - episode_done=episode_done, + reward=r.reward, + episode_done=True, next_observation=tinker.ModelInput.empty(), next_stop_condition=self.stop_condition, - metrics=metrics, + metrics=r.metrics, ) async def _log_trace( @@ -349,19 +268,6 @@ def __init__( shuffle: bool = True, num_epochs: int = 1, ): - """ - Initialize the RL dataset. - - Args: - problems: List of KernelBench problems - renderer: Tinker renderer for formatting - batch_size: Number of problems per batch - group_size: Number of rollouts per problem - eval_config: Configuration for kernel evaluation - reward_config: Reward configuration - shuffle: Whether to shuffle problems each epoch - num_epochs: Number of training epochs - """ self.problems = problems self.renderer = renderer self.batch_size = batch_size @@ -401,15 +307,13 @@ def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]: for i in range(start_idx, end_idx): problem_idx = self._problem_indices[i] problem = self.problems[problem_idx] - - builder = KernelBenchEnvGroupBuilder( + builders.append(KernelBenchEnvGroupBuilder( problem=problem, renderer=self.renderer, group_size=self.group_size, eval_config=self.eval_config, reward_config=self.reward_config, - ) - builders.append(builder) + )) return builders @@ -425,6 +329,7 @@ class KernelBenchDatasetBuilder(RLDatasetBuilder): # Problem selection level: int = 1 + levels: list[int] | None = None # Train on multiple levels (overrides level when set) start_problem: int | None = None end_problem: int | None = None backend: str = "triton" @@ -452,6 +357,11 @@ class KernelBenchDatasetBuilder(RLDatasetBuilder): reward_speed_weight: float = 1.0 reward_length_weight: float = 0.0 + # Reward clipping and speed cap + reward_clip_min: float | None = None + reward_clip_max: float | None = None + reward_speed_max_reward: float = 10.0 # Cap on speed reward component + # Reward hacking detection (static checker) reward_enable_static_checker: bool = True reward_static_checker_backend: str = "triton" @@ -464,6 +374,9 @@ class KernelBenchDatasetBuilder(RLDatasetBuilder): # Test split test_fraction: float = 0.1 + # Explicit holdout indices per level (overrides test_fraction when set) + # Format: {level: [problem_ids]} e.g. {1: [3,10,25], 2: [10,20,30]} + holdout_indices: dict[int, list[int]] | None = None # Prompt configuration prompt_option: str = "one_shot" # "zero_shot", "one_shot", "few_shot" @@ -471,43 +384,72 @@ class KernelBenchDatasetBuilder(RLDatasetBuilder): prompt_include_hardware: bool = False prompt_gpu_name: str | None = None + # Evaluator backend selection: "modal" (cloud NVIDIA) or "local" (in-process subprocess for AMD/HIP) + evaluator_backend: str = "modal" + # Modal configuration (isolated GPU evaluation) modal_gpu_type: str = "A100" # GPU type to use on Modal modal_timeout: float = 120.0 # Timeout in seconds per kernel + # Local evaluator configuration (used when evaluator_backend == "local") + # gpu_arch is passed to set_gpu_arch() — e.g. ["gfx950"] for MI350X, ["gfx942"] for MI300X + gpu_arch: list[str] | None = None + local_timeout: float = 300.0 + async def __call__(self, tokenizer=None) -> tuple[RLDataset, RLDataset | None]: """Build train and optional test datasets. Args: tokenizer: The tokenizer to use for the renderer. Required for most renderers. """ - # Get problem IDs - problem_ids = get_problem_ids( - self.level, - start=self.start_problem, - end=self.end_problem, - dataset_src=self.dataset_src, - ) - - # Create problems - all_problems = [ - KernelBenchProblem( - level=self.level, - problem_id=pid, - backend=self.backend, + # Determine which levels to use + active_levels = self.levels if self.levels else [self.level] + + # Collect problems across all levels + all_problems: list[KernelBenchProblem] = [] + for lvl in active_levels: + # Get problem IDs + problem_ids = get_problem_ids( + lvl, + start=self.start_problem, + end=self.end_problem, dataset_src=self.dataset_src, - prompt_option=self.prompt_option, - prompt_precision=self.prompt_precision or self.precision, - prompt_include_hardware=self.prompt_include_hardware, - prompt_gpu_name=self.prompt_gpu_name or ( - self.modal_gpu_type if self.prompt_include_hardware else None - ), ) - for pid in problem_ids - ] + + # Create problems + all_problems.extend( + KernelBenchProblem( + level=lvl, + problem_id=pid, + backend=self.backend, + dataset_src=self.dataset_src, + prompt_option=self.prompt_option, + prompt_precision=self.prompt_precision or self.precision, + prompt_include_hardware=self.prompt_include_hardware, + prompt_gpu_name=self.prompt_gpu_name or ( + self.modal_gpu_type if self.prompt_include_hardware else None + ), + ) + for pid in problem_ids + ) # Split into train/test - if self.test_fraction > 0 and len(all_problems) > 1: + if self.holdout_indices: + # Explicit holdout: separate by (level, problem_id) membership + holdout_set = { + (lvl, pid) + for lvl, pids in self.holdout_indices.items() + for pid in pids + } + train_problems = [ + p for p in all_problems + if (p.level, p.problem_id) not in holdout_set + ] + test_problems = [ + p for p in all_problems + if (p.level, p.problem_id) in holdout_set + ] or None + elif self.test_fraction > 0 and len(all_problems) > 1: n_test = max(1, int(len(all_problems) * self.test_fraction)) # Use last N problems as test set for reproducibility train_problems = all_problems[:-n_test] @@ -528,8 +470,11 @@ async def __call__(self, tokenizer=None) -> tuple[RLDataset, RLDataset | None]: precision=self.precision, check_for_excessive_speedup=self.check_for_excessive_speedup, excessive_speedup_threshold=self.excessive_speedup_threshold, + evaluator_backend=self.evaluator_backend, modal_gpu_type=self.modal_gpu_type, modal_timeout=self.modal_timeout, + gpu_arch=list(self.gpu_arch) if self.gpu_arch else [], + local_timeout=self.local_timeout, ) # Create reward config @@ -539,6 +484,9 @@ async def __call__(self, tokenizer=None) -> tuple[RLDataset, RLDataset | None]: correctness_weight=self.reward_correctness_weight, speed_weight=self.reward_speed_weight, length_weight=self.reward_length_weight, + speed_max_reward=self.reward_speed_max_reward, + reward_clip_min=self.reward_clip_min, + reward_clip_max=self.reward_clip_max, enable_static_checker=self.reward_enable_static_checker, static_checker_backend=self.reward_static_checker_backend or self.backend, static_checker_precision=self.reward_static_checker_precision or self.precision, @@ -546,15 +494,10 @@ async def __call__(self, tokenizer=None) -> tuple[RLDataset, RLDataset | None]: static_checker_warnings=self.reward_static_checker_warnings, ) - # Configure Modal evaluator with the same config - from kernelbench_tinker.modal.evaluator import ModalEvaluatorConfig, set_modal_evaluator, ModalKernelEvaluator - modal_config = ModalEvaluatorConfig( - enabled=True, - gpu_type=eval_config.modal_gpu_type, - timeout=int(eval_config.modal_timeout), - ) - set_modal_evaluator(ModalKernelEvaluator(modal_config)) - logger.info(f"Modal evaluator configured: GPU={eval_config.modal_gpu_type}, timeout={eval_config.modal_timeout}s") + # Install the appropriate evaluator backend (modal vs local). + # Both expose the same async surface, so call sites stay agnostic. + from kernelbench_tinker.envs.evaluator_dispatch import set_evaluator_from_eval_config + set_evaluator_from_eval_config(eval_config) # Create train dataset train_dataset = KernelBenchRLDataset( @@ -583,4 +526,3 @@ async def __call__(self, tokenizer=None) -> tuple[RLDataset, RLDataset | None]: ) return train_dataset, test_dataset - diff --git a/src/kernelbench_tinker/envs/multiturn_kernelbench_env.py b/src/kernelbench_tinker/envs/multiturn_kernelbench_env.py new file mode 100644 index 0000000..6e084c5 --- /dev/null +++ b/src/kernelbench_tinker/envs/multiturn_kernelbench_env.py @@ -0,0 +1,532 @@ +""" +Multi-turn KernelBench RL environment. + +Extends the single-turn KernelBenchEnv to support iterative kernel refinement. +Each episode consists of up to T turns where the model receives evaluation +feedback and can fix errors or improve performance. +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass, field +from typing import Any, Sequence + +import tinker +from tinker.types.model_input_chunk import EncodedTextChunk +from tinker_cookbook import renderers +from tinker_cookbook.completers import StopCondition +from tinker_cookbook.rl.types import ( + Action, + Env, + EnvGroupBuilder, + Metrics, + Observation, + StepResult, + Trajectory, +) +from tinker_cookbook.utils import logtree + +from kernelbench_tinker.config.configs import EvalConfig, MultiTurnConfig +from kernelbench_tinker.envs.env_utils import build_system_prompt, evaluate_step +from kernelbench_tinker.envs.kernelbench_client import ( + KernelBenchProblem, + KernelEvalResult, + ParsedResponse, +) +from kernelbench_tinker.training.reward import ( + RewardConfig, + compute_reward_breakdown, +) +from kernelbench_tinker.training.trace_logger import get_trace_logger + +logger = logging.getLogger(__name__) + +# Limit for feedback content included in refinement prompts (char-based fallback) +MAX_HISTORY_CONTEXT_LEN = 8000 + +def extract_raw_content(response_text: str, eos_token: str | None = None) -> str: + """Extract assistant content from a response, stripping the thinking block. + + Keep text after ```` when present. Strip an unclosed ```` + prefix if the model started thinking but didn't close the tag. If the + response has no ```` and the EOS token is missing, return the EOS + token as a null-response marker. + """ + if "" in response_text: + return response_text.split("")[-1].lstrip('\n') + if "" in response_text: + # Unclosed thinking block — strip it + return response_text.split("")[0].strip() + if eos_token is not None and eos_token not in response_text: + return eos_token + return response_text + + +def build_eval_feedback(eval_result: KernelEvalResult) -> str: + """Build feedback string from an evaluation result for the next refinement turn.""" + error_msg = eval_result.get("error_message") or "" + metadata = eval_result.get("metadata") or {} + error_type = metadata.get("error_type", "") + + if not eval_result["format_ok"] or error_type == "parsing_error": + resp = ( + "Your previous answer failed to be parsed due to not adhering " + f"to the desired formatting. Here's the error message: {error_msg}.\n" + ) + elif not eval_result["compiled"]: + resp = ( + "Your previous answer failed to compile. " + f"Here's the error message: {error_msg}.\n" + ) + elif error_type == "runtime_error" or (not eval_result["correctness"] and error_msg): + # Runtime error: compiled successfully but had runtime errors + resp = ( + "Your previous answer compiled successfully but had runtime " + f"errors. Here's the error message: {error_msg}.\n" + ) + elif not eval_result["correctness"]: + # Incorrect output + resp = ( + "Your previous answer was incorrect. " + f"Here's the error message: {error_msg}.\n" + ) + else: + speedup = eval_result.get("speedup") or 0.0 + resp = ( + "Your previous answer was correct but can be made faster. " + "Here's the speedup you achieved relative to the baseline: " + f"{speedup:.2f}.\n" + ) + + resp += "\nRestart your reasoning process and generate new, complete code." + return resp + + +# --------------------------------------------------------------------------- +# Multi-turn state +# --------------------------------------------------------------------------- + + +@dataclass +class MultiTurnState: + """Mutable state for a multi-turn kernel refinement episode.""" + + level: int + problem_id: int + backend: str + turn_idx: int + max_turns: int + history: list[dict] # Per-turn: {raw_content, kernel, feedback, score} + step_scores: list[float] + done: bool + success: bool + + +# --------------------------------------------------------------------------- +# Multi-turn environment +# --------------------------------------------------------------------------- + + +class MultiTurnKernelBenchEnv(Env): + """ + Multi-turn RL environment for KernelBench. + + Each episode consists of up to T refinement steps: + 1. Turn 0: problem prompt (same as single-turn) + 2. Turn 1+: problem prompt + previous attempt feedback + + The episode ends when the kernel is correct (early stopping) or + max_turns is reached. + """ + + def __init__( + self, + problem: KernelBenchProblem, + renderer: renderers.Renderer, + max_turns: int = 4, + eval_config: EvalConfig | None = None, + reward_config: RewardConfig | None = None, + system_prompt: str | None = None, + early_stop_on_correct: bool = False, + speedup_threshold: float | None = None, + tokenizer: Any | None = None, + prompt_max_tokens: int | None = None, + inject_think_token: bool = False, + ): + self.problem = problem + self.renderer = renderer + self.max_turns = max_turns + self.eval_config = eval_config or EvalConfig() + self.reward_config = reward_config or RewardConfig() + self.early_stop_on_correct = early_stop_on_correct + self.speedup_threshold = speedup_threshold + self.tokenizer = tokenizer + self.prompt_max_tokens = prompt_max_tokens + self.inject_think_token = inject_think_token + + self._system_prompt = system_prompt or build_system_prompt( + problem.backend, + ) + + self._current_prompt_messages: list[renderers.Message] | None = None + self._state: MultiTurnState | None = None + + @property + def stop_condition(self) -> StopCondition: + return self.renderer.get_stop_sequences() + + def _append_think_token(self, observation: tinker.ModelInput) -> tinker.ModelInput: + """Append ``\\n`` tokens after the chat template to force thinking mode.""" + if not self.inject_think_token or self.tokenizer is None: + return observation + think_ids = self.tokenizer.encode("\n", add_special_tokens=False) + return observation.append(EncodedTextChunk(tokens=think_ids)) + + @property + def state(self) -> MultiTurnState: + if self._state is None: + raise RuntimeError( + "Environment not initialized. Call initial_observation first." + ) + return self._state + + def _build_initial_messages(self) -> list[renderers.Message]: + messages: list[renderers.Message] = [] + if self._system_prompt: + messages.append({"role": "system", "content": self._system_prompt}) + messages.append({"role": "user", "content": self.problem.prompt}) + return messages + + def _count_message_tokens(self, messages: list[renderers.Message]) -> int: + """Count total tokens across all messages using the tokenizer.""" + if self.tokenizer is None: + return 0 + total = 0 + for msg in messages: + content = msg.get("content", "") + total += len(self.tokenizer.encode(content)) + return total + + def _build_refinement_messages(self) -> list[renderers.Message]: + """Build refinement prompt with history as alternating assistant/user turns. + + Uses token-based truncation (oldest-first) when tokenizer and + prompt_max_tokens are set, otherwise falls back to char-based. + """ + messages: list[renderers.Message] = [] + if self._system_prompt: + messages.append({"role": "system", "content": self._system_prompt}) + + # Initial user message: problem without one-shot example + base = self.problem.base_prompt + messages.append({ + "role": "user", + "content": base + "Here are your previous attempts:\n", + }) + + if self.state.history: + history = list(self.state.history) + + if self.tokenizer is not None and self.prompt_max_tokens is not None: + # Token-based truncation: count base tokens, fit history + # into remaining budget, keeping most recent entries. + base_tokens = self._count_message_tokens(messages) + # 32-token safety buffer for tokenizer boundary effects + budget = self.prompt_max_tokens - base_tokens - 32 + + # Walk history backwards, accumulating tokens + kept: list[dict] = [] + used = 0 + for entry in reversed(history): + entry_text = entry["raw_content"] + entry["feedback"] + entry_tokens = len(self.tokenizer.encode(entry_text)) + if used + entry_tokens > budget and kept: + break + kept.append(entry) + used += entry_tokens + history = list(reversed(kept)) + else: + # Char-based fallback + total_len = sum( + len(e["raw_content"]) + len(e["feedback"]) for e in history + ) + while total_len > MAX_HISTORY_CONTEXT_LEN and len(history) > 1: + removed = history.pop(0) + total_len -= len(removed["raw_content"]) + len(removed["feedback"]) + + # Add history as assistant/user turn pairs + for entry in history: + messages.append({"role": "assistant", "content": entry["raw_content"]}) + messages.append({"role": "user", "content": entry["feedback"]}) + + return messages + + async def initial_observation(self) -> tuple[Observation, StopCondition]: + self._state = MultiTurnState( + level=self.problem.level, + problem_id=self.problem.problem_id, + backend=self.problem.backend, + turn_idx=0, + max_turns=self.max_turns, + history=[], + step_scores=[], + done=False, + success=False, + ) + messages = self._build_initial_messages() + observation = self.renderer.build_generation_prompt(messages) + observation = self._append_think_token(observation) + self._current_prompt_messages = messages + return observation, self.stop_condition + + async def step(self, action: Action) -> StepResult: + step_start = time.perf_counter() + state = self.state + + r = await evaluate_step( + self.problem, self.renderer, action, + self.eval_config, self.reward_config, step_start, + ) + state.step_scores.append(r.reward) + + # Extract content for history. + # Use CoT summary when available (strips full reasoning, keeps concise + # summary + kernel code), otherwise fall back to raw content with + # thinking block stripped. + eos_token = getattr(self.tokenizer, "eos_token", None) if self.tokenizer else None + raw_content = extract_raw_content(r.response_text, eos_token) + history_content = r.parsed.cot_summary if r.parsed.cot_summary else raw_content + + # Build feedback and store in history + feedback = build_eval_feedback(r.eval_result) + state.history.append({ + "raw_content": history_content, + "kernel": r.kernel_code, + "feedback": feedback, + "score": r.reward, + }) + + # Log + logtree.log_text( + f"Multi-turn: Level {state.level}, ID {state.problem_id}, " + f"Turn {state.turn_idx}" + ) + logtree.log_text(f"Format OK: {'Yes' if r.format_ok else 'No'}") + logtree.log_text( + f"Compiled: {'Yes' if r.eval_result['compiled'] else 'No'}" + ) + logtree.log_text( + f"Correctness: {r.eval_result['tests_passed']}/{r.eval_result['tests_total']}" + ) + if r.eval_result.get("speedup") is not None: + logtree.log_text(f"Speedup: {r.eval_result['speedup']:.2f}x") + logtree.log_text(f"Step score: {r.reward:.3f}") + + # Early stopping + is_correct = r.eval_result["correctness"] + meets_speedup = ( + self.speedup_threshold is None + or (r.eval_result.get("speedup") or 0.0) >= self.speedup_threshold + ) + if self.early_stop_on_correct and is_correct and meets_speedup: + state.done = True + state.success = True + + state.turn_idx += 1 + if state.turn_idx >= state.max_turns: + state.done = True + + # Add multi-turn fields to shared metrics + metrics = r.metrics + metrics["turn"] = state.turn_idx - 1 + metrics["step_score"] = r.reward + metrics["episode_done"] = float(state.done) + metrics["episode_success"] = float(state.success) + + # Trace logging + await self._log_trace( + parsed=r.parsed, + eval_result=r.eval_result, + format_ok=r.format_ok, + reward=r.reward, + metrics=metrics, + ) + + # Next observation or done + if state.done: + next_observation = tinker.ModelInput.empty() + else: + messages = self._build_refinement_messages() + next_observation = self.renderer.build_generation_prompt(messages) + next_observation = self._append_think_token(next_observation) + self._current_prompt_messages = messages + + return StepResult( + reward=r.reward, + episode_done=state.done, + next_observation=next_observation, + next_stop_condition=self.stop_condition, + metrics=metrics, + ) + + async def _log_trace( + self, + parsed: ParsedResponse, + eval_result: KernelEvalResult, + format_ok: bool, + reward: float, + metrics: Metrics, + ) -> None: + trace_logger = get_trace_logger() + if trace_logger is None: + return + + trace_record = { + "mode": "multi_turn", + "level": self.problem.level, + "problem_id": self.problem.problem_id, + "backend": self.problem.backend, + "dataset_src": self.problem.dataset_src, + "prompt_option": self.problem.prompt_option, + "turn": self.state.turn_idx - 1, + "max_turns": self.state.max_turns, + "prompt_messages": self._current_prompt_messages, + "renderer": getattr( + self.renderer, "name", type(self.renderer).__name__ + ), + "response": { + "raw": parsed.raw, + "kernel": parsed.kernel, + "cot_summary": parsed.cot_summary, + "format_ok": format_ok, + }, + "eval_result": eval_result, + "reward": reward, + "reward_breakdown": compute_reward_breakdown( + eval_result, + self.reward_config, + kernel_code=parsed.kernel, + backend=self.problem.backend, + ), + "metrics": metrics, + "history": [ + { + "raw_content": entry["raw_content"], + "kernel": entry["kernel"], + "feedback": entry["feedback"], + "score": entry["score"], + } + for entry in self.state.history + ], + "state": { + "turn_idx": self.state.turn_idx, + "done": self.state.done, + "success": self.state.success, + "step_scores": list(self.state.step_scores), + }, + "timestamp": time.time(), + "stop_condition": str(self.stop_condition), + } + + await trace_logger.log(trace_record) + + def get_step_scores(self) -> list[float]: + """Return per-step scores for discounted return computation.""" + return list(self.state.step_scores) + + +# --------------------------------------------------------------------------- +# Group builder, dataset, dataset builder (mirrors single-turn structure) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class MultiTurnKernelBenchEnvGroupBuilder(EnvGroupBuilder): + """Builder for groups of multi-turn KernelBench environments.""" + + problem: KernelBenchProblem + renderer: renderers.Renderer + group_size: int + max_turns: int = 4 + eval_config: EvalConfig = field(default_factory=EvalConfig) + reward_config: RewardConfig = field(default_factory=RewardConfig) + system_prompt: str | None = None + early_stop_on_correct: bool = False + speedup_threshold: float | None = None + tokenizer: Any | None = field(default=None, hash=False, compare=False) + prompt_max_tokens: int | None = None + inject_think_token: bool = False + + async def make_envs(self) -> Sequence[Env]: + return [ + MultiTurnKernelBenchEnv( + problem=self.problem, + renderer=self.renderer, + max_turns=self.max_turns, + eval_config=self.eval_config, + reward_config=self.reward_config, + system_prompt=self.system_prompt, + early_stop_on_correct=self.early_stop_on_correct, + speedup_threshold=self.speedup_threshold, + tokenizer=self.tokenizer, + prompt_max_tokens=self.prompt_max_tokens, + inject_think_token=self.inject_think_token, + ) + for _ in range(self.group_size) + ] + + async def compute_group_rewards( + self, + trajectory_group: list[Trajectory], + env_group: Sequence[Env], + ) -> list[tuple[float, Metrics]]: + # No-op: real rewards are computed per-step inside env.step() and + # overwritten by apply_discounted_returns before advantage estimation. + return [(0.0, {}) for _ in trajectory_group] + + def logging_tags(self) -> list[str]: + return [ + f"level_{self.problem.level}", + f"problem_{self.problem.problem_id}", + "kernelbench", + "multiturn", + ] + + +def wrap_builders_as_multiturn( + builders: Sequence[EnvGroupBuilder], + multiturn_cfg: MultiTurnConfig, + tokenizer: Any | None = None, +) -> list[MultiTurnKernelBenchEnvGroupBuilder]: + """Wrap single-turn KernelBenchEnvGroupBuilders as multi-turn builders. + + Called by the training loop when multiturn.enabled is True. Reads + problem/renderer/group_size/eval_config/reward_config from each + single-turn builder and creates MultiTurnKernelBenchEnvGroupBuilder + instances with the multi-turn config. + """ + from kernelbench_tinker.envs.kernelbench_env import KernelBenchEnvGroupBuilder + + wrapped = [] + for b in builders: + if not isinstance(b, KernelBenchEnvGroupBuilder): + raise TypeError( + f"Expected KernelBenchEnvGroupBuilder, got {type(b).__name__}" + ) + wrapped.append(MultiTurnKernelBenchEnvGroupBuilder( + problem=b.problem, + renderer=b.renderer, + group_size=b.group_size, + max_turns=multiturn_cfg.max_turns, + eval_config=b.eval_config, + reward_config=b.reward_config, + system_prompt=build_system_prompt(b.problem.backend), + early_stop_on_correct=multiturn_cfg.early_stop_on_correct, + speedup_threshold=multiturn_cfg.speedup_threshold, + tokenizer=tokenizer, + prompt_max_tokens=multiturn_cfg.prompt_max_tokens, + inject_think_token=multiturn_cfg.inject_think_token, + )) + return wrapped diff --git a/src/kernelbench_tinker/evaluation/eval_kernelbench.py b/src/kernelbench_tinker/evaluation/eval_kernelbench.py index cf00836..f4adbd7 100644 --- a/src/kernelbench_tinker/evaluation/eval_kernelbench.py +++ b/src/kernelbench_tinker/evaluation/eval_kernelbench.py @@ -10,7 +10,7 @@ import json import logging import os -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from typing import Any import numpy as np diff --git a/src/kernelbench_tinker/local/__init__.py b/src/kernelbench_tinker/local/__init__.py new file mode 100644 index 0000000..b6f0a5f --- /dev/null +++ b/src/kernelbench_tinker/local/__init__.py @@ -0,0 +1,15 @@ +"""Local in-process kernel evaluator (subprocess-isolated, AMD ROCm friendly).""" + +from kernelbench_tinker.local.evaluator import ( + LocalEvaluatorConfig, + LocalKernelEvaluator, + get_local_evaluator, + set_local_evaluator, +) + +__all__ = [ + "LocalEvaluatorConfig", + "LocalKernelEvaluator", + "get_local_evaluator", + "set_local_evaluator", +] diff --git a/src/kernelbench_tinker/local/evaluator.py b/src/kernelbench_tinker/local/evaluator.py new file mode 100644 index 0000000..a01b14b --- /dev/null +++ b/src/kernelbench_tinker/local/evaluator.py @@ -0,0 +1,428 @@ +""" +Local in-process kernel evaluator with subprocess isolation. + +This module provides the same public API as `kernelbench_tinker.modal.evaluator` +(`evaluate_single`, `evaluate_batch`, `evaluate_single_batched`) but runs +`eval_kernel_against_ref` in an isolated subprocess on the local GPU instead of +shipping work to a Modal container. + +Why a subprocess per kernel? + KernelBench compiles each kernel via torch.utils.cpp_extension.load_inline, + which leaks memory across compilations and quickly OOMs an MI350X with 192GB + of HBM after a couple hundred problems. Running each eval in a fresh Python + interpreter is the same trick used by eval_hip.py and is the only reliable + way to sweep all 300 KernelBench problems on AMD. + +Result shape matches Modal's KernelEvalResult so the rest of the pipeline +(reward computation, multi-turn feedback parsing, eval_kernel_rl) needs no +changes when the dispatch flips between backends. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import sys +from dataclasses import dataclass, field +from typing import Any, cast + +logger = logging.getLogger(__name__) + + +# Sentinel markers used by the worker to delimit the JSON result so we can +# robustly extract it from a noisy stdout (HIP/ROCm prints a lot of warnings). +_RESULT_BEGIN = "__KB_TINKER_LOCAL_RESULT_BEGIN__" +_RESULT_END = "__KB_TINKER_LOCAL_RESULT_END__" + + +# Subprocess worker. Reads a JSON request from stdin, runs one eval, prints +# the result between sentinel markers, then exits. Stays self-contained so +# it can be exec'd via `python -c` without bundling files. +# +# Mirrors the result-shape conversion in src/kernelbench_tinker/modal/app.py +# (correctness_trials regex, ref_runtime → speedup, runtime > 0 guard). +_WORKER_SCRIPT = r''' +import json, os, re, sys, tempfile, traceback + +def _emit(payload): + sys.stdout.write("__KB_TINKER_LOCAL_RESULT_BEGIN__\n") + sys.stdout.write(json.dumps(payload, default=str)) + sys.stdout.write("\n__KB_TINKER_LOCAL_RESULT_END__\n") + sys.stdout.flush() + +raw = sys.stdin.read() +req = json.loads(raw) +num_correct_trials = req.get("num_correct_trials", 5) +kernel_code = req.get("kernel_code", "") + +# Per-subprocess torch extension build dir, isolated from any stale lock +# files left behind by a SIGKILLed predecessor. +_build_dir = tempfile.mkdtemp(prefix="kb_tinker_ext_") +os.environ["TORCH_EXTENSIONS_DIR"] = _build_dir + +try: + # Configure GPU arch BEFORE importing torch so AOT compile paths see it. + # set_gpu_arch is pure os.environ manipulation, no torch dependency. + gpu_arch = req.get("gpu_arch") or [] + if gpu_arch: + from kernelbench.utils import set_gpu_arch + set_gpu_arch(gpu_arch) + + import torch + from kernelbench.eval import eval_kernel_against_ref, get_torch_dtype_from_string + + dtype = get_torch_dtype_from_string(req.get("precision", "fp32")) + + if not torch.cuda.is_available(): + _emit({ + "format_ok": True, + "compiled": False, "correctness": False, + "tests_passed": 0, "tests_total": num_correct_trials, + "speedup": None, "runtime_ms": None, "baseline_runtime_ms": None, + "error_message": "torch.cuda not available in worker subprocess", + "code_length": len(kernel_code), + "metadata": {}, + }) + sys.exit(0) + + device = torch.device("cuda:0") + result = eval_kernel_against_ref( + original_model_src=req["ref_code"], + custom_model_src=kernel_code, + num_correct_trials=num_correct_trials, + num_perf_trials=req.get("num_perf_trials", 100), + measure_performance=req.get("measure_performance", False), + verbose=False, + device=device, + backend=req.get("backend", "triton"), + precision=dtype, + ) +except Exception as e: + # Anything raised before/around eval_kernel_against_ref is a worker + # framework failure (OOM, driver crash, KernelBench bug, ...). Report + # it explicitly so it can't be confused with a kernel-side failure. + _emit({ + "format_ok": True, + "compiled": False, "correctness": False, + "tests_passed": 0, "tests_total": num_correct_trials, + "speedup": None, "runtime_ms": None, "baseline_runtime_ms": None, + "error_message": f"Worker framework error: {e}", + "code_length": len(kernel_code), + "metadata": {"traceback": traceback.format_exc()[-1000:]}, + }) + sys.exit(0) + +# Result marshalling — keep narrow so a marshalling bug never masquerades +# as a user-kernel failure. Mirrors modal/app.py:300-370. +try: + metadata = {k: str(v)[:500] for k, v in dict(getattr(result, "metadata", {}) or {}).items()} + + tests_passed = 0 + trials_str = metadata.get("correctness_trials", "(0 / 0)") + m = re.match(r"\((\d+)\s*/\s*(\d+)\)", trials_str) + if m: + tests_passed = int(m.group(1)) + elif result.correctness: + tests_passed = num_correct_trials + + runtime_raw = getattr(result, "runtime", -1.0) + runtime_ms = float(runtime_raw) if runtime_raw and runtime_raw > 0 else None + + baseline_runtime_ms = None + ref_runtime = getattr(result, "ref_runtime", None) + if ref_runtime and ref_runtime > 0: + baseline_runtime_ms = float(ref_runtime) + + speedup = None + if result.correctness and runtime_ms and baseline_runtime_ms and baseline_runtime_ms > 0: + speedup = baseline_runtime_ms / runtime_ms + + error_message = None + for key in ("runtime_error", "compilation_error", "correctness_issue", "max_difference"): + if metadata.get(key): + error_message = str(metadata[key])[:1000] + break + + _emit({ + "format_ok": True, + "compiled": bool(result.compiled), + "correctness": bool(result.correctness), + "tests_passed": tests_passed, + "tests_total": num_correct_trials, + "speedup": speedup, + "runtime_ms": runtime_ms, + "baseline_runtime_ms": baseline_runtime_ms, + "error_message": error_message, + "code_length": len(kernel_code), + "metadata": metadata, + }) +except Exception as e: + _emit({ + "format_ok": True, + "compiled": bool(getattr(result, "compiled", False)), + "correctness": bool(getattr(result, "correctness", False)), + "tests_passed": 0, "tests_total": num_correct_trials, + "speedup": None, "runtime_ms": None, "baseline_runtime_ms": None, + "error_message": f"Worker result-marshalling error: {e}", + "code_length": len(kernel_code), + "metadata": {"traceback": traceback.format_exc()[-1000:]}, + }) +''' + + +@dataclass +class LocalEvaluatorConfig: + """Configuration for the local subprocess-isolated kernel evaluator.""" + + enabled: bool = True + timeout: int = 300 # seconds per kernel subprocess + # gpu_arch is passed to set_gpu_arch() inside the worker. Must be one of + # the architectures KernelBench's set_gpu_arch() recognizes: + # ["gfx950"] -> MI350X + # ["gfx942"] -> MI300X + # ["Ampere"] -> A100 + # ["Hopper"] -> H100 + gpu_arch: list[str] = field(default_factory=list) + max_concurrent: int = 1 # Serialize subprocesses by default (1 GPU) + return_exceptions: bool = True + + +def _make_default_result(num_correct_trials: int, kernel_code: str, error: str) -> dict[str, Any]: + return { + "format_ok": True, + "compiled": False, + "correctness": False, + "tests_passed": 0, + "tests_total": num_correct_trials, + "speedup": None, + "runtime_ms": None, + "baseline_runtime_ms": None, + "error_message": error, + "code_length": len(kernel_code), + "metadata": {}, + } + + +def _extract_result(stdout: str) -> dict[str, Any] | None: + """Pull the JSON payload between the sentinel markers from worker stdout. + + Returns None on "no markers found". Raises ValueError on "markers found + but JSON inside is malformed" so the caller can surface the actual error + rather than a generic "no result" message. + """ + begin = stdout.rfind(_RESULT_BEGIN) + if begin == -1: + return None + end = stdout.find(_RESULT_END, begin) + if end == -1: + return None + payload = stdout[begin + len(_RESULT_BEGIN):end].strip() + try: + return cast(dict[str, Any], json.loads(payload)) + except json.JSONDecodeError as e: + raise ValueError(f"Worker emitted malformed JSON: {e}; payload={payload[:200]!r}") + + +class LocalKernelEvaluator: + """ + Kernel evaluator that runs eval_kernel_against_ref in an isolated subprocess. + + Mirrors the public surface of `ModalKernelEvaluator` so that calling code + only sees a generic evaluator interface. Each kernel is dispatched to a + fresh `python -c` worker so torch.utils.cpp_extension cache leaks die with + the worker, not with the long-running RL/eval process. + """ + + def __init__(self, config: LocalEvaluatorConfig | None = None): + self.config = config or LocalEvaluatorConfig() + self._semaphore: asyncio.Semaphore | None = None + + def _get_semaphore(self) -> asyncio.Semaphore: + if self._semaphore is None: + self._semaphore = asyncio.Semaphore(max(1, self.config.max_concurrent)) + return self._semaphore + + async def _run_subprocess(self, request: dict[str, Any]) -> dict[str, Any]: + """Spawn one worker subprocess, ship the request via stdin, parse the result. + + Callers (`evaluate_single` / `evaluate_batch`) are responsible for + injecting `gpu_arch` into the request — there's no fallback here so + that the source of truth stays in one place. + """ + try: + proc = await asyncio.create_subprocess_exec( + sys.executable, + "-c", + _WORKER_SCRIPT, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + except Exception as e: + return _make_default_result( + request.get("num_correct_trials", 5), + request.get("kernel_code", ""), + f"Failed to spawn worker subprocess: {e}", + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + proc.communicate(json.dumps(request).encode("utf-8")), + timeout=self.config.timeout, + ) + except asyncio.TimeoutError: + try: + proc.kill() + await proc.wait() + except ProcessLookupError: + pass + return _make_default_result( + request.get("num_correct_trials", 5), + request.get("kernel_code", ""), + f"Subprocess timeout after {self.config.timeout}s", + ) + + stdout = stdout_bytes.decode("utf-8", errors="replace") + stderr = stderr_bytes.decode("utf-8", errors="replace") + + try: + result = _extract_result(stdout) + except ValueError as e: + # Markers found but JSON inside is malformed — surface the parser's + # diagnostic (it already includes a 200-char payload preview) so + # debugging doesn't require digging through the worker stdout. + return _make_default_result( + request.get("num_correct_trials", 5), + request.get("kernel_code", ""), + f"Worker emitted malformed result: {e}", + ) + if result is None: + tail = (stderr or stdout)[-500:] + return _make_default_result( + request.get("num_correct_trials", 5), + request.get("kernel_code", ""), + f"Worker did not emit result. tail={tail!r}", + ) + return result + + async def evaluate_single( + self, + ref_code: str, + kernel_code: str, + backend: str = "triton", + num_correct_trials: int = 5, + measure_performance: bool = False, + num_perf_trials: int = 100, + precision: str = "fp32", + timing_method: str = "cuda_event", + check_for_excessive_speedup: bool = True, + excessive_speedup_threshold: float = 10.0, + ) -> dict[str, Any]: + if not self.config.enabled: + raise RuntimeError("Local evaluator is disabled") + + request = { + "ref_code": ref_code, + "kernel_code": kernel_code, + "backend": backend, + "num_correct_trials": num_correct_trials, + "measure_performance": measure_performance, + "num_perf_trials": num_perf_trials, + "precision": precision, + "timing_method": timing_method, + "check_for_excessive_speedup": check_for_excessive_speedup, + "excessive_speedup_threshold": excessive_speedup_threshold, + "gpu_arch": self.config.gpu_arch, + } + + async with self._get_semaphore(): + return await self._run_subprocess(request) + + async def evaluate_batch( + self, + evaluations: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + if not evaluations: + return [] + if not self.config.enabled: + raise RuntimeError("Local evaluator is disabled") + + sem = self._get_semaphore() + + async def _one(e: dict[str, Any]) -> dict[str, Any]: + request = { + "ref_code": e["ref_code"], + "kernel_code": e["kernel_code"], + "backend": e.get("backend", "triton"), + "num_correct_trials": e.get("num_correct_trials", 5), + "measure_performance": e.get("measure_performance", False), + "num_perf_trials": e.get("num_perf_trials", 100), + "precision": e.get("precision", "fp32"), + "timing_method": e.get("timing_method", "cuda_event"), + "check_for_excessive_speedup": e.get("check_for_excessive_speedup", True), + "excessive_speedup_threshold": e.get("excessive_speedup_threshold", 10.0), + "gpu_arch": self.config.gpu_arch, + } + async with sem: + try: + return await self._run_subprocess(request) + except Exception as exc: + if self.config.return_exceptions: + return _make_default_result( + e.get("num_correct_trials", 5), + e["kernel_code"], + f"Local evaluation failed: {exc}", + ) + raise + + return await asyncio.gather(*(_one(e) for e in evaluations)) + + async def evaluate_single_batched( + self, + ref_code: str, + kernel_code: str, + backend: str = "triton", + num_correct_trials: int = 5, + measure_performance: bool = False, + num_perf_trials: int = 100, + precision: str = "fp32", + timing_method: str = "cuda_event", + check_for_excessive_speedup: bool = True, + excessive_speedup_threshold: float = 10.0, + ) -> dict[str, Any]: + # No batching benefit on a single local GPU; just delegate. + return await self.evaluate_single( + ref_code=ref_code, + kernel_code=kernel_code, + backend=backend, + num_correct_trials=num_correct_trials, + measure_performance=measure_performance, + num_perf_trials=num_perf_trials, + precision=precision, + timing_method=timing_method, + check_for_excessive_speedup=check_for_excessive_speedup, + excessive_speedup_threshold=excessive_speedup_threshold, + ) + + +# Global evaluator instance (lazy initialized) +_global_local_evaluator: LocalKernelEvaluator | None = None + + +def get_local_evaluator( + config: LocalEvaluatorConfig | None = None, +) -> LocalKernelEvaluator: + """Get or create the global local evaluator instance.""" + global _global_local_evaluator + if _global_local_evaluator is None: + _global_local_evaluator = LocalKernelEvaluator(config) + elif config is not None: + _global_local_evaluator.config = config + return _global_local_evaluator + + +def set_local_evaluator(evaluator: LocalKernelEvaluator) -> None: + """Set the global local evaluator instance.""" + global _global_local_evaluator + _global_local_evaluator = evaluator diff --git a/src/kernelbench_tinker/modal/app.py b/src/kernelbench_tinker/modal/app.py index ce87066..825b987 100644 --- a/src/kernelbench_tinker/modal/app.py +++ b/src/kernelbench_tinker/modal/app.py @@ -28,7 +28,6 @@ import modal - # ============================================================================= # GPU Architecture Mapping # ============================================================================= @@ -189,9 +188,10 @@ def evaluate( """ import tempfile import time - + import modal.experimental import torch + from kernelbench.eval import eval_kernel_against_ref, get_torch_dtype_from_string from kernelbench.utils import set_gpu_arch diff --git a/src/kernelbench_tinker/scripts/eval_kernel_rl.py b/src/kernelbench_tinker/scripts/eval_kernel_rl.py index 07ba2a6..f932d9e 100644 --- a/src/kernelbench_tinker/scripts/eval_kernel_rl.py +++ b/src/kernelbench_tinker/scripts/eval_kernel_rl.py @@ -18,28 +18,57 @@ import json import logging import os -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass from typing import Any import chz import tinker -from tqdm import tqdm - from tinker_cookbook import renderers, tokenizer_utils from tinker_cookbook.completers import TinkerTokenCompleter +from tqdm import tqdm from kernelbench_tinker.env import setup_environment +from kernelbench_tinker.envs.env_utils import build_system_prompt from kernelbench_tinker.envs.kernelbench_client import ( KernelBenchProblem, evaluate_kernel_async, get_problem_ids, parse_structured_response, ) -from kernelbench_tinker.training.models import get_renderer_name_for_model +from kernelbench_tinker.training.models import ( + KernelBenchTokenCompleter, + get_renderer_name_for_model, +) logger = logging.getLogger(__name__) +def pick_best_sample( + samples: list[dict[str, Any]], +) -> tuple[bool, bool, float | None]: + """Pick the best sample from a list of evaluation results. + + Returns (best_correct, best_compiled, best_speedup). + """ + def speedup_value(sample: dict[str, Any]) -> float: + speedup = sample.get("speedup") + return float(speedup) if isinstance(speedup, (int, float)) else 0.0 + + correct_samples = [s for s in samples if s.get("correctness")] + if correct_samples: + best = max(correct_samples, key=speedup_value) + else: + compiled = [s for s in samples if s.get("compiled")] + best = compiled[0] if compiled else samples[0] + + best_speedup: float | None = None + speedup_obj = best.get("speedup") + if isinstance(speedup_obj, (int, float)): + best_speedup = float(speedup_obj) + + return bool(best.get("correctness")), bool(best.get("compiled")), best_speedup + + @chz.chz class EvalConfig: """Configuration for model evaluation.""" @@ -50,6 +79,7 @@ class EvalConfig: # Evaluation configuration level: int = 1 + levels: list[int] | None = None # Iterate multiple levels (overrides level when set) start_problem: int | None = None end_problem: int | None = None backend: str = "triton" @@ -58,6 +88,8 @@ class EvalConfig: # Generation configuration max_tokens: int = 4096 temperature: float = 0.0 # Greedy for eval + top_p: float = 1.0 # Nucleus sampling (1.0 = disabled) + seed: int | None = None # Random seed for generation (null = random) num_samples: int = 1 # Samples per problem # Evaluation settings @@ -69,10 +101,18 @@ class EvalConfig: check_for_excessive_speedup: bool = True excessive_speedup_threshold: float = 10.0 + # Evaluator backend selection: "modal" (cloud NVIDIA) or "local" (in-process subprocess for AMD/HIP) + evaluator_backend: str = "modal" + # Modal configuration modal_gpu_type: str = "A100" modal_timeout: float = 120.0 + # Local evaluator configuration (used when evaluator_backend == "local") + # gpu_arch is passed to set_gpu_arch() — e.g. ["gfx950"] for MI350X, ["gfx942"] for MI300X + gpu_arch: list[str] | None = None + local_timeout: float = 300.0 + # Prompt configuration prompt_option: str = "one_shot" prompt_include_hardware: bool = False @@ -86,6 +126,12 @@ class EvalConfig: tensorboard_log_dir: str | None = None # If provided, log eval metrics to TensorBoard tensorboard_step: int = 0 # Step to log eval metrics at + # Multi-turn inference + multiturn_enabled: bool = False + multiturn_max_turns: int = 8 + inject_think_token: bool = False # Append \n to generation prompts + prompt_max_tokens: int | None = None # Token budget for history truncation + # Tinker API base_url: str | None = None @@ -109,9 +155,9 @@ async def generate_kernel( temperature: float, ) -> str: """Generate a kernel for a problem.""" - # Build prompt + # Build prompt (same system prompt as training env) messages = [ - {"role": "system", "content": "You are an expert GPU kernel developer."}, + {"role": "system", "content": build_system_prompt(problem.backend)}, {"role": "user", "content": problem.prompt}, ] observation = renderer.build_generation_prompt(messages) @@ -185,49 +231,114 @@ async def evaluate_problem( **eval_result, }) - def speedup_value(sample: dict[str, Any]) -> float: - speedup = sample.get("speedup") - return float(speedup) if isinstance(speedup, (int, float)) else 0.0 + best_correct, best_compiled, best_speedup = pick_best_sample(samples) - # Find best result - correct_samples = [s for s in samples if s.get("correctness")] - if correct_samples: - # Best by speedup - best = max(correct_samples, key=speedup_value) - else: - # Best by compilation - compiled = [s for s in samples if s.get("compiled")] - best = compiled[0] if compiled else samples[0] + return EvalResult( + level=problem.level, + problem_id=problem.problem_id, + samples=samples, + best_correct=best_correct, + best_compiled=best_compiled, + best_speedup=best_speedup, + ) - best_speedup: float | None = None - speedup_obj = best.get("speedup") - if isinstance(speedup_obj, (int, float)): - best_speedup = float(speedup_obj) + +async def evaluate_problem_multiturn( + sampling_client: tinker.SamplingClient, + problem: KernelBenchProblem, + renderer: renderers.Renderer, + cfg: EvalConfig, + tokenizer: object | None = None, +) -> EvalResult: + """Evaluate a single problem using multi-turn refinement. + + Reuses MultiTurnKernelBenchEnv so history truncation, feedback + construction, and think-token injection stay in one place. + """ + from kernelbench_tinker.config.configs import EvalConfig as KernelEvalConfig + from kernelbench_tinker.envs.multiturn_kernelbench_env import MultiTurnKernelBenchEnv + + eval_config = KernelEvalConfig( + num_correct_trials=cfg.num_correct_trials, + measure_performance=cfg.measure_performance, + num_perf_trials=cfg.num_perf_trials, + timing_method=cfg.timing_method, + precision=cfg.precision, + check_for_excessive_speedup=cfg.check_for_excessive_speedup, + excessive_speedup_threshold=cfg.excessive_speedup_threshold, + evaluator_backend=cfg.evaluator_backend, + modal_gpu_type=cfg.modal_gpu_type, + modal_timeout=cfg.modal_timeout, + gpu_arch=list(cfg.gpu_arch) if cfg.gpu_arch else [], + local_timeout=cfg.local_timeout, + ) + + env = MultiTurnKernelBenchEnv( + problem=problem, + renderer=renderer, + max_turns=cfg.multiturn_max_turns, + eval_config=eval_config, + system_prompt=build_system_prompt(problem.backend), + tokenizer=tokenizer, + prompt_max_tokens=cfg.prompt_max_tokens, + inject_think_token=cfg.inject_think_token, + ) + + completer = KernelBenchTokenCompleter( + sampling_client, + max_tokens=cfg.max_tokens, + temperature=cfg.temperature if cfg.num_samples == 1 else 1.0, + top_p=cfg.top_p, + seed=cfg.seed, + ) + + observation, stop = await env.initial_observation() + samples: list[dict[str, Any]] = [] + + for turn in range(cfg.multiturn_max_turns): + result = await completer(observation, stop) + step_result = await env.step(result.tokens) + m = step_result.metrics + + kernel_code = env.state.history[-1]["kernel"] + if cfg.max_kernel_code_chars is not None and len(kernel_code) > cfg.max_kernel_code_chars: + kernel_code = kernel_code[: cfg.max_kernel_code_chars] + "..." + + samples.append({ + "sample_id": turn, + "turn": turn, + "kernel_code": kernel_code, + "format_ok": bool(m.get("format_ok")), + "compiled": bool(m.get("compiled")), + "correctness": bool(m.get("correctness")), + "tests_passed": m.get("tests_passed", 0), + "tests_total": m.get("tests_total", 0), + "speedup": m.get("speedup"), + "runtime_ms": m.get("runtime_ms"), + }) + + if step_result.episode_done: + break + observation = step_result.next_observation + stop = step_result.next_stop_condition + + best_correct, best_compiled, best_speedup = pick_best_sample(samples) return EvalResult( level=problem.level, problem_id=problem.problem_id, samples=samples, - best_correct=bool(best.get("correctness")), - best_compiled=bool(best.get("compiled")), + best_correct=best_correct, + best_compiled=best_compiled, best_speedup=best_speedup, ) async def run_evaluation(cfg: EvalConfig) -> dict[str, Any]: """Run full evaluation.""" - from kernelbench_tinker.modal.evaluator import ( - ModalEvaluatorConfig, - ModalKernelEvaluator, - set_modal_evaluator, - ) - - modal_config = ModalEvaluatorConfig( - enabled=True, - gpu_type=cfg.modal_gpu_type, - timeout=int(cfg.modal_timeout), - ) - set_modal_evaluator(ModalKernelEvaluator(modal_config)) + # Install the appropriate evaluator backend (modal vs local). + from kernelbench_tinker.envs.evaluator_dispatch import set_evaluator_from_eval_config + set_evaluator_from_eval_config(cfg) # Create Tinker client service_client = tinker.ServiceClient(base_url=cfg.base_url) @@ -246,34 +357,46 @@ async def run_evaluation(cfg: EvalConfig) -> dict[str, Any]: tokenizer = tokenizer_utils.get_tokenizer(cfg.model_name) renderer = renderers.get_renderer(renderer_name, tokenizer) - # Get problems - problem_ids = get_problem_ids( - cfg.level, - start=cfg.start_problem, - end=cfg.end_problem, - dataset_src=cfg.dataset_src, - ) + # Iterate one or more levels (the for-loop the user asked for: each task, + # each level, end-to-end on the same configured backend). + active_levels = list(cfg.levels) if cfg.levels else [cfg.level] - problems = [ - KernelBenchProblem( - level=cfg.level, - problem_id=pid, - backend=cfg.backend, + problems: list[KernelBenchProblem] = [] + for lvl in active_levels: + problem_ids = get_problem_ids( + lvl, + start=cfg.start_problem, + end=cfg.end_problem, dataset_src=cfg.dataset_src, - prompt_option=cfg.prompt_option, - prompt_precision=cfg.precision, - prompt_include_hardware=cfg.prompt_include_hardware, - prompt_gpu_name=cfg.prompt_gpu_name, ) - for pid in problem_ids - ] + problems.extend( + KernelBenchProblem( + level=lvl, + problem_id=pid, + backend=cfg.backend, + dataset_src=cfg.dataset_src, + prompt_option=cfg.prompt_option, + prompt_precision=cfg.precision, + prompt_include_hardware=cfg.prompt_include_hardware, + prompt_gpu_name=cfg.prompt_gpu_name, + ) + for pid in problem_ids + ) - logger.info(f"Evaluating {len(problems)} problems from level {cfg.level}") + logger.info( + f"Evaluating {len(problems)} problems across levels={active_levels} backend={cfg.backend}" + ) # Evaluate each problem results = [] for problem in tqdm(problems, desc="Evaluating"): try: + if cfg.multiturn_enabled: + result = await evaluate_problem_multiturn( + sampling_client, problem, renderer, cfg, tokenizer=tokenizer + ) + results.append(result) + continue result = await evaluate_problem( sampling_client, problem, renderer, cfg ) @@ -310,7 +433,10 @@ async def run_evaluation(cfg: EvalConfig) -> dict[str, Any]: "checkpoint_path": cfg.checkpoint_path, "model_name": cfg.model_name, "level": cfg.level, + "levels": list(cfg.levels) if cfg.levels else [cfg.level], "backend": cfg.backend, + "evaluator_backend": cfg.evaluator_backend, + "gpu_arch": list(cfg.gpu_arch) if cfg.gpu_arch else [], "num_samples": cfg.num_samples, }, "metrics": metrics, @@ -333,8 +459,8 @@ def main(): logger.info("Starting KernelBench Evaluation") logger.info(f"Checkpoint: {cfg.checkpoint_path or 'base model'}") - logger.info(f"Level: {cfg.level}") - logger.info(f"Backend: {cfg.backend}") + logger.info(f"Levels: {list(cfg.levels) if cfg.levels else [cfg.level]}") + logger.info(f"Backend: {cfg.backend} (evaluator={cfg.evaluator_backend})") # Run evaluation output = asyncio.run(run_evaluation(cfg)) @@ -353,8 +479,8 @@ def main(): # Log to TensorBoard if specified if cfg.tensorboard_log_dir: # Lazy import to avoid circular imports - from kernelbench_tinker.training.tensorboard_logger import create_tensorboard_logger from kernelbench_tinker.evaluation.eval_kernelbench import EvalResults, ProblemResult + from kernelbench_tinker.training.tensorboard_logger import create_tensorboard_logger tb_logger = create_tensorboard_logger(cfg.tensorboard_log_dir) diff --git a/src/kernelbench_tinker/scripts/train_kernel_rl.py b/src/kernelbench_tinker/scripts/train_kernel_rl.py index 28e96c6..fe2aedd 100644 --- a/src/kernelbench_tinker/scripts/train_kernel_rl.py +++ b/src/kernelbench_tinker/scripts/train_kernel_rl.py @@ -28,7 +28,8 @@ import yaml # type: ignore[import-untyped] from kernelbench_tinker.env import setup_environment -from kernelbench_tinker.training.loop import TrainingConfig, main as train_main +from kernelbench_tinker.training.loop import TrainingConfig +from kernelbench_tinker.training.loop import main as train_main logger = logging.getLogger(__name__) @@ -105,11 +106,16 @@ def main(): cfg = blueprint.make() logger.info("Starting KernelBench RL Training") + logger.info(f"Multi-turn: {'enabled' if cfg.multiturn.enabled else 'disabled'}") logger.info(f"Model: {cfg.model_name}") logger.info(f"Level: {cfg.dataset_builder.level}") logger.info(f"Batch size: {cfg.dataset_builder.batch_size}") logger.info(f"Group size: {cfg.dataset_builder.group_size}") logger.info(f"Log path: {cfg.log_path}") + if cfg.multiturn.enabled: + logger.info(f"Refinement turns per trajectory (n): {cfg.multiturn.max_turns}") + logger.info(f"Parallel trajectories (group_size): {cfg.dataset_builder.group_size}") + logger.info(f"Discount factor (gamma): {cfg.multiturn.gamma}") # Run training asyncio.run(train_main(cfg)) diff --git a/src/kernelbench_tinker/training/loop.py b/src/kernelbench_tinker/training/loop.py index 33bd86c..80d9676 100644 --- a/src/kernelbench_tinker/training/loop.py +++ b/src/kernelbench_tinker/training/loop.py @@ -16,16 +16,15 @@ import asyncio import logging import os -from pathlib import Path import time -from typing import Any +from pathlib import Path +from typing import Any, Sequence import chz import numpy as np import tinker import torch from tinker.types import LossFnType - from tinker_cookbook import checkpoint_utils from tinker_cookbook.completers import TinkerTokenCompleter from tinker_cookbook.rl.data_processing import ( @@ -35,14 +34,28 @@ ) from tinker_cookbook.rl.rollouts import do_group_rollout from tinker_cookbook.rl.types import ( + Env, EnvGroupBuilder, TrajectoryGroup, ) from tinker_cookbook.utils import ml_log from tinker_cookbook.utils.misc_utils import timed +from kernelbench_tinker.config.configs import MultiTurnConfig from kernelbench_tinker.envs.kernelbench_env import KernelBenchDatasetBuilder -from kernelbench_tinker.training.models import get_adam_params +from kernelbench_tinker.envs.multiturn_kernelbench_env import wrap_builders_as_multiturn +from kernelbench_tinker.training.models import ( + KernelBenchTokenCompleter, + build_loss_fn_config, + get_adam_params, +) +from kernelbench_tinker.training.multiturn import ( + apply_discounted_returns_to_trajectories, + compute_multiturn_advantages, + compute_multiturn_trajectory_metrics, + do_multiturn_group_rollout_and_filter, + flatten_multiturn_trajectory_groups, +) from kernelbench_tinker.training.tensorboard_logger import ( TensorBoardConfig, TensorBoardLogger, @@ -50,6 +63,8 @@ ) from kernelbench_tinker.training.trace_logger import TraceLogger, set_trace_logger +logger = logging.getLogger(__name__) + def remove_mask(datum: tinker.Datum) -> tinker.Datum: """Remove mask from datum loss_fn_inputs before sending to forward_backward. @@ -62,8 +77,6 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum: loss_fn_inputs={k: v for k, v in datum.loss_fn_inputs.items() if k != "mask"}, ) -logger = logging.getLogger(__name__) - @chz.chz class TrainingConfig: @@ -83,6 +96,9 @@ class TrainingConfig: default_factory=KernelBenchDatasetBuilder ) + # Multi-turn specific config + multiturn: MultiTurnConfig = chz.field(default_factory=MultiTurnConfig) + # Training configuration num_substeps: int = 1 # Optimizer steps per batch loss_fn: LossFnType = "importance_sampling" @@ -154,7 +170,6 @@ async def do_group_rollout_and_filter( return trajectory_group - def compute_trajectory_metrics( trajectory_groups: list[TrajectoryGroup], taglist: list[list[str]] | None = None, @@ -240,6 +255,8 @@ async def train_step( learning_rate: float, num_substeps: int, loss_fn: LossFnType, + max_grad_norm: float = 0.0, + loss_fn_config: dict[str, float] | None = None, ) -> list[torch.Tensor]: """ Perform a training step with gradient accumulation. @@ -250,6 +267,8 @@ async def train_step( learning_rate: Learning rate num_substeps: Number of optimizer steps loss_fn: Loss function type + max_grad_norm: Maximum gradient norm for clipping (0.0 = no clipping) + loss_fn_config: Optional loss function config (e.g. PPO clip thresholds) Returns: List of training logprobs tensors @@ -262,8 +281,11 @@ async def train_step( batch = data[i : i + substep_size] # Forward-backward pass (remove mask key from datums) + fwd_bwd_kwargs: dict[str, Any] = {"loss_fn": loss_fn} + if loss_fn_config is not None: + fwd_bwd_kwargs["loss_fn_config"] = loss_fn_config fwd_bwd_future = await training_client.forward_backward_async( - [remove_mask(d) for d in batch], loss_fn=loss_fn + [remove_mask(d) for d in batch], **fwd_bwd_kwargs ) fwd_bwd_result = await fwd_bwd_future.result_async() @@ -272,7 +294,7 @@ async def train_step( training_logprobs.append(output["logprobs"].to_torch()) # Optimizer step - adam_params = get_adam_params(learning_rate) + adam_params = get_adam_params(learning_rate, max_grad_norm=max_grad_norm) optim_future = await training_client.optim_step_async(adam_params) await optim_future.result_async() @@ -335,6 +357,13 @@ async def run_training_loop( Args: cfg: Training configuration """ + is_multiturn = cfg.multiturn.enabled + if is_multiturn: + logger.info("Running in MULTI-TURN mode") + logger.info(f" max_turns (refinement turns per trajectory): {cfg.multiturn.max_turns}") + logger.info(f" group_size (parallel trajectories, m): {cfg.dataset_builder.group_size}") + logger.info(f" gamma (discount factor): {cfg.multiturn.gamma}") + # Setup logging os.makedirs(cfg.log_path, exist_ok=True) ml_logger = ml_log.setup_logging( @@ -415,12 +444,20 @@ async def run_training_loop( # Create dataset (pass tokenizer for renderer) dataset_builder = cfg.dataset_builder - logger.info("Using KernelBenchDatasetBuilder") + if is_multiturn: + logger.info("Using KernelBenchDatasetBuilder (multi-turn, max_turns=%d)", cfg.multiturn.max_turns) + else: + logger.info("Using KernelBenchDatasetBuilder") train_dataset, test_dataset = await dataset_builder(tokenizer=tokenizer) num_batches = len(train_dataset) logger.info(f"Training on {num_batches} batches") + # Warmup schedule (multi-turn only) + warmup_batches = int(num_batches * cfg.multiturn.warmup_ratio) if is_multiturn else 0 + if warmup_batches > 0: + logger.info(f"Linear LR warmup for {warmup_batches} batches") + # Get initial sampling client sampling_client, _ = await save_checkpoint_and_get_sampling_client( training_client, start_batch, cfg.log_path, cfg.save_every, start_batch @@ -435,56 +472,165 @@ async def run_training_loop( "optim/lr": cfg.learning_rate, } - # Get batch of env group builders + # Get batch of env group builders (always single-turn from dataset) env_group_builders = train_dataset.get_batch(batch_idx) - # Collect rollouts (single-turn) - with timed("rollout", metrics): - try: - results = await asyncio.gather(*[ - do_group_rollout_and_filter( - sampling_client, - builder, - max_tokens=cfg.max_tokens, - temperature=cfg.temperature, - do_remove_constant_reward_groups=cfg.remove_constant_reward_groups, + if is_multiturn: + # Wrap single-turn builders as multi-turn + env_group_builders = wrap_builders_as_multiturn( + env_group_builders, cfg.multiturn, tokenizer + ) + + # ----- Multi-turn rollouts ----- + # Response length extension (multi-turn only) + effective_max_tokens = cfg.max_tokens + if ( + cfg.multiturn.max_tokens_extended > 0 + and batch_idx >= cfg.multiturn.max_tokens_extend_after_step + ): + effective_max_tokens = cfg.multiturn.max_tokens_extended + if batch_idx == cfg.multiturn.max_tokens_extend_after_step: + logger.info( + f"Extending max_tokens from {cfg.max_tokens} to " + f"{cfg.multiturn.max_tokens_extended} at step {batch_idx}" ) - for builder in env_group_builders - ], return_exceptions=True) - except Exception: - logger.exception("Group rollout failed during gather") - raise - - # Filter out None (removed constant reward groups) and exceptions - trajectory_groups = [] - for tg in results: - if isinstance(tg, Exception): - logger.error("Group rollout failed", exc_info=tg) - elif tg is not None: - trajectory_groups.append(tg) - if len(trajectory_groups) == 0: - logger.warning(f"Batch {batch_idx}: All groups filtered out, skipping") - continue - - # Compute metrics - traj_metrics = compute_trajectory_metrics(trajectory_groups) - metrics.update(traj_metrics) - - # Compute advantages and assemble training data - with timed("assemble_data", metrics): - advantages = compute_advantages(trajectory_groups) - data, _metadata = assemble_training_data(trajectory_groups, advantages) - - # Training step - with timed("train", metrics): - await train_step( - data, - training_client, - cfg.learning_rate, - cfg.num_substeps, - cfg.loss_fn, + with timed("rollout", metrics): + try: + results = await asyncio.gather( + *[ + do_multiturn_group_rollout_and_filter( + sampling_client, + builder, + max_tokens=effective_max_tokens, + temperature=cfg.multiturn.temperature, + do_remove_constant_reward_groups=cfg.remove_constant_reward_groups, + top_p=cfg.multiturn.top_p, + seed=cfg.multiturn.seed, + ) + for builder in env_group_builders + ], + return_exceptions=True, + ) + except Exception: + logger.exception("Group rollout failed during gather") + raise + + trajectory_groups: list[TrajectoryGroup] = [] + env_groups: list[Sequence[Env]] = [] + for r in results: + if isinstance(r, BaseException): + logger.error("Group rollout failed", exc_info=r) + continue + tg, envs = r + if tg is not None and envs is not None: + trajectory_groups.append(tg) + env_groups.append(envs) + + if len(trajectory_groups) == 0: + logger.warning( + f"Batch {batch_idx}: All groups filtered out, skipping" + ) + continue + + with timed("discount_returns", metrics): + apply_discounted_returns_to_trajectories( + trajectory_groups, env_groups, + gamma=cfg.multiturn.gamma, + aggregation=cfg.multiturn.aggregation, + ) + + traj_metrics = compute_multiturn_trajectory_metrics( + trajectory_groups, env_groups ) + metrics.update(traj_metrics) + + # Flatten: each turn becomes its own single-transition trajectory + # so that advantage normalization is across all group_size * n turn-level samples + with timed("flatten", metrics): + trajectory_groups = flatten_multiturn_trajectory_groups( + trajectory_groups + ) + + # Compute advantages and assemble training data + with timed("assemble_data", metrics): + advantages = compute_multiturn_advantages(trajectory_groups) + + if cfg.multiturn.constant_length_norm > 0: + for i in range(len(advantages)): + advantages[i] = advantages[i] / cfg.multiturn.constant_length_norm + + data, _metadata = assemble_training_data(trajectory_groups, advantages) + + # Learning rate warmup (multi-turn only) + if warmup_batches > 0 and batch_idx < warmup_batches: + lr = cfg.learning_rate * (batch_idx + 1) / warmup_batches + else: + lr = cfg.learning_rate + metrics["optim/lr"] = lr + + # Training step with PPO clip and grad norm + with timed("train", metrics): + await train_step( + data, + training_client, + lr, + cfg.multiturn.num_substeps, + cfg.multiturn.loss_fn, # type: ignore[arg-type] + max_grad_norm=cfg.multiturn.max_grad_norm, + loss_fn_config=build_loss_fn_config( + cfg.multiturn.clip_epsilon_low, + cfg.multiturn.clip_epsilon_high, + ), + ) + else: + # Collect rollouts (single-turn) + with timed("rollout", metrics): + try: + st_results = await asyncio.gather(*[ + do_group_rollout_and_filter( + sampling_client, + builder, + max_tokens=cfg.max_tokens, + temperature=cfg.temperature, + do_remove_constant_reward_groups=cfg.remove_constant_reward_groups, + ) + for builder in env_group_builders + ], return_exceptions=True) + except Exception: + logger.exception("Group rollout failed during gather") + raise + + # Filter out None (removed constant reward groups) and exceptions + trajectory_groups = [] + for tg in st_results: + if isinstance(tg, Exception): + logger.error("Group rollout failed", exc_info=tg) + elif tg is not None: + trajectory_groups.append(tg) + + if len(trajectory_groups) == 0: + logger.warning(f"Batch {batch_idx}: All groups filtered out, skipping") + continue + + # Compute metrics + traj_metrics = compute_trajectory_metrics(trajectory_groups) + metrics.update(traj_metrics) + + # Compute advantages and assemble training data + with timed("assemble_data", metrics): + advantages = compute_advantages(trajectory_groups) + data, _metadata = assemble_training_data(trajectory_groups, advantages) + + # Training step + with timed("train", metrics): + await train_step( + data, + training_client, + cfg.learning_rate, + cfg.num_substeps, + cfg.loss_fn, + ) # Save checkpoint and get new sampling client sampling_client, checkpoint_metrics = await save_checkpoint_and_get_sampling_client( @@ -501,14 +647,25 @@ async def run_training_loop( tb_logger.log_training_metrics(metrics, batch_idx) tb_logger.log_trajectory_histograms(trajectory_groups, batch_idx) tb_logger.log_per_level_metrics(trajectory_groups, batch_idx) - tb_logger.log_advantage_statistics(advantages, batch_idx) - - logger.info( - f"Batch {batch_idx}/{num_batches}: " - f"reward={metrics.get('reward/mean', 0):.3f}, " - f"compile={metrics.get('kernel/compile_rate', 0):.1%}, " - f"correct={metrics.get('kernel/correct_rate', 0):.1%}" - ) + adv_arrays = [a.numpy() if isinstance(a, torch.Tensor) else a for a in advantages] + tb_logger.log_advantage_statistics(adv_arrays, batch_idx) + + if is_multiturn: + logger.info( + f"Batch {batch_idx}/{num_batches}: " + f"raw_score={metrics.get('multiturn/raw_score_mean', 0):.3f}, " + f"compile={metrics.get('multiturn/compile_rate', 0):.1%}, " + f"correct={metrics.get('multiturn/correct_rate', 0):.1%}, " + f"success={metrics.get('multiturn/success_rate', 0):.1%}, " + f"avg_turns={metrics.get('multiturn/avg_turns', 0):.1f}" + ) + else: + logger.info( + f"Batch {batch_idx}/{num_batches}: " + f"reward={metrics.get('reward/mean', 0):.3f}, " + f"compile={metrics.get('kernel/compile_rate', 0):.1%}, " + f"correct={metrics.get('kernel/correct_rate', 0):.1%}" + ) # Save final checkpoint if start_batch < num_batches: diff --git a/src/kernelbench_tinker/training/models.py b/src/kernelbench_tinker/training/models.py index f1d6ee3..a7ede69 100644 --- a/src/kernelbench_tinker/training/models.py +++ b/src/kernelbench_tinker/training/models.py @@ -1,10 +1,19 @@ """ -Minimal model helpers for KernelBench ↔ Tinker integration. +Model and completer helpers for KernelBench ↔ Tinker integration. """ from __future__ import annotations import tinker +from tinker_cookbook.completers import ( + StopCondition, + TokenCompleter, + TokensWithLogprobs, +) + +# --------------------------------------------------------------------------- +# Renderer helpers +# --------------------------------------------------------------------------- def get_renderer_name_for_model(model_name: str) -> str: @@ -28,11 +37,86 @@ def get_renderer_name_for_model(model_name: str) -> str: return "role_colon" -def get_adam_params(learning_rate: float) -> tinker.AdamParams: +# --------------------------------------------------------------------------- +# Optimizer helpers +# --------------------------------------------------------------------------- + + +def get_adam_params( + learning_rate: float, + max_grad_norm: float = 0.0, +) -> tinker.AdamParams: """Get Adam optimizer parameters.""" - return tinker.AdamParams( - learning_rate=learning_rate, - beta1=0.9, - beta2=0.95, - eps=1e-8, - ) + kwargs: dict = { + "learning_rate": learning_rate, + "beta1": 0.9, + "beta2": 0.95, + "eps": 1e-8, + } + if max_grad_norm > 0: + kwargs["grad_clip_norm"] = max_grad_norm + return tinker.AdamParams(**kwargs) + + +# --------------------------------------------------------------------------- +# Token completers +# --------------------------------------------------------------------------- + + +class KernelBenchTokenCompleter(TokenCompleter): + """Token completer with top_p and seed support. + + TinkerTokenCompleter only accepts temperature. This subclass adds top_p + and seed, which the multi-turn training loop and eval script need. + """ + + def __init__( + self, + sampling_client: tinker.SamplingClient, + max_tokens: int, + temperature: float = 1.0, + top_p: float = 1.0, + seed: int | None = None, + ): + self.sampling_client = sampling_client + self.max_tokens = max_tokens + self.temperature = temperature + self.top_p = top_p + self.seed = seed + + async def __call__( + self, model_input: tinker.ModelInput, stop: StopCondition + ) -> TokensWithLogprobs: + sample_result = await self.sampling_client.sample_async( + prompt=model_input, + num_samples=1, + sampling_params=tinker.SamplingParams( + stop=stop, + max_tokens=self.max_tokens, + temperature=self.temperature, + top_p=self.top_p, + seed=self.seed, + ), + ) + sampled_tokens = sample_result.sequences[0].tokens + sampled_logprobs = sample_result.sequences[0].logprobs + assert sampled_logprobs is not None + return TokensWithLogprobs(tokens=sampled_tokens, maybe_logprobs=sampled_logprobs) + + +# --------------------------------------------------------------------------- +# Loss function helpers +# --------------------------------------------------------------------------- + + +def build_loss_fn_config( + clip_epsilon_low: float = 0.0, + clip_epsilon_high: float = 0.0, +) -> dict[str, float] | None: + """Build loss_fn_config for PPO clip thresholds (passed to forward_backward_async).""" + if clip_epsilon_low <= 0: + return None + return { + "clip_low_threshold": 1.0 - clip_epsilon_low, + "clip_high_threshold": 1.0 + clip_epsilon_high, + } diff --git a/src/kernelbench_tinker/training/multiturn.py b/src/kernelbench_tinker/training/multiturn.py new file mode 100644 index 0000000..af7b5c8 --- /dev/null +++ b/src/kernelbench_tinker/training/multiturn.py @@ -0,0 +1,281 @@ +""" +Multi-turn rollout, advantage estimation, and metrics for KernelBench RL. + +These helpers are used by the training loop when multiturn.enabled is True. +Single-turn training does not touch this module. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections import defaultdict +from typing import Any, Sequence + +import numpy as np +import tinker +import torch +from tinker_cookbook.rl.data_processing import remove_constant_reward_groups +from tinker_cookbook.rl.rollouts import do_single_rollout +from tinker_cookbook.rl.types import ( + Env, + EnvGroupBuilder, + Trajectory, + TrajectoryGroup, +) + +from kernelbench_tinker.envs.multiturn_kernelbench_env import MultiTurnKernelBenchEnv +from kernelbench_tinker.training.models import KernelBenchTokenCompleter +from kernelbench_tinker.training.reward import compute_discounted_returns + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Rollouts +# --------------------------------------------------------------------------- + + +async def do_multiturn_group_rollout_and_filter( + sampling_client: tinker.SamplingClient, + env_group_builder: EnvGroupBuilder, + max_tokens: int, + temperature: float, + do_remove_constant_reward_groups: bool, + top_p: float = 1.0, + seed: int | None = None, +) -> tuple[TrajectoryGroup | None, Sequence[Env] | None]: + """Multi-turn rollout that returns (trajectory_group, envs). + + We can't use do_group_rollout here because it doesn't return the envs, + and we need env access to read per-step scores for discounted returns. + """ + policy = KernelBenchTokenCompleter( + sampling_client, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + seed=seed, + ) + + envs = await env_group_builder.make_envs() + rollout_results = await asyncio.gather( + *[do_single_rollout(policy, env) for env in envs], + return_exceptions=True, + ) + + trajectories = [] + valid_envs: list[Env] = [] + for traj, env in zip(rollout_results, envs): + if isinstance(traj, Exception): + logger.warning(f"Rollout failed: {traj}") + else: + trajectories.append(traj) + valid_envs.append(env) + + if not trajectories: + logger.warning("All rollouts in group failed") + return None, None + + # Final rewards are [0.0] because multi-turn rewards live in + # transition.reward (set by env.step) and are later overwritten by + # apply_discounted_returns. TrajectoryGroup.get_total_rewards() sums + # transition.reward + final_reward, so final_reward must be zero. + trajectory_group = TrajectoryGroup( + trajectories, + [0.0] * len(trajectories), + [{}] * len(trajectories), + ) + + if do_remove_constant_reward_groups: + trajectory_groups = remove_constant_reward_groups([trajectory_group]) + if len(trajectory_groups) == 0: + return None, None + trajectory_group = trajectory_groups[0] + + return trajectory_group, valid_envs + + +# --------------------------------------------------------------------------- +# Discounted returns +# --------------------------------------------------------------------------- + + +def apply_discounted_returns_to_trajectories( + trajectory_groups: list[TrajectoryGroup], + env_groups: list[Sequence[Env]], + gamma: float, + aggregation: str = "sum", +) -> None: + """Replace per-step rewards with discounted returns for multi-turn training.""" + for tg, envs in zip(trajectory_groups, env_groups): + for traj, env in zip(tg.trajectories_G, envs): + if isinstance(env, MultiTurnKernelBenchEnv): + step_scores = env.get_step_scores() + else: + step_scores = [t.reward for t in traj.transitions] + + if not step_scores: + continue + + returns = compute_discounted_returns(step_scores, gamma, aggregation) + for i, trans in enumerate(traj.transitions): + if i < len(returns): + trans.reward = returns[i] + + +# --------------------------------------------------------------------------- +# Flatten and advantage estimation +# --------------------------------------------------------------------------- + + +def flatten_multiturn_trajectory_groups( + trajectory_groups: list[TrajectoryGroup], +) -> list[TrajectoryGroup]: + """Flatten multi-turn trajectories so each turn is its own single-transition trajectory.""" + flattened = [] + for tg in trajectory_groups: + new_trajectories = [] + for traj in tg.trajectories_G: + for trans in traj.transitions: + new_trajectories.append( + Trajectory(transitions=[trans], final_ob=tinker.ModelInput.empty()) + ) + + # final_rewards must be 0.0 because get_total_rewards() sums + # transition.reward + final_reward. The real rewards already live + # in transition.reward (set by apply_discounted_returns). + new_group = TrajectoryGroup( + new_trajectories, + [0.0] * len(new_trajectories), + [{}] * len(new_trajectories), + ) + flattened.append(new_group) + return flattened + + +def compute_multiturn_advantages( + trajectory_groups: list[TrajectoryGroup], +) -> list[torch.Tensor]: + """GRPO advantage with std normalization. + + Expects flattened trajectory groups (each "trajectory" = one turn). + Normalizes across all m*n samples per problem group. + """ + advantages_P = [] + for tg in trajectory_groups: + rewards = torch.tensor(tg.get_total_rewards()) + mean = rewards.mean() + std = rewards.std() + advantages = (rewards - mean) / (std + 1e-9) + advantages_P.append(advantages) + return advantages_P + + +# --------------------------------------------------------------------------- +# Metrics +# --------------------------------------------------------------------------- + + +def compute_multiturn_trajectory_metrics( + trajectory_groups: list[TrajectoryGroup], + env_groups: list[Sequence[Env]], +) -> dict[str, Any]: + """Compute aggregate metrics for multi-turn trajectories.""" + metrics: dict[str, Any] = {} + + turn_compiled: dict[int, list[float]] = defaultdict(list) + turn_correct: dict[int, list[float]] = defaultdict(list) + + all_rewards = [] + all_num_turns = [] + all_success = [] + all_best_speedup = [] + + all_format_ok = [] + all_compiled = [] + all_correct = [] + all_step_scores = [] + all_eval_times = [] + all_step_times = [] + + for tg, envs in zip(trajectory_groups, env_groups): + rewards = tg.get_total_rewards() + all_rewards.extend(rewards) + + for traj, env in zip(tg.trajectories_G, envs): + traj_speedups = [] + + for trans in traj.transitions: + if trans.metrics: + turn = trans.metrics.get("turn", 0) + compiled = trans.metrics.get("compiled", 0) + correct = trans.metrics.get("correctness", 0) + turn_compiled[turn].append(compiled) + turn_correct[turn].append(correct) + + all_format_ok.append(trans.metrics.get("format_ok", 0)) + all_compiled.append(compiled) + all_correct.append(correct) + + if "step_score" in trans.metrics: + all_step_scores.append(trans.metrics["step_score"]) + if "time/eval" in trans.metrics: + all_eval_times.append(trans.metrics["time/eval"]) + if "time/step_total" in trans.metrics: + all_step_times.append(trans.metrics["time/step_total"]) + if "speedup" in trans.metrics: + traj_speedups.append(trans.metrics["speedup"]) + + if traj_speedups: + all_best_speedup.append(max(traj_speedups)) + + if isinstance(env, MultiTurnKernelBenchEnv): + all_success.append(float(env.state.success)) + all_num_turns.append(env.state.turn_idx) + + if all_rewards: + metrics["reward/discounted_mean"] = float(np.mean(all_rewards)) + metrics["reward/discounted_std"] = float(np.std(all_rewards)) + metrics["reward/discounted_min"] = float(np.min(all_rewards)) + metrics["reward/discounted_max"] = float(np.max(all_rewards)) + + if all_format_ok: + metrics["multiturn/format_rate"] = float(np.mean(all_format_ok)) + if all_compiled: + metrics["multiturn/compile_rate"] = float(np.mean(all_compiled)) + if all_correct: + metrics["multiturn/correct_rate"] = float(np.mean(all_correct)) + if all_format_ok: + failures = [1.0 - (f and c and r) for f, c, r in zip(all_format_ok, all_compiled, all_correct)] + metrics["multiturn/failure_rate"] = float(np.mean(failures)) + if all_step_scores: + metrics["multiturn/raw_score_mean"] = float(np.mean(all_step_scores)) + if all_success: + metrics["multiturn/success_rate"] = float(np.mean(all_success)) + if all_num_turns: + metrics["multiturn/avg_turns"] = float(np.mean(all_num_turns)) + if all_best_speedup: + metrics["multiturn/best_speedup_mean"] = float(np.mean(all_best_speedup)) + if all_eval_times: + metrics["time/eval_mean"] = float(np.mean(all_eval_times)) + if all_step_times: + metrics["time/step_mean"] = float(np.mean(all_step_times)) + + for turn in sorted(turn_compiled.keys()): + if turn_compiled[turn]: + metrics[f"multiturn/turn_{turn}/compile_rate"] = float( + np.mean(turn_compiled[turn]) + ) + if turn_correct[turn]: + metrics[f"multiturn/turn_{turn}/correct_rate"] = float( + np.mean(turn_correct[turn]) + ) + + metrics["batch/num_groups"] = len(trajectory_groups) + metrics["batch/num_trajectories"] = sum( + len(tg.trajectories_G) for tg in trajectory_groups + ) + metrics["batch/total_steps"] = len(all_step_scores) + + return metrics diff --git a/src/kernelbench_tinker/training/reward.py b/src/kernelbench_tinker/training/reward.py index 4d7f9e0..de8df0a 100644 --- a/src/kernelbench_tinker/training/reward.py +++ b/src/kernelbench_tinker/training/reward.py @@ -10,7 +10,7 @@ from __future__ import annotations -import math +import logging import sys from dataclasses import dataclass from pathlib import Path @@ -30,6 +30,8 @@ if TYPE_CHECKING: from kernelbench_tinker.envs.kernelbench_client import KernelEvalResult +logger = logging.getLogger(__name__) + @dataclass class RewardConfig: @@ -58,7 +60,7 @@ class RewardConfig: # Speed reward configuration # Linear speedup: reward = T_baseline / T_kernel # ========================================================================== - speed_baseline: float = 1.0 # Speedup threshold (1.0 = same as baseline) + speed_baseline: float = 0.0 # Speedup threshold (0.0 = raw speedup as reward) speed_scale: float = 1.0 # Linear scaling (not log) speed_max_reward: float = 10.0 # Cap to prevent outliers @@ -98,6 +100,12 @@ class RewardConfig: # Default: all warning checks from static_checker.WARNING_CHECKS static_checker_warnings: list[str] | None = None # None = use defaults (all warning checks) + # ========================================================================== + # Reward clipping configuration + # ========================================================================== + reward_clip_min: float | None = None # Lower bound on total reward (None = no clipping) + reward_clip_max: float | None = None # Upper bound on total reward (None = no clipping) + def format_reward(eval_result: "KernelEvalResult", config: RewardConfig) -> float: """ @@ -189,13 +197,11 @@ def speed_reward( if speedup is None or speedup <= 0: return 0.0 - # Linear speedup, not log-scaled - # If speedup <= baseline (1.0), no speed bonus - if speedup <= config.speed_baseline: + # Linear speedup: reward = speedup for correct kernels + # With default speed_baseline=0.0: 2x speedup = 2.0 reward + if config.speed_baseline > 0 and speedup <= config.speed_baseline: return 0.0 - # Linear reward: speedup - 1.0 (so 2x speedup = 1.0 reward, 3x = 2.0, etc.) - # This matches the default formula where reward = speedup for correct kernels reward = config.speed_scale * (speedup - config.speed_baseline) # Clamp to max to prevent outliers @@ -337,16 +343,11 @@ def compute_reward( ) # Log warnings (don't zero reward) - if warnings: - import logging - logger = logging.getLogger(__name__) - for warning in warnings: - logger.warning(f"Static checker warning: {warning}") - + for warning in warnings: + logger.warning(f"Static checker warning: {warning}") + # Zero reward if errors detected if has_errors: - import logging - logger = logging.getLogger(__name__) for error in errors: logger.error(f"Reward hacking detected (reward set to 0): {error}") return 0.0 @@ -401,6 +402,11 @@ def compute_reward( l_reward = length_reward(eval_result, config) total += config.length_weight * l_reward + # Reward clipping + if config.reward_clip_min is not None: + total = max(total, config.reward_clip_min) + if config.reward_clip_max is not None: + total = min(total, config.reward_clip_max) return total @@ -446,3 +452,35 @@ def compute_reward_breakdown( "static_checker_errors": static_checker_errors, "static_checker_warnings": static_checker_warnings, } + + +def compute_discounted_returns( + step_scores: list[float], + gamma: float = 0.4, + aggregation: str = "sum", +) -> list[float]: + """Compute discounted returns for multi-turn RL. + + sum: R_t = S_t + gamma * R_{t+1} (backward recursion) + max: R_t = max{ gamma^(i-t) * S_i } + """ + if aggregation not in ("sum", "max"): + raise ValueError(f"Unknown aggregation mode: {aggregation!r}. Must be 'sum' or 'max'.") + + if not step_scores: + return [] + + T = len(step_scores) + + if aggregation == "sum": + returns = [0.0] * T + returns[-1] = step_scores[-1] + for t in range(T - 2, -1, -1): + returns[t] = step_scores[t] + gamma * returns[t + 1] + return returns + + # aggregation == "max" + return [ + max(gamma ** (i - t) * step_scores[i] for i in range(t, T)) + for t in range(T) + ] diff --git a/src/kernelbench_tinker/training/tensorboard_logger.py b/src/kernelbench_tinker/training/tensorboard_logger.py index 891aa6b..7eea9b4 100644 --- a/src/kernelbench_tinker/training/tensorboard_logger.py +++ b/src/kernelbench_tinker/training/tensorboard_logger.py @@ -17,9 +17,8 @@ from typing import Any, Sequence import numpy as np -from torch.utils.tensorboard import SummaryWriter - from tinker_cookbook.rl.types import TrajectoryGroup +from torch.utils.tensorboard import SummaryWriter logger = logging.getLogger(__name__) @@ -199,6 +198,41 @@ def log_training_metrics( if "kernel/mean_speedup" in metrics: self.writer.add_scalar("Kernel/MeanSpeedup", metrics["kernel/mean_speedup"], step) + # === Multi-turn Reward Metrics === + # Multi-turn emits reward/discounted_{mean,std,min,max} + if "reward/discounted_mean" in metrics: + self.writer.add_scalar("Reward/DiscountedMean", metrics["reward/discounted_mean"], step) + if "reward/discounted_std" in metrics: + self.writer.add_scalar("Reward/DiscountedStdDev", metrics["reward/discounted_std"], step) + if "reward/discounted_min" in metrics: + self.writer.add_scalar("Reward/DiscountedMin", metrics["reward/discounted_min"], step) + if "reward/discounted_max" in metrics: + self.writer.add_scalar("Reward/DiscountedMax", metrics["reward/discounted_max"], step) + + # === Multi-turn Kernel Quality Metrics === + if "multiturn/format_rate" in metrics: + self.writer.add_scalar("MultiTurn/FormatRate", metrics["multiturn/format_rate"], step) + if "multiturn/compile_rate" in metrics: + self.writer.add_scalar("MultiTurn/CompileRate", metrics["multiturn/compile_rate"], step) + if "multiturn/correct_rate" in metrics: + self.writer.add_scalar("MultiTurn/CorrectRate", metrics["multiturn/correct_rate"], step) + + # === Multi-turn Failure Rate === + if "kernel/failure_rate" in metrics: + self.writer.add_scalar("Kernel/FailureRate", metrics["kernel/failure_rate"], step) + if "multiturn/failure_rate" in metrics: + self.writer.add_scalar("MultiTurn/FailureRate", metrics["multiturn/failure_rate"], step) + + # === Multi-turn Specific Metrics === + if "multiturn/raw_score_mean" in metrics: + self.writer.add_scalar("MultiTurn/RawScoreMean", metrics["multiturn/raw_score_mean"], step) + if "multiturn/success_rate" in metrics: + self.writer.add_scalar("MultiTurn/SuccessRate", metrics["multiturn/success_rate"], step) + if "multiturn/avg_turns" in metrics: + self.writer.add_scalar("MultiTurn/AvgTurns", metrics["multiturn/avg_turns"], step) + if "multiturn/best_speedup_mean" in metrics: + self.writer.add_scalar("MultiTurn/BestSpeedupMean", metrics["multiturn/best_speedup_mean"], step) + def log_trajectory_histograms( self, trajectory_groups: list[TrajectoryGroup], diff --git a/uv.lock b/uv.lock index 19e2d97..b5445b7 100644 --- a/uv.lock +++ b/uv.lock @@ -2,9 +2,12 @@ version = 1 revision = 3 requires-python = ">=3.11" resolution-markers = [ - "python_full_version >= '3.14'", - "python_full_version >= '3.12' and python_full_version < '3.14'", - "python_full_version < '3.12'", + "python_full_version >= '3.14' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform != 'linux'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'linux'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'linux'", + "python_full_version < '3.12' and sys_platform == 'linux'", + "python_full_version < '3.12' and sys_platform != 'linux'", ] [[package]] @@ -1054,13 +1057,39 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" }, +] + +[[package]] +name = "gitpython" +version = "3.1.46" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/b5/59d16470a1f0dfe8c793f9ef56fd3826093fc52b3bd96d6b9d6c26c7e27b/gitpython-3.1.46.tar.gz", hash = "sha256:400124c7d0ef4ea03f7310ac2fbf7151e09ff97f2a3288d64a440c584a29c37f", size = 215371, upload-time = "2026-01-01T15:37:32.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058", size = 208620, upload-time = "2026-01-01T15:37:30.574Z" }, +] + [[package]] name = "grpcio" version = "1.67.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12' and python_full_version < '3.14'", - "python_full_version < '3.12'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'linux'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'linux'", + "python_full_version < '3.12' and sys_platform == 'linux'", + "python_full_version < '3.12' and sys_platform != 'linux'", ] sdist = { url = "https://files.pythonhosted.org/packages/20/53/d9282a66a5db45981499190b77790570617a604a38f3d103d0400974aeb5/grpcio-1.67.1.tar.gz", hash = "sha256:3dc2ed4cabea4dc14d5e708c2b426205956077cc5de419b4d4079315017e9732", size = 12580022, upload-time = "2024-10-29T06:30:07.787Z" } wheels = [ @@ -1098,7 +1127,8 @@ name = "grpcio" version = "1.76.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.14'", + "python_full_version >= '3.14' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform != 'linux'", ] dependencies = [ { name = "typing-extensions", marker = "python_full_version >= '3.14'" }, @@ -1675,7 +1705,7 @@ wheels = [ [[package]] name = "kernelbench" version = "0.2.0.dev0" -source = { directory = "../KernelBench" } +source = { editable = "KernelBench" } dependencies = [ { name = "datasets" }, { name = "einops" }, @@ -1703,6 +1733,7 @@ requires-dist = [ { name = "litellm", extras = ["proxy"] }, { name = "modal" }, { name = "ninja" }, + { name = "nsight-python", marker = "extra == 'gpu'" }, { name = "numpy" }, { name = "nvidia-cutlass-dsl", marker = "extra == 'gpu'" }, { name = "openai" }, @@ -1724,7 +1755,7 @@ provides-extras = ["gpu", "dev"] [[package]] name = "kernelbench-tinker" -version = "0.1.0" +version = "0.2.0" source = { editable = "." } dependencies = [ { name = "chz" }, @@ -1741,6 +1772,7 @@ dependencies = [ { name = "tomli" }, { name = "torch" }, { name = "tqdm" }, + { name = "wandb" }, ] [package.optional-dependencies] @@ -1757,7 +1789,7 @@ requires-dist = [ { name = "chz" }, { name = "datasets" }, { name = "isort", marker = "extra == 'dev'" }, - { name = "kernelbench", directory = "../KernelBench" }, + { name = "kernelbench", editable = "KernelBench" }, { name = "litellm" }, { name = "modal", specifier = ">=0.64.0" }, { name = "mypy", marker = "extra == 'dev'" }, @@ -1771,6 +1803,7 @@ requires-dist = [ { name = "tomli" }, { name = "torch", specifier = ">=2.9.0" }, { name = "tqdm" }, + { name = "wandb" }, ] provides-extras = ["dev"] @@ -4097,6 +4130,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/24/4d91e05817e92e3a61c8a21e08fd0f390f5301f1c448b137c57c4bc6e543/semver-3.0.4-py3-none-any.whl", hash = "sha256:9c824d87ba7f7ab4a1890799cec8596f15c1241cb473404ea1cb0c55e4b04746", size = 17912, upload-time = "2025-01-24T13:19:24.949Z" }, ] +[[package]] +name = "sentry-sdk" +version = "2.53.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d3/06/66c8b705179bc54087845f28fd1b72f83751b6e9a195628e2e9af9926505/sentry_sdk-2.53.0.tar.gz", hash = "sha256:6520ef2c4acd823f28efc55e43eb6ce2e6d9f954a95a3aa96b6fd14871e92b77", size = 412369, upload-time = "2026-02-16T11:11:14.743Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/d4/2fdf854bc3b9c7f55219678f812600a20a138af2dd847d99004994eada8f/sentry_sdk-2.53.0-py2.py3-none-any.whl", hash = "sha256:46e1ed8d84355ae54406c924f6b290c3d61f4048625989a723fd622aab838899", size = 437908, upload-time = "2026-02-16T11:11:13.227Z" }, +] + [[package]] name = "setuptools" version = "80.9.0" @@ -4133,6 +4179,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "smmap" +version = "5.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329, upload-time = "2025-01-02T07:14:40.909Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -4741,6 +4796,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/9a/0962b05b308494e3202d3f794a6e85abe471fe3cafdbcf95c2e8c713aabd/uvloop-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a5c39f217ab3c663dc699c04cbd50c13813e31d917642d459fdcec07555cc553", size = 4660018, upload-time = "2024-10-14T23:38:10.888Z" }, ] +[[package]] +name = "wandb" +version = "0.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "gitpython" }, + { name = "packaging" }, + { name = "platformdirs" }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sentry-sdk" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/60/d94952549920469524b689479c864c692ca47eca4b8c2fe3389b64a58778/wandb-0.25.0.tar.gz", hash = "sha256:45840495a288e34245d69d07b5a0b449220fbc5b032e6b51c4f92ec9026d2ad1", size = 43951335, upload-time = "2026-02-13T00:17:45.515Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/7d/0c131db3ec9deaabbd32263d90863cbfbe07659527e11c35a5c738cecdc5/wandb-0.25.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:5eecb3c7b5e60d1acfa4b056bfbaa0b79a482566a9db58c9f99724b3862bc8e5", size = 23287536, upload-time = "2026-02-13T00:17:20.265Z" }, + { url = "https://files.pythonhosted.org/packages/c3/95/31bb7f76a966ec87495e5a72ac7570685be162494c41757ac871768dbc4f/wandb-0.25.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:daeedaadb183dc466e634fba90ab2bab1d4e93000912be0dee95065a0624a3fd", size = 25196062, upload-time = "2026-02-13T00:17:23.356Z" }, + { url = "https://files.pythonhosted.org/packages/d9/a1/258cdedbf30cebc692198a774cf0ef945b7ed98ee64bdaf62621281c95d8/wandb-0.25.0-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:5e0127dbcef13eea48f4b84268da7004d34d3120ebc7b2fa9cefb72b49dbb825", size = 22799744, upload-time = "2026-02-13T00:17:26.437Z" }, + { url = "https://files.pythonhosted.org/packages/de/91/ec9465d014cfd199c5b2083d271d31b3c2aedeae66f3d8a0712f7f54bdf3/wandb-0.25.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:6c4c38077836f9b7569a35b0e1dcf1f0c43616fcd936d182f475edbfea063665", size = 25262839, upload-time = "2026-02-13T00:17:28.8Z" }, + { url = "https://files.pythonhosted.org/packages/c7/95/cb2d1c7143f534544147fb53fe87944508b8cb9a058bc5b6f8a94adbee15/wandb-0.25.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6edd8948d305cb73745bf564b807bd73da2ccbd47c548196b8a362f7df40aed8", size = 22853714, upload-time = "2026-02-13T00:17:31.68Z" }, + { url = "https://files.pythonhosted.org/packages/d7/94/68163f70c1669edcf130822aaaea782d8198b5df74443eca0085ec596774/wandb-0.25.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:ada6f08629bb014ad6e0a19d5dec478cdaa116431baa3f0a4bf4ab8d9893611f", size = 25358037, upload-time = "2026-02-13T00:17:34.676Z" }, + { url = "https://files.pythonhosted.org/packages/cc/fb/9578eed2c01b2fc6c8b693da110aa9c73a33d7bb556480f5cfc42e48c94e/wandb-0.25.0-py3-none-win32.whl", hash = "sha256:020b42ca4d76e347709d65f59b30d4623a115edc28f462af1c92681cb17eae7c", size = 24604118, upload-time = "2026-02-13T00:17:37.641Z" }, + { url = "https://files.pythonhosted.org/packages/25/97/460f6cb738aaa39b4eb2e6b4c630b2ae4321cdd70a79d5955ea75a878981/wandb-0.25.0-py3-none-win_amd64.whl", hash = "sha256:78307ac0b328f2dc334c8607bec772851215584b62c439eb320c4af4fb077a00", size = 24604122, upload-time = "2026-02-13T00:17:39.991Z" }, + { url = "https://files.pythonhosted.org/packages/27/6c/5847b4dda1dfd52630dac08711d4348c69ed657f0698fc2d949c7f7a6622/wandb-0.25.0-py3-none-win_arm64.whl", hash = "sha256:c6174401fd6fb726295e98d57b4231c100eca96bd17de51bfc64038a57230aaf", size = 21785298, upload-time = "2026-02-13T00:17:42.475Z" }, +] + [[package]] name = "watchfiles" version = "1.1.1"