diff --git a/docs/content/docs/configuration/config.mdx b/docs/content/docs/configuration/config.mdx
index 51eae082fa..c95ccc745c 100644
--- a/docs/content/docs/configuration/config.mdx
+++ b/docs/content/docs/configuration/config.mdx
@@ -118,7 +118,7 @@ run_name: "test_run"
log_path: "/tmp/skyrl-logs"
dump_data_batch: false
dump_eval_results: true
-log_example_interval: 1
+print_example_interval: 1
```
@@ -128,7 +128,7 @@ log_example_interval: 1
- `log_path`: Path for infrastructure log files. Infrastructure logs (vLLM engine startup, model loading, worker initialization) are written to `{log_path}/infra-YYMMDD_HHMMSS.log`. For multi-node training, use a shared filesystem path to consolidate logs into a single file. See the [logging guide](../checkpointing-logging/logging) for details.
- `dump_data_batch`: Whether to dump the data batch to a file. This is useful for debugging. When `true`, the data batch will be dumped to a file in the `export_path` directory. The training batch at global step `N` is saved to `self.cfg.trainer.export_path / "dumped_data" / global_step_N_training_input`
- `dump_eval_results`: Whether to dump the evaluation results to a file. When `true`, the full evaluation results will be dumped to a file in the `export_path` directory. The evaluation results at global step `N` is saved to `self.cfg.trainer.export_path / "dumped_eval" / global_step_N_eval_results`
-- `log_example_interval`: Log an example prompt every N training steps, `0`/`-1` to disable.
+- `print_example_interval`: Pretty-print an example prompt/response/reward to stdout every N training steps, `0`/`-1` to disable.
## Training Backends
diff --git a/examples/train/gsm8k/run_gsm8k.sh b/examples/train/gsm8k/run_gsm8k.sh
index 973b0d2595..8695b0b759 100755
--- a/examples/train/gsm8k/run_gsm8k.sh
+++ b/examples/train/gsm8k/run_gsm8k.sh
@@ -20,7 +20,8 @@ uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
- trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \
+ trainer.policy.model.path="Qwen/Qwen3-0.6B" \
+ trainer.num_logger_eval_samples=10 \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py
index e4ac7bb655..ac49b23d9d 100644
--- a/skyrl/train/config/config.py
+++ b/skyrl/train/config/config.py
@@ -669,8 +669,21 @@ class TrainerConfig(BaseConfig):
dump_eval_results: bool = True
rope_scaling: Optional[Dict[str, Any]] = None
rope_theta: Optional[float] = None
- log_example_interval: int = 1
- """Log an example prompt every N training steps, ``0``/``-1`` to disable"""
+ print_example_interval: int = 1
+ """Pretty-print an example prompt/response/reward to stdout every N
+ training steps; ``0``/``-1`` disables. Renamed from ``log_example_interval``."""
+ num_logger_eval_samples: int = -1
+ """Number of evaluation trajectory (prompt, response, score) tuples to upload to a wandb
+ table on each eval. ``-1`` (default) or ``0`` disables. When positive,
+ up to this many samples are taken from the start of each eval pass and
+ logged via :class:`TrajectoryLogger`. Column count is fixed
+ by the first call, so keep the eval set size and this value stable."""
+ num_logger_train_samples: int = -1
+ """Number of training trajectory (prompt, response, score) tuples to upload to a wandb
+ table on each training step. ``-1`` (default) or ``0`` disables. When positive,
+ up to this many samples are taken from the start of each training step and
+ logged via :class:`TrajectoryLogger`. Column count is fixed
+ by the first call, so keep the training set size and this value stable."""
logprobs_chunk_size: Optional[int] = 1024
"""Chunk size along the sequence dimension when computing log-probs from logits.
This lowers peak GPU memory at the cost of ~2x wall-clock time.
diff --git a/skyrl/train/entrypoints/main_base.py b/skyrl/train/entrypoints/main_base.py
index 4aa53acac2..9c7bd3df01 100644
--- a/skyrl/train/entrypoints/main_base.py
+++ b/skyrl/train/entrypoints/main_base.py
@@ -29,6 +29,7 @@
from skyrl.train.trainer import RayPPOTrainer
from skyrl.train.utils import validate_cfg
from skyrl.train.utils.tracking import Tracking
+from skyrl.train.utils.trajectory_logging import TrajectoryLogger
from skyrl.train.utils.utils import (
ResolvedPlacementGroup,
get_ray_pg_ready_with_timeout,
@@ -278,10 +279,24 @@ def get_tracker(self):
return Tracking(
project_name=self.cfg.trainer.project_name,
experiment_name=self.cfg.trainer.run_name,
- backends=self.cfg.trainer.logger,
+ backend=self.cfg.trainer.logger,
config=self.cfg,
)
+ def get_trajectory_logger(self) -> TrajectoryLogger:
+ """Initializes the trajectory logger used during eval (and optionally
+ training) to upload (prompt, response, reward) samples to wandb.
+
+ Override in a subclass to swap in a project-specific
+ :class:`TrajectoryLogger` (custom columns, trajectory rendering, or
+ wandb key). The instance is cheap and statefully accumulates rows
+ across evals via the wandb-Table re-create workaround.
+
+ Returns:
+ TrajectoryLogger: The trajectory logger.
+ """
+ return TrajectoryLogger()
+
def get_inference_client(self) -> InferenceEngineInterface:
"""Setup and return the inference engine client.
@@ -387,6 +402,8 @@ def _setup_trainer(self):
generator=generator,
colocate_pg=self.colocate_pg,
)
+ # Install the trajectory logger after construction
+ trainer.trajectory_logger = self.get_trajectory_logger()
# Expose the trainer on self so callers can log exceptions raised
# during `build_models` (which happens before _setup_trainer returns).
self.trainer = trainer
diff --git a/skyrl/train/evaluate.py b/skyrl/train/evaluate.py
index f24c84f2bd..1e46c2d0b8 100644
--- a/skyrl/train/evaluate.py
+++ b/skyrl/train/evaluate.py
@@ -1,6 +1,9 @@
-from collections import defaultdict
+from collections import Counter, defaultdict
from pathlib import Path
-from typing import Any, Dict, List
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+if TYPE_CHECKING:
+ from skyrl.train.utils.tracking import Tracking
import torch
from loguru import logger
@@ -22,12 +25,12 @@
prepare_generator_input,
)
from skyrl.train.utils import Timer
-from skyrl.train.utils.logging_utils import log_example
from skyrl.train.utils.trainer_utils import (
calculate_per_dataset_metrics,
dump_per_dataset_eval_results,
validate_generator_output,
)
+from skyrl.train.utils.trajectory_logging import TrajectoryLogger, pretty_print_example
@torch.no_grad()
@@ -37,6 +40,8 @@ async def evaluate(
cfg: SkyRLTrainConfig,
global_step: int | None,
tokenizer: AutoTokenizer,
+ trajectory_logger: Optional[TrajectoryLogger] = None,
+ tracker: Optional["Tracking"] = None,
) -> Dict[str, float]:
"""Runs generation and evaluation of trajectories.
@@ -57,6 +62,7 @@ async def evaluate(
concat_all_envs: List[str] = []
concat_env_extras: List[Dict[str, Any]] = []
concat_uids: List[str] = []
+ concat_prompts: List[str] = []
sampling_params = cfg.generator.eval_sampling_params
pbar = tqdm(total=len(eval_dataloader), initial=0, desc="Evaluation Progress")
for _, prompts in enumerate(eval_dataloader):
@@ -75,20 +81,34 @@ async def evaluate(
concat_all_envs.extend(generator_input["env_classes"])
concat_env_extras.extend(generator_input["env_extras"])
concat_uids.extend(uids)
+ concat_prompts.extend(generator_input["prompts"])
concat_generator_outputs: GeneratorOutput = concatenate_generator_outputs(generator_outputs)
# Extract data_sources from env_extras
concat_data_sources = [env_extra.get("data_source") for env_extra in concat_env_extras]
- if cfg.trainer.log_example_interval > 0:
+ if cfg.trainer.print_example_interval > 0:
vis = tokenizer.decode(generator_output["response_ids"][0])
- log_example(
+ pretty_print_example(
logger,
prompt=generator_input["prompts"][0],
response=vis,
reward=generator_output["rewards"][0],
)
+ # Optionally upload up to `num_logger_eval_samples` samples to tracker (wandb)
+ if trajectory_logger is not None:
+ with Timer("log_eval_results"):
+ trajectory_logger.log(
+ tracker=tracker,
+ num_samples=cfg.trainer.num_logger_eval_samples,
+ prompts=concat_prompts,
+ generator_output=concat_generator_outputs,
+ tokenizer=tokenizer,
+ global_step=global_step,
+ wandb_key="trajectories/eval",
+ )
+
# 2. Group data by data source and calculate per-dataset metrics
eval_metrics = calculate_per_dataset_metrics(
concat_generator_outputs, concat_uids, concat_data_sources, cfg.generator.eval_n_samples_per_prompt
@@ -137,6 +157,8 @@ async def evaluate_step_wise(
cfg: SkyRLTrainConfig,
global_step: int | None,
tokenizer: AutoTokenizer,
+ trajectory_logger: Optional[TrajectoryLogger] = None,
+ tracker: Optional["Tracking"] = None,
) -> Dict[str, float]:
"""Runs generation and evaluation of trajectories for step-wise training.
@@ -159,6 +181,7 @@ async def evaluate_step_wise(
concat_all_envs: List[str] = []
concat_env_extras: List[Dict[str, Any]] = []
concat_uids: List[str] = []
+ concat_prompts: List[str] = []
sampling_params = cfg.generator.eval_sampling_params
pbar = tqdm(total=len(eval_dataloader), initial=0, desc="Evaluation Progress")
for _, prompts in enumerate(eval_dataloader):
@@ -173,9 +196,16 @@ async def evaluate_step_wise(
)
generator_output: GeneratorOutput = await generator.generate(generator_input)
traj_id_to_input = {
- traj_id.instance_id: {"env_class": env_class, "env_extras": env_extra}
- for traj_id, env_class, env_extra in zip(
- generator_input["trajectory_ids"], generator_input["env_classes"], generator_input["env_extras"]
+ traj_id.instance_id: {
+ "env_class": env_class,
+ "env_extras": env_extra,
+ "prompt": prompt,
+ }
+ for traj_id, env_class, env_extra, prompt in zip(
+ generator_input["trajectory_ids"],
+ generator_input["env_classes"],
+ generator_input["env_extras"],
+ generator_input["prompts"],
)
}
for traj_id in generator_output["trajectory_ids"]:
@@ -183,6 +213,7 @@ async def evaluate_step_wise(
concat_all_envs.append(traj_id_to_input[traj_id.instance_id]["env_class"])
concat_env_extras.append(traj_id_to_input[traj_id.instance_id]["env_extras"])
concat_uids.append(traj_id.instance_id)
+ concat_prompts.append(traj_id_to_input[traj_id.instance_id]["prompt"])
validate_generator_output(generator_input, generator_output, step_wise=True)
generator_outputs.append(generator_output)
concat_generator_outputs: GeneratorOutput = concatenate_generator_outputs(generator_outputs)
@@ -190,7 +221,7 @@ async def evaluate_step_wise(
# Extract data_sources from env_extras
concat_data_sources = [env_extra.get("data_source") for env_extra in concat_env_extras]
- if cfg.trainer.log_example_interval > 0:
+ if cfg.trainer.print_example_interval > 0:
vis = tokenizer.decode(generator_output["response_ids"][0])
logger.info(f"Eval output example: {vis}")
@@ -209,6 +240,24 @@ async def evaluate_step_wise(
data_sources_last_step = [
data_source for data_source, is_last_step in zip(concat_data_sources, is_last_step_mask) if is_last_step
]
+ prompts_last_step = [prompt for prompt, is_last_step in zip(concat_prompts, is_last_step_mask) if is_last_step]
+
+ # Optionally upload up to `num_logger_eval_samples` samples to wandb.
+ # For step-wise we override the logger's default loss-mask-based
+ # num_turns with the total step count per trajectory (counted *before*
+ # the last-step filter).
+ if trajectory_logger is not None:
+ trajectory_step_counts = Counter(concat_uids)
+ trajectory_logger.log(
+ tracker=tracker,
+ num_samples=cfg.trainer.num_logger_eval_samples,
+ prompts=prompts_last_step,
+ generator_output=generator_output_last_step,
+ tokenizer=tokenizer,
+ global_step=global_step,
+ num_turns_list=[trajectory_step_counts[uid] for uid in uids_last_step],
+ wandb_key="trajectories/eval",
+ )
# 2. Group data by data source and calculate per-dataset metrics
eval_metrics = calculate_per_dataset_metrics(
diff --git a/skyrl/train/sft_trainer.py b/skyrl/train/sft_trainer.py
index cfb3443c26..3f82705866 100644
--- a/skyrl/train/sft_trainer.py
+++ b/skyrl/train/sft_trainer.py
@@ -764,7 +764,7 @@ def _init_tracker(self):
self.tracker = Tracking(
project_name=self.cfg.trainer.project_name,
experiment_name=self.cfg.trainer.run_name,
- backends=self.cfg.trainer.logger,
+ backend=self.cfg.trainer.logger,
config=self.sft_cfg,
)
diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py
index a6eb2109c6..22763667be 100644
--- a/skyrl/train/trainer.py
+++ b/skyrl/train/trainer.py
@@ -67,7 +67,6 @@
get_ray_pg_ready_with_timeout,
trainer_utils,
)
-from skyrl.train.utils.logging_utils import log_example
from skyrl.train.utils.ray_gpu_monitor import RayGpuMonitor
from skyrl.train.utils.tracking import Tracking
from skyrl.train.utils.trainer_utils import (
@@ -82,6 +81,7 @@
validate_generator_output,
zero_variance_filter,
)
+from skyrl.train.utils.trajectory_logging import TrajectoryLogger, pretty_print_example
from skyrl.train.utils.utils import ResolvedPlacementGroup, configure_ray_worker_logging
from skyrl.train.utils.vllm_metrics_scraper import VLLMMetricsScraper
@@ -127,6 +127,9 @@ def __init__(
self._ray_gpu_monitor = RayGpuMonitor() if cfg.trainer.enable_ray_gpu_monitor else None
+ # trajectory logger is installed after construction if needed
+ self.trajectory_logger: TrajectoryLogger = None
+
# initialized in `build_models`
self.policy_model: PPORayActorGroup = None
self.critic_model: Optional[PPORayActorGroup] = None
@@ -180,6 +183,8 @@ async def eval(self) -> Dict[str, float]:
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
+ trajectory_logger=self.trajectory_logger,
+ tracker=self.tracker,
)
else:
eval_metrics = await evaluate(
@@ -188,6 +193,8 @@ async def eval(self) -> Dict[str, float]:
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
+ trajectory_logger=self.trajectory_logger,
+ tracker=self.tracker,
)
return eval_metrics
@@ -268,16 +275,29 @@ async def train(self):
with Timer("postprocess_generator_output", self.all_timings):
generator_output, uids = self.postprocess_generator_output(generator_output, uids)
- # 2. print example just for debugging
- log_interval = self.cfg.trainer.log_example_interval
- if log_interval > 0 and self.global_step % log_interval == 0:
+ # 2.1 print example just for debugging
+ print_interval = self.cfg.trainer.print_example_interval
+ if print_interval > 0 and self.global_step % print_interval == 0:
vis = self.tokenizer.decode(generator_output["response_ids"][0])
- log_example(
+ pretty_print_example(
logger,
prompt=generator_input["prompts"][0],
response=vis,
reward=generator_output["rewards"][0],
)
+ # 2.2 Optionally upload up to `num_logger_train_samples` samples to tracker
+ if self.trajectory_logger is not None:
+ with Timer("log_train_results"):
+ self.trajectory_logger.log(
+ tracker=self.tracker,
+ num_samples=self.cfg.trainer.num_logger_train_samples,
+ prompts=generator_input["prompts"],
+ generator_output=generator_output,
+ tokenizer=self.tokenizer,
+ global_step=self.global_step,
+ wandb_key="trajectories/train",
+ include_idx=False,
+ )
# 3. Convert GeneratorOutput to TrainingInputBatch
with Timer("convert_to_training_input", self.all_timings):
diff --git a/skyrl/train/utils/logging_utils.py b/skyrl/train/utils/logging_utils.py
deleted file mode 100644
index 863d31d788..0000000000
--- a/skyrl/train/utils/logging_utils.py
+++ /dev/null
@@ -1,84 +0,0 @@
-from typing import Any, Dict, List, Optional, Union
-
-POSITIVE_RESPONSE_COLOR = "green"
-NEGATIVE_RESPONSE_COLOR = "yellow"
-BASE_PROMPT_COLOR = "cyan"
-
-
-def _color_block_format_and_kwargs(
- text: str,
- color: str,
- field_prefix: str,
-) -> tuple[str, dict]:
- """Build a format string and kwargs for a multi-line colored block.
-
- The format string will look like:
- "{p0}\n{p1}\n..."
-
- where "p0", "p1", ... are placeholder names starting with `field_prefix`.
- """
- # Ensure at least one line
- lines = text.splitlines() or [""]
-
- fmt_lines = []
- kwargs: dict[str, str] = {}
-
- for i, line in enumerate(lines):
- key = f"{field_prefix}{i}"
- # NOTE: double braces {{ }} so that {key} survives into str.format
- fmt_lines.append(f"<{color}>{{{key}}}{color}>")
- kwargs[key] = line
-
- fmt = "\n".join(fmt_lines)
- return fmt, kwargs
-
-
-def log_example(
- logger: Any,
- prompt: List[Dict[str, Any]],
- response: str,
- reward: Optional[Union[float, List[float]]] = None,
-) -> None:
- """
- Log a single example prompt and response with formatting and colors.
-
- Args:
- logger: The logger instance to use (expected to be loguru logger or compatible).
- prompt: The input prompt in OpenAI message format.
- response: The output response string.
- reward: The reward value(s) associated with the response.
- """
- reward_val = 0.0
- reward_str = "N/A"
- try:
- prompt_str = str(prompt)
- response_str = str(response)
- # --- Reward handling ---
- if reward is not None:
- if isinstance(reward, list):
- reward_val = float(sum(reward))
- else:
- reward_val = float(reward)
- reward_str = f"{reward_val:.4f}"
-
- # --- Color selection ---
- if reward is not None and reward_val > 0:
- response_color = POSITIVE_RESPONSE_COLOR
- else:
- response_color = NEGATIVE_RESPONSE_COLOR
-
- # --- Build per-line colored blocks in the *format string* ---
- prompt_fmt, prompt_kwargs = _color_block_format_and_kwargs(prompt_str, BASE_PROMPT_COLOR, "p")
- response_fmt, response_kwargs = _color_block_format_and_kwargs(response_str, response_color, "r")
-
- # Single format string with only our own markup and placeholders
- log_format = "Example:\n" f" Input: {prompt_fmt}\n" " Output (Total Reward: {reward}):\n" f"{response_fmt}"
-
- # Merge all args for str.format
- format_kwargs = {**prompt_kwargs, **response_kwargs, "reward": reward_str}
-
- # Let Loguru parse tags in log_format and then substitute arguments.
- logger.opt(colors=True).info(log_format, **format_kwargs)
- except Exception as e:
- print(f"Error pretty printing example, debug printing instead: {e}")
- print(f"Example:\n Input: {prompt}\n Output (Total Reward: {reward_str}):\n{response}")
diff --git a/skyrl/train/utils/tracking.py b/skyrl/train/utils/tracking.py
index 2fead5342e..f2f6f3dde4 100644
--- a/skyrl/train/utils/tracking.py
+++ b/skyrl/train/utils/tracking.py
@@ -21,7 +21,7 @@
from enum import Enum
from functools import partial
from pathlib import Path
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
from loguru import logger
from omegaconf import DictConfig, OmegaConf
@@ -37,26 +37,20 @@ def __init__(
self,
project_name,
experiment_name,
- backends: Union[str, List[str]] = "console",
+ backend: str = "console",
config: Optional[Union[SkyRLTrainConfig, DictConfig]] = None,
):
- if isinstance(backends, str):
- backends = [backends]
- for backend in backends:
- assert backend in self.supported_backends, f"{backend} is not supported"
+ assert backend in self.supported_backends, f"{backend} is not supported"
+ self.backend = backend
- self.logger = {}
-
- if "wandb" in backends:
+ if backend == "wandb":
import wandb
wandb.init(project=project_name, name=experiment_name, config=get_config_as_dict(config))
- self.logger["wandb"] = wandb
-
- if "mlflow" in backends:
- self.logger["mlflow"] = _MlflowLoggingAdapter(project_name, experiment_name, config)
-
- if "swanlab" in backends:
+ self.logger: Any = wandb
+ elif backend == "mlflow":
+ self.logger = _MlflowLoggingAdapter(project_name, experiment_name, config)
+ elif backend == "swanlab":
import os
import swanlab
@@ -73,44 +67,41 @@ def __init__(
logdir=SWANLAB_LOG_DIR,
mode=SWANLAB_MODE,
)
- self.logger["swanlab"] = swanlab
-
- if "tensorboard" in backends:
- self.logger["tensorboard"] = _TensorboardAdapter()
-
- if "console" in backends:
- self.console_logger = ConsoleLogger()
- self.logger["console"] = self.console_logger
+ self.logger = swanlab
+ elif backend == "tensorboard":
+ self.logger = _TensorboardAdapter()
+ else: # "console"
+ self.logger = ConsoleLogger()
self._exception_logged = False
def log(self, data, step, commit=False):
- for logger_name, logger_instance in self.logger.items():
- if logger_name == "wandb":
- logger_instance.log(data=data, step=step, commit=commit)
- else:
- logger_instance.log(data=data, step=step)
+ if self.backend == "wandb":
+ self.logger.log(data=data, step=step, commit=commit)
+ else:
+ self.logger.log(data=data, step=step)
def finish(self):
- for logger_name, logger_instance in self.logger.items():
- # NOTE (sumanthrh): We use a try-except block here while finishing tracking.
- # This is because wandb often errors out with a BrokenPipeError when closing.
- # https://github.com/wandb/wandb/issues/6449
- try:
- if logger_name == "wandb":
- logger_instance.finish(exit_code=0)
- elif logger_name != "console":
- logger_instance.finish()
- except Exception as e:
- logger.warning(f"Attempted to finish tracking with logger {logger_name} but got error {e}")
+ if self.backend == "console":
+ return
+ # NOTE (sumanthrh): We use a try-except block here while finishing tracking.
+ # This is because wandb often errors out with a BrokenPipeError when closing.
+ # https://github.com/wandb/wandb/issues/6449
+ try:
+ if self.backend == "wandb":
+ self.logger.finish(exit_code=0)
+ else:
+ self.logger.finish()
+ except Exception as e:
+ logger.warning(f"Attempted to finish tracking with backend {self.backend} but got error {e}")
def log_exception(self, e: BaseException, step: int = 0) -> None:
- """Log the active exception's traceback to all configured backends.
+ """Log the active exception's traceback to the configured backend.
Always prints the traceback on the driver via loguru (so it lands in
- Ray driver logs instead of being swallowed). If wandb is configured,
- also logs a row to an `error/tracebacks` wandb.Table and calls
- `finish()` to flush the async upload before the process re-raises.
+ Ray driver logs instead of being swallowed). If the wandb backend is
+ active, also logs a row to an `error/tracebacks` wandb.Table and calls
+ `finish()` to flush the async upload before the caller re-raises.
Ray-wrapped worker errors (e.g. OOMs raised inside actors) include
both local and remote frames in `traceback.format_exc()`.
@@ -123,7 +114,7 @@ def log_exception(self, e: BaseException, step: int = 0) -> None:
self._exception_logged = True
tb_str = traceback.format_exc()[-10000:]
logger.error(f"Training failed at step {step} with {type(e).__name__}:\n{tb_str}")
- if "wandb" in self.logger:
+ if self.backend == "wandb":
try:
import wandb
@@ -132,7 +123,7 @@ def log_exception(self, e: BaseException, step: int = 0) -> None:
# Note: omit `step=` here. Per-step logs use commit=True, so
# re-logging at the same step would be dropped. The step value
# is also embedded in the table row itself.
- self.logger["wandb"].log({"error/tracebacks": error_table})
+ self.logger.log({"error/tracebacks": error_table})
# Tables upload asynchronously. Finish the run so the upload
# completes before the caller re-raises and the process dies.
try:
@@ -142,6 +133,41 @@ def log_exception(self, e: BaseException, step: int = 0) -> None:
except Exception as log_exc:
logger.warning(f"Failed to log exception traceback to wandb: {log_exc}")
+ def log_samples_to_table(
+ self,
+ key: str,
+ columns: List[str],
+ samples: List[Tuple[Any, ...]],
+ step: int,
+ ) -> None:
+ """Append rows to an accumulating wandb table at ``key``.
+
+ Each call extends the existing table at ``key`` (or creates one on
+ first call) with ``samples`` and logs the new table to wandb at the
+ given ``step``. ``columns`` defines the table schema and must stay
+ consistent across calls for the same ``key``; each row in ``samples``
+ must have ``len(columns)`` values in the matching order.
+
+ No-op for non-wandb backends -- only the wandb backend supports
+ ``wandb.Table``.
+ """
+ if self.backend != "wandb":
+ return
+ import wandb
+
+ # Cache one table per key so different callers (e.g. eval vs train
+ # trajectory loggers, error traceback table) don't trample each other.
+ if not hasattr(self, "_sample_tables"):
+ self._sample_tables: Dict[str, Any] = {}
+ if key not in self._sample_tables:
+ self._sample_tables[key] = wandb.Table(columns=columns)
+ # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
+ new_table = wandb.Table(columns=columns, data=self._sample_tables[key].data)
+ for row in samples:
+ new_table.add_data(*row)
+ self.logger.log({key: new_table}, step=step)
+ self._sample_tables[key] = new_table
+
def __del__(self):
try:
self.finish()
@@ -260,87 +286,3 @@ def _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]:
ans = pd.json_normalize(raw, sep=sep).to_dict(orient="records")[0]
assert isinstance(ans, dict)
return ans
-
-
-@dataclasses.dataclass
-class ValidationGenerationsLogger:
- def log(self, loggers, samples, step):
- if "wandb" in loggers:
- self.log_generations_to_wandb(samples, step)
- if "swanlab" in loggers:
- self.log_generations_to_swanlab(samples, step)
- if "mlflow" in loggers:
- self.log_generations_to_mlflow(samples, step)
-
- def log_generations_to_wandb(self, samples, step):
- """Log samples to wandb as a table"""
- import wandb
-
- # Create column names for all samples
- columns = ["step"] + sum(
- [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []
- )
-
- if not hasattr(self, "validation_table"):
- # Initialize the table on first call
- self.validation_table = wandb.Table(columns=columns)
-
- # Create a new table with same columns and existing data
- # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
- new_table = wandb.Table(columns=columns, data=self.validation_table.data)
-
- # Add new row with all data
- row_data = []
- row_data.append(step)
- for sample in samples:
- row_data.extend(sample)
-
- new_table.add_data(*row_data)
-
- # Update reference and log
- wandb.log({"val/generations": new_table}, step=step)
- self.validation_table = new_table
-
- def log_generations_to_swanlab(self, samples, step):
- """Log samples to swanlab as text"""
- import swanlab
-
- swanlab_text_list = []
- for i, sample in enumerate(samples):
- row_text = f"""
- input: {sample[0]}
-
- ---
-
- output: {sample[1]}
-
- ---
-
- score: {sample[2]}
- """
- swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}"))
-
- # Log to swanlab
- swanlab.log({"val/generations": swanlab_text_list}, step=step)
-
- def log_generations_to_mlflow(self, samples, step):
- """Log validation generation to mlflow as artifacts"""
- # https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact
-
- import json
- import tempfile
-
- import mlflow
-
- try:
- with tempfile.TemporaryDirectory() as tmp_dir:
- validation_gen_step_file = Path(tmp_dir, f"val_step{step}.json")
- row_data = []
- for sample in samples:
- data = {"input": sample[0], "output": sample[1], "score": sample[2]}
- row_data.append(data)
- with open(validation_gen_step_file, "w") as file:
- json.dump(row_data, file)
- mlflow.log_artifact(validation_gen_step_file)
- except Exception as e:
- logger.warning(f"save validation generation file to mlflow failed with error {e}")
diff --git a/skyrl/train/utils/trajectory_logging.py b/skyrl/train/utils/trajectory_logging.py
new file mode 100644
index 0000000000..dddd254fd9
--- /dev/null
+++ b/skyrl/train/utils/trajectory_logging.py
@@ -0,0 +1,347 @@
+"""Utils for trajectory logging."""
+
+import dataclasses
+import random
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
+
+if TYPE_CHECKING:
+ from skyrl.train.utils.tracking import Tracking
+
+POSITIVE_RESPONSE_COLOR = "green"
+NEGATIVE_RESPONSE_COLOR = "yellow"
+BASE_PROMPT_COLOR = "cyan"
+
+
+def _color_block_format_and_kwargs(
+ text: str,
+ color: str,
+ field_prefix: str,
+) -> tuple[str, dict]:
+ """Build a format string and kwargs for a multi-line colored block.
+
+ The format string will look like:
+ "{p0}\\n{p1}\\n..."
+
+ where "p0", "p1", ... are placeholder names starting with `field_prefix`.
+ """
+ # Ensure at least one line
+ lines = text.splitlines() or [""]
+
+ fmt_lines = []
+ kwargs: dict[str, str] = {}
+
+ for i, line in enumerate(lines):
+ key = f"{field_prefix}{i}"
+ # NOTE: double braces {{ }} so that {key} survives into str.format
+ fmt_lines.append(f"<{color}>{{{key}}}{color}>")
+ kwargs[key] = line
+
+ fmt = "\n".join(fmt_lines)
+ return fmt, kwargs
+
+
+def pretty_print_example(
+ logger: Any,
+ prompt: List[Dict[str, Any]],
+ response: str,
+ reward: Optional[Union[float, List[float]]] = None,
+) -> None:
+ """
+ Log a single example prompt and response with formatting and colors.
+
+ Args:
+ logger: The logger instance to use (expected to be loguru logger or compatible).
+ prompt: The input prompt in OpenAI message format.
+ response: The output response string.
+ reward: The reward value(s) associated with the response.
+ """
+ reward_val = 0.0
+ reward_str = "N/A"
+ try:
+ prompt_str = str(prompt)
+ response_str = str(response)
+ # --- Reward handling ---
+ if reward is not None:
+ if isinstance(reward, list):
+ reward_val = float(sum(reward))
+ else:
+ reward_val = float(reward)
+ reward_str = f"{reward_val:.4f}"
+
+ # --- Color selection ---
+ if reward is not None and reward_val > 0:
+ response_color = POSITIVE_RESPONSE_COLOR
+ else:
+ response_color = NEGATIVE_RESPONSE_COLOR
+
+ # --- Build per-line colored blocks in the *format string* ---
+ prompt_fmt, prompt_kwargs = _color_block_format_and_kwargs(prompt_str, BASE_PROMPT_COLOR, "p")
+ response_fmt, response_kwargs = _color_block_format_and_kwargs(response_str, response_color, "r")
+
+ # Single format string with only our own markup and placeholders
+ log_format = "Example:\n" f" Input: {prompt_fmt}\n" " Output (Total Reward: {reward}):\n" f"{response_fmt}"
+
+ # Merge all args for str.format
+ format_kwargs = {**prompt_kwargs, **response_kwargs, "reward": reward_str}
+
+ # Let Loguru parse tags in log_format and then substitute arguments.
+ logger.opt(colors=True).info(log_format, **format_kwargs)
+ except Exception as e:
+ print(f"Error pretty printing example, debug printing instead: {e}")
+ print(f"Example:\n Input: {prompt}\n Output (Total Reward: {reward_str}):\n{response}")
+
+
+@dataclasses.dataclass
+class TrajectoryLogger:
+ """Logs rollout samples to tracker backends as a table.
+
+ Accepts a full ``GeneratorOutput``-shaped dict plus a parallel list of
+ prompts, and derives the per-sample fields (``num_turns`` from the loss
+ mask, the trajectory string from prompt + response tokens) internally.
+ A caller with a better signal for ``num_turns`` (e.g. step-wise eval
+ counting trajectory steps) can pass it in explicitly to override.
+
+ Designed for subclassing. Override the granular pieces to customize:
+
+ - :meth:`build_samples` -- which columns each row contains. If you
+ change the tuple shape, also update :attr:`columns`.
+ - :meth:`format_trajectory` -- how the trajectory string is rendered.
+ - :meth:`count_assistant_turns` -- default loss-mask -> turn-count.
+ - :meth:`log` -- top-level dispatch (e.g. add a new backend, change the
+ wandb key, prepend different per-row metadata).
+
+ All hooks accept ``**kwargs`` so callers can plumb extra fields through
+ a subclass (e.g. ``data_source`` for a per-dataset column) without
+ changing the base API.
+ """
+
+ columns: Tuple[str, ...] = ("step", "idx", "reward", "num_turns", "trajectory")
+ sample_seed: int = 0
+ """Seed for the random picks in :meth:`select_sample_indices`. Fixed so
+ consecutive runs surface the same prompt/response pairs in the table."""
+
+ def log(
+ self,
+ *,
+ tracker: Optional["Tracking"],
+ num_samples: int,
+ prompts: List[Any],
+ generator_output: Dict[str, Any],
+ tokenizer: Any,
+ global_step: Optional[int],
+ num_turns_list: Optional[List[int]] = None,
+ wandb_key: str,
+ include_idx: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ """Build sample rows from a GeneratorOutput-shaped dict and dispatch.
+
+ ``generator_output`` must contain at least ``response_ids``,
+ ``rewards`` and ``loss_masks``. ``prompts`` is passed separately
+ because it lives on :class:`GeneratorInput`, not output.
+
+ ``tracker`` is the active :class:`Tracking` instance. Today only the
+ wandb backend gets a trajectory table written (via
+ :meth:`Tracking.log_samples_to_table`); other backends are a no-op.
+
+ ``num_turns_list`` defaults to per-response assistant-span counts
+ derived from ``loss_masks`` via :meth:`count_assistant_turns`. Pass
+ an explicit value (e.g. trajectory step counts in step-wise eval)
+ to override.
+
+ Extra ``**kwargs`` are forwarded to :meth:`build_samples` for
+ subclass extensibility.
+
+ No-op when ``num_samples <= 0`` or the tracker is not a backend we
+ know how to write to.
+ """
+ response_ids = generator_output.get("response_ids") or []
+ if num_samples <= 0 or tracker is None or tracker.backend != "wandb" or not response_ids:
+ return
+ loss_masks = generator_output.get("loss_masks") or []
+ rewards = generator_output.get("rewards") or []
+ if num_turns_list is None:
+ num_turns_list = [self.count_assistant_turns(m) for m in loss_masks]
+ samples = self.build_samples(
+ num_samples=num_samples,
+ prompts=prompts,
+ response_ids=response_ids,
+ rewards=rewards,
+ loss_masks=loss_masks,
+ num_turns_list=num_turns_list,
+ tokenizer=tokenizer,
+ **kwargs,
+ )
+ if not samples:
+ return
+ # `global_step` may be None (eval-only context); the table API wants
+ # a numeric step.
+ step = 0 if global_step is None else global_step
+ # ``build_samples`` always emits ``(idx, reward, num_turns, trajectory)``
+ # tuples; drop the leading idx when the caller doesn't want it logged.
+ columns = list(self.columns)
+ if not include_idx:
+ columns = [c for c in columns if c != "idx"]
+ samples = [sample[1:] for sample in samples]
+ # Prepend `step` to each row so rows from different calls remain
+ # distinguishable in the accumulating table.
+ tracker.log_samples_to_table(
+ key=wandb_key,
+ columns=columns,
+ samples=[(step, *sample) for sample in samples],
+ step=step,
+ )
+
+ def build_samples(
+ self,
+ *,
+ num_samples: int,
+ prompts: List[Any],
+ response_ids: List[List[int]],
+ rewards: List[float],
+ loss_masks: List[List[int]],
+ num_turns_list: List[int],
+ tokenizer: Any,
+ **kwargs: Any,
+ ) -> List[Tuple[Any, ...]]:
+ """Build the per-row tuples to be written.
+
+ Default shape: ``(idx, reward, num_turns, trajectory)`` (matching
+ :attr:`columns` after the ``step`` column is prepended by :meth:`log`).
+ ``idx`` is the position in the input arrays, *not* a sequential row
+ number, so it points back to the original sample.
+
+ Indices are chosen by :meth:`select_sample_indices`, which by default
+ anchors on the min- and max-reward samples and fills the rest at
+ random. Override either method to customize selection or column shape;
+ ``**kwargs`` are whatever extra values the caller passed to
+ :meth:`log`.
+ """
+ total = min(
+ len(response_ids),
+ len(prompts),
+ len(rewards),
+ len(loss_masks),
+ len(num_turns_list),
+ )
+ # Per-token rewards arrive as lists; collapse to scalars so the
+ # min/max picks and the wandb column are both well-typed.
+ scalar_rewards = [float(sum(r)) if isinstance(r, list) else float(r) for r in rewards[:total]]
+ indices = self.select_sample_indices(num_samples=num_samples, rewards=scalar_rewards, total=total)
+ return [
+ (
+ i,
+ scalar_rewards[i],
+ num_turns_list[i],
+ self.format_trajectory(prompts[i], response_ids[i], loss_masks[i], tokenizer, **kwargs),
+ )
+ for i in indices
+ ]
+
+ def select_sample_indices(
+ self,
+ *,
+ num_samples: int,
+ rewards: Sequence[float],
+ total: int,
+ ) -> List[int]:
+ """Pick up to ``num_samples`` indices from ``[0, total)`` to log.
+
+ Guarantees that the min- and max-reward samples are included when
+ ``num_samples >= 2`` and ``total >= 2``; remaining slots are filled by
+ random sampling without replacement. With ``num_samples == 1`` only
+ the min-reward sample is kept (arbitrary choice when we can fit one).
+ If all rewards tie, ``min`` and ``max`` may resolve to the same index;
+ the duplicate is dropped and the rest of the budget goes to random
+ picks. Returned indices are sorted ascending so the output table reads
+ in input order.
+ """
+ if total <= 0 or num_samples <= 0:
+ return []
+ n = min(num_samples, total)
+ if n >= total:
+ return list(range(total))
+
+ # Anchor on extremes. `min`/`max` return the first occurrence on ties,
+ # which is fine -- we just need one of each.
+ min_idx = min(range(total), key=lambda i: rewards[i])
+ max_idx = max(range(total), key=lambda i: rewards[i])
+ anchors: List[int] = []
+ for cand in (min_idx, max_idx):
+ if cand not in anchors:
+ anchors.append(cand)
+ anchors = anchors[:n]
+
+ rest_needed = n - len(anchors)
+ if rest_needed > 0:
+ pool = [i for i in range(total) if i not in anchors]
+ rng = random.Random(self.sample_seed)
+ anchors.extend(rng.sample(pool, min(rest_needed, len(pool))))
+ return sorted(anchors)
+
+ def format_trajectory(
+ self,
+ prompt: Any,
+ response_token_ids: Optional[List[int]],
+ loss_mask: Optional[List[int]],
+ tokenizer: Any,
+ **kwargs: Any,
+ ) -> str:
+ """Render a trajectory as a human-readable string with role separators.
+
+ The initial prompt (a list of ``{"role", "content"}`` chat messages,
+ or a plain string) is rendered with one ``[ROLE]\\n{content}`` block
+ per message. The generated response is then split into runs based on
+ ``loss_mask`` -- mask=1 spans are ``[ASSISTANT]``, mask=0 spans are
+ ``[USER/TOOL]`` (the mask alone can't distinguish the two). Override
+ for env-specific formatting (e.g. parsing structured tool calls).
+ """
+ parts: List[str] = []
+ if isinstance(prompt, list) and prompt and isinstance(prompt[0], dict):
+ for msg in prompt:
+ role = str(msg.get("role", "user")).upper()
+ content = msg.get("content", "")
+ parts.append(f"[{role}]\n{content}")
+ elif prompt is not None:
+ parts.append(f"[USER]\n{prompt}")
+
+ if response_token_ids:
+ if loss_mask and len(loss_mask) == len(response_token_ids):
+ cur_role: Optional[str] = None
+ cur_tokens: List[int] = []
+ for tok, m in zip(response_token_ids, loss_mask):
+ new_role = "ASSISTANT" if m == 1 else "USER/TOOL"
+ if cur_role is None:
+ cur_role = new_role
+ cur_tokens = [tok]
+ elif new_role == cur_role:
+ cur_tokens.append(tok)
+ else:
+ decoded = tokenizer.decode(cur_tokens, skip_special_tokens=True)
+ parts.append(f"[{cur_role}]\n{decoded}")
+ cur_role = new_role
+ cur_tokens = [tok]
+ if cur_tokens and cur_role is not None:
+ decoded = tokenizer.decode(cur_tokens, skip_special_tokens=True)
+ parts.append(f"[{cur_role}]\n{decoded}")
+ else:
+ decoded = tokenizer.decode(response_token_ids, skip_special_tokens=True)
+ parts.append(f"[ASSISTANT]\n{decoded}")
+
+ return "\n\n".join(parts)
+
+ def count_assistant_turns(self, loss_mask: Optional[List[int]]) -> int:
+ """Count contiguous 1-spans in ``loss_mask`` (each = one assistant turn).
+
+ Default for the ``num_turns`` column when the caller doesn't supply
+ one. Returns 0 for empty/None masks.
+ """
+ if not loss_mask:
+ return 0
+ turns = 0
+ prev = 0
+ for m in loss_mask:
+ if m == 1 and prev == 0:
+ turns += 1
+ prev = m
+ return turns
diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_trainer_full_checkpointing.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_trainer_full_checkpointing.py
index 12eb937819..35ff9fa310 100644
--- a/tests/backends/skyrl_train/gpu/gpu_ci/test_trainer_full_checkpointing.py
+++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_trainer_full_checkpointing.py
@@ -111,7 +111,7 @@ def create_minimal_trainer(cfg: SkyRLTrainConfig):
tracker = Tracking(
project_name=cfg.trainer.project_name,
experiment_name=cfg.trainer.run_name,
- backends=cfg.trainer.logger,
+ backend=cfg.trainer.logger,
config=cfg,
)
diff --git a/tests/train/utils/test_logging_utils.py b/tests/train/utils/test_logging_utils.py
index 784e847843..dea39d62d7 100644
--- a/tests/train/utils/test_logging_utils.py
+++ b/tests/train/utils/test_logging_utils.py
@@ -4,12 +4,12 @@
import pytest
-from skyrl.train.utils.logging_utils import (
+from skyrl.train.utils.trajectory_logging import (
BASE_PROMPT_COLOR,
NEGATIVE_RESPONSE_COLOR,
POSITIVE_RESPONSE_COLOR,
_color_block_format_and_kwargs,
- log_example,
+ pretty_print_example,
)
@@ -56,13 +56,13 @@ def test_color_block_format_and_kwargs_multi_line():
([0.1, 0.2], POSITIVE_RESPONSE_COLOR),
],
)
-def test_log_example_uses_expected_colors_and_reward_string(reward, expected_color):
+def test_pretty_print_example_uses_expected_colors_and_reward_string(reward, expected_color):
logger = StubLogger()
prompt = [{"role": "user", "content": "line1\nline2"}]
response = "out1\nout2"
- log_example(logger, prompt=prompt, response=response, reward=reward)
+ pretty_print_example(logger, prompt=prompt, response=response, reward=reward)
# Basic structure checks
assert logger.last_message.startswith("Example:\n Input: ")
@@ -84,7 +84,7 @@ def test_log_example_uses_expected_colors_and_reward_string(reward, expected_col
if reward is None:
assert reward_str == "N/A"
else:
- # log_example normalizes rewards to a single float
+ # pretty_print_example normalizes rewards to a single float
if isinstance(reward, list):
expected_val = float(sum(reward))
else:
@@ -98,20 +98,20 @@ def test_log_example_uses_expected_colors_and_reward_string(reward, expected_col
assert f"{expected_color}>" in logger.last_message
-def test_log_example_handles_exceptions_gracefully(monkeypatch, capsys):
- """Force an exception inside log_example and ensure the fallback path prints."""
+def test_pretty_print_example_handles_exceptions_gracefully(monkeypatch, capsys):
+ """Force an exception inside pretty_print_example and ensure the fallback path prints."""
def broken_color_block(*args, **kwargs):
raise RuntimeError("boom")
# Patch the helper to raise
monkeypatch.setattr(
- "skyrl.train.utils.logging_utils._color_block_format_and_kwargs",
+ "skyrl.train.utils.trajectory_logging._color_block_format_and_kwargs",
broken_color_block,
)
logger = StubLogger()
- log_example(logger, prompt=[{"role": "user", "content": "p"}], response="r", reward=None)
+ pretty_print_example(logger, prompt=[{"role": "user", "content": "p"}], response="r", reward=None)
# And the plain-text fallback should be printed to stdout
captured = capsys.readouterr()