From c60a1b914513dfa5a4c90684dd26250d0e40e05d Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 28 May 2026 02:13:33 +0000 Subject: [PATCH] add validation sample logging Refactor example/trajectory logging: - Rename `trainer.log_example_interval` -> `print_example_interval`; add `trainer.num_logger_eval_samples` / `num_logger_train_samples` knobs. - Replace `logging_utils.log_example` with `trajectory_logging.pretty_print_example` and add a `TrajectoryLogger` class that uploads (prompt, response, reward, num_turns, trajectory) rows to a wandb table during eval and (optionally) training. Pluggable via `BasePPOExp.get_trajectory_logger`. - Simplify `Tracking` from multi-backend (`backends: List[str]`) to a single `backend: str`; add `log_samples_to_table` for accumulating wandb tables. - Update evaluate.py / trainer.py / sft_trainer.py / main_base.py / tests / docs / gsm8k example. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/content/docs/configuration/config.mdx | 4 +- examples/train/gsm8k/run_gsm8k.sh | 3 +- skyrl/train/config/config.py | 17 +- skyrl/train/entrypoints/main_base.py | 19 +- skyrl/train/evaluate.py | 67 +++- skyrl/train/sft_trainer.py | 2 +- skyrl/train/trainer.py | 30 +- skyrl/train/utils/logging_utils.py | 84 ----- skyrl/train/utils/tracking.py | 200 ++++------ skyrl/train/utils/trajectory_logging.py | 347 ++++++++++++++++++ .../gpu_ci/test_trainer_full_checkpointing.py | 2 +- tests/train/utils/test_logging_utils.py | 18 +- 12 files changed, 549 insertions(+), 244 deletions(-) delete mode 100644 skyrl/train/utils/logging_utils.py create mode 100644 skyrl/train/utils/trajectory_logging.py diff --git a/docs/content/docs/configuration/config.mdx b/docs/content/docs/configuration/config.mdx index 51eae082fa..c95ccc745c 100644 --- a/docs/content/docs/configuration/config.mdx +++ b/docs/content/docs/configuration/config.mdx @@ -118,7 +118,7 @@ run_name: "test_run" log_path: "/tmp/skyrl-logs" dump_data_batch: false dump_eval_results: true -log_example_interval: 1 +print_example_interval: 1 ``` @@ -128,7 +128,7 @@ log_example_interval: 1 - `log_path`: Path for infrastructure log files. Infrastructure logs (vLLM engine startup, model loading, worker initialization) are written to `{log_path}/infra-YYMMDD_HHMMSS.log`. For multi-node training, use a shared filesystem path to consolidate logs into a single file. See the [logging guide](../checkpointing-logging/logging) for details. - `dump_data_batch`: Whether to dump the data batch to a file. This is useful for debugging. When `true`, the data batch will be dumped to a file in the `export_path` directory. The training batch at global step `N` is saved to `self.cfg.trainer.export_path / "dumped_data" / global_step_N_training_input` - `dump_eval_results`: Whether to dump the evaluation results to a file. When `true`, the full evaluation results will be dumped to a file in the `export_path` directory. The evaluation results at global step `N` is saved to `self.cfg.trainer.export_path / "dumped_eval" / global_step_N_eval_results` -- `log_example_interval`: Log an example prompt every N training steps, `0`/`-1` to disable. +- `print_example_interval`: Pretty-print an example prompt/response/reward to stdout every N training steps, `0`/`-1` to disable. ## Training Backends diff --git a/examples/train/gsm8k/run_gsm8k.sh b/examples/train/gsm8k/run_gsm8k.sh index 973b0d2595..8695b0b759 100755 --- a/examples/train/gsm8k/run_gsm8k.sh +++ b/examples/train/gsm8k/run_gsm8k.sh @@ -20,7 +20,8 @@ uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ trainer.algorithm.advantage_estimator="grpo" \ - trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ + trainer.policy.model.path="Qwen/Qwen3-0.6B" \ + trainer.num_logger_eval_samples=10 \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp \ trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index e4ac7bb655..ac49b23d9d 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -669,8 +669,21 @@ class TrainerConfig(BaseConfig): dump_eval_results: bool = True rope_scaling: Optional[Dict[str, Any]] = None rope_theta: Optional[float] = None - log_example_interval: int = 1 - """Log an example prompt every N training steps, ``0``/``-1`` to disable""" + print_example_interval: int = 1 + """Pretty-print an example prompt/response/reward to stdout every N + training steps; ``0``/``-1`` disables. Renamed from ``log_example_interval``.""" + num_logger_eval_samples: int = -1 + """Number of evaluation trajectory (prompt, response, score) tuples to upload to a wandb + table on each eval. ``-1`` (default) or ``0`` disables. When positive, + up to this many samples are taken from the start of each eval pass and + logged via :class:`TrajectoryLogger`. Column count is fixed + by the first call, so keep the eval set size and this value stable.""" + num_logger_train_samples: int = -1 + """Number of training trajectory (prompt, response, score) tuples to upload to a wandb + table on each training step. ``-1`` (default) or ``0`` disables. When positive, + up to this many samples are taken from the start of each training step and + logged via :class:`TrajectoryLogger`. Column count is fixed + by the first call, so keep the training set size and this value stable.""" logprobs_chunk_size: Optional[int] = 1024 """Chunk size along the sequence dimension when computing log-probs from logits. This lowers peak GPU memory at the cost of ~2x wall-clock time. diff --git a/skyrl/train/entrypoints/main_base.py b/skyrl/train/entrypoints/main_base.py index 4aa53acac2..9c7bd3df01 100644 --- a/skyrl/train/entrypoints/main_base.py +++ b/skyrl/train/entrypoints/main_base.py @@ -29,6 +29,7 @@ from skyrl.train.trainer import RayPPOTrainer from skyrl.train.utils import validate_cfg from skyrl.train.utils.tracking import Tracking +from skyrl.train.utils.trajectory_logging import TrajectoryLogger from skyrl.train.utils.utils import ( ResolvedPlacementGroup, get_ray_pg_ready_with_timeout, @@ -278,10 +279,24 @@ def get_tracker(self): return Tracking( project_name=self.cfg.trainer.project_name, experiment_name=self.cfg.trainer.run_name, - backends=self.cfg.trainer.logger, + backend=self.cfg.trainer.logger, config=self.cfg, ) + def get_trajectory_logger(self) -> TrajectoryLogger: + """Initializes the trajectory logger used during eval (and optionally + training) to upload (prompt, response, reward) samples to wandb. + + Override in a subclass to swap in a project-specific + :class:`TrajectoryLogger` (custom columns, trajectory rendering, or + wandb key). The instance is cheap and statefully accumulates rows + across evals via the wandb-Table re-create workaround. + + Returns: + TrajectoryLogger: The trajectory logger. + """ + return TrajectoryLogger() + def get_inference_client(self) -> InferenceEngineInterface: """Setup and return the inference engine client. @@ -387,6 +402,8 @@ def _setup_trainer(self): generator=generator, colocate_pg=self.colocate_pg, ) + # Install the trajectory logger after construction + trainer.trajectory_logger = self.get_trajectory_logger() # Expose the trainer on self so callers can log exceptions raised # during `build_models` (which happens before _setup_trainer returns). self.trainer = trainer diff --git a/skyrl/train/evaluate.py b/skyrl/train/evaluate.py index f24c84f2bd..1e46c2d0b8 100644 --- a/skyrl/train/evaluate.py +++ b/skyrl/train/evaluate.py @@ -1,6 +1,9 @@ -from collections import defaultdict +from collections import Counter, defaultdict from pathlib import Path -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + from skyrl.train.utils.tracking import Tracking import torch from loguru import logger @@ -22,12 +25,12 @@ prepare_generator_input, ) from skyrl.train.utils import Timer -from skyrl.train.utils.logging_utils import log_example from skyrl.train.utils.trainer_utils import ( calculate_per_dataset_metrics, dump_per_dataset_eval_results, validate_generator_output, ) +from skyrl.train.utils.trajectory_logging import TrajectoryLogger, pretty_print_example @torch.no_grad() @@ -37,6 +40,8 @@ async def evaluate( cfg: SkyRLTrainConfig, global_step: int | None, tokenizer: AutoTokenizer, + trajectory_logger: Optional[TrajectoryLogger] = None, + tracker: Optional["Tracking"] = None, ) -> Dict[str, float]: """Runs generation and evaluation of trajectories. @@ -57,6 +62,7 @@ async def evaluate( concat_all_envs: List[str] = [] concat_env_extras: List[Dict[str, Any]] = [] concat_uids: List[str] = [] + concat_prompts: List[str] = [] sampling_params = cfg.generator.eval_sampling_params pbar = tqdm(total=len(eval_dataloader), initial=0, desc="Evaluation Progress") for _, prompts in enumerate(eval_dataloader): @@ -75,20 +81,34 @@ async def evaluate( concat_all_envs.extend(generator_input["env_classes"]) concat_env_extras.extend(generator_input["env_extras"]) concat_uids.extend(uids) + concat_prompts.extend(generator_input["prompts"]) concat_generator_outputs: GeneratorOutput = concatenate_generator_outputs(generator_outputs) # Extract data_sources from env_extras concat_data_sources = [env_extra.get("data_source") for env_extra in concat_env_extras] - if cfg.trainer.log_example_interval > 0: + if cfg.trainer.print_example_interval > 0: vis = tokenizer.decode(generator_output["response_ids"][0]) - log_example( + pretty_print_example( logger, prompt=generator_input["prompts"][0], response=vis, reward=generator_output["rewards"][0], ) + # Optionally upload up to `num_logger_eval_samples` samples to tracker (wandb) + if trajectory_logger is not None: + with Timer("log_eval_results"): + trajectory_logger.log( + tracker=tracker, + num_samples=cfg.trainer.num_logger_eval_samples, + prompts=concat_prompts, + generator_output=concat_generator_outputs, + tokenizer=tokenizer, + global_step=global_step, + wandb_key="trajectories/eval", + ) + # 2. Group data by data source and calculate per-dataset metrics eval_metrics = calculate_per_dataset_metrics( concat_generator_outputs, concat_uids, concat_data_sources, cfg.generator.eval_n_samples_per_prompt @@ -137,6 +157,8 @@ async def evaluate_step_wise( cfg: SkyRLTrainConfig, global_step: int | None, tokenizer: AutoTokenizer, + trajectory_logger: Optional[TrajectoryLogger] = None, + tracker: Optional["Tracking"] = None, ) -> Dict[str, float]: """Runs generation and evaluation of trajectories for step-wise training. @@ -159,6 +181,7 @@ async def evaluate_step_wise( concat_all_envs: List[str] = [] concat_env_extras: List[Dict[str, Any]] = [] concat_uids: List[str] = [] + concat_prompts: List[str] = [] sampling_params = cfg.generator.eval_sampling_params pbar = tqdm(total=len(eval_dataloader), initial=0, desc="Evaluation Progress") for _, prompts in enumerate(eval_dataloader): @@ -173,9 +196,16 @@ async def evaluate_step_wise( ) generator_output: GeneratorOutput = await generator.generate(generator_input) traj_id_to_input = { - traj_id.instance_id: {"env_class": env_class, "env_extras": env_extra} - for traj_id, env_class, env_extra in zip( - generator_input["trajectory_ids"], generator_input["env_classes"], generator_input["env_extras"] + traj_id.instance_id: { + "env_class": env_class, + "env_extras": env_extra, + "prompt": prompt, + } + for traj_id, env_class, env_extra, prompt in zip( + generator_input["trajectory_ids"], + generator_input["env_classes"], + generator_input["env_extras"], + generator_input["prompts"], ) } for traj_id in generator_output["trajectory_ids"]: @@ -183,6 +213,7 @@ async def evaluate_step_wise( concat_all_envs.append(traj_id_to_input[traj_id.instance_id]["env_class"]) concat_env_extras.append(traj_id_to_input[traj_id.instance_id]["env_extras"]) concat_uids.append(traj_id.instance_id) + concat_prompts.append(traj_id_to_input[traj_id.instance_id]["prompt"]) validate_generator_output(generator_input, generator_output, step_wise=True) generator_outputs.append(generator_output) concat_generator_outputs: GeneratorOutput = concatenate_generator_outputs(generator_outputs) @@ -190,7 +221,7 @@ async def evaluate_step_wise( # Extract data_sources from env_extras concat_data_sources = [env_extra.get("data_source") for env_extra in concat_env_extras] - if cfg.trainer.log_example_interval > 0: + if cfg.trainer.print_example_interval > 0: vis = tokenizer.decode(generator_output["response_ids"][0]) logger.info(f"Eval output example: {vis}") @@ -209,6 +240,24 @@ async def evaluate_step_wise( data_sources_last_step = [ data_source for data_source, is_last_step in zip(concat_data_sources, is_last_step_mask) if is_last_step ] + prompts_last_step = [prompt for prompt, is_last_step in zip(concat_prompts, is_last_step_mask) if is_last_step] + + # Optionally upload up to `num_logger_eval_samples` samples to wandb. + # For step-wise we override the logger's default loss-mask-based + # num_turns with the total step count per trajectory (counted *before* + # the last-step filter). + if trajectory_logger is not None: + trajectory_step_counts = Counter(concat_uids) + trajectory_logger.log( + tracker=tracker, + num_samples=cfg.trainer.num_logger_eval_samples, + prompts=prompts_last_step, + generator_output=generator_output_last_step, + tokenizer=tokenizer, + global_step=global_step, + num_turns_list=[trajectory_step_counts[uid] for uid in uids_last_step], + wandb_key="trajectories/eval", + ) # 2. Group data by data source and calculate per-dataset metrics eval_metrics = calculate_per_dataset_metrics( diff --git a/skyrl/train/sft_trainer.py b/skyrl/train/sft_trainer.py index cfb3443c26..3f82705866 100644 --- a/skyrl/train/sft_trainer.py +++ b/skyrl/train/sft_trainer.py @@ -764,7 +764,7 @@ def _init_tracker(self): self.tracker = Tracking( project_name=self.cfg.trainer.project_name, experiment_name=self.cfg.trainer.run_name, - backends=self.cfg.trainer.logger, + backend=self.cfg.trainer.logger, config=self.sft_cfg, ) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index a6eb2109c6..22763667be 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -67,7 +67,6 @@ get_ray_pg_ready_with_timeout, trainer_utils, ) -from skyrl.train.utils.logging_utils import log_example from skyrl.train.utils.ray_gpu_monitor import RayGpuMonitor from skyrl.train.utils.tracking import Tracking from skyrl.train.utils.trainer_utils import ( @@ -82,6 +81,7 @@ validate_generator_output, zero_variance_filter, ) +from skyrl.train.utils.trajectory_logging import TrajectoryLogger, pretty_print_example from skyrl.train.utils.utils import ResolvedPlacementGroup, configure_ray_worker_logging from skyrl.train.utils.vllm_metrics_scraper import VLLMMetricsScraper @@ -127,6 +127,9 @@ def __init__( self._ray_gpu_monitor = RayGpuMonitor() if cfg.trainer.enable_ray_gpu_monitor else None + # trajectory logger is installed after construction if needed + self.trajectory_logger: TrajectoryLogger = None + # initialized in `build_models` self.policy_model: PPORayActorGroup = None self.critic_model: Optional[PPORayActorGroup] = None @@ -180,6 +183,8 @@ async def eval(self) -> Dict[str, float]: cfg=self.cfg, global_step=self.global_step, tokenizer=self.tokenizer, + trajectory_logger=self.trajectory_logger, + tracker=self.tracker, ) else: eval_metrics = await evaluate( @@ -188,6 +193,8 @@ async def eval(self) -> Dict[str, float]: cfg=self.cfg, global_step=self.global_step, tokenizer=self.tokenizer, + trajectory_logger=self.trajectory_logger, + tracker=self.tracker, ) return eval_metrics @@ -268,16 +275,29 @@ async def train(self): with Timer("postprocess_generator_output", self.all_timings): generator_output, uids = self.postprocess_generator_output(generator_output, uids) - # 2. print example just for debugging - log_interval = self.cfg.trainer.log_example_interval - if log_interval > 0 and self.global_step % log_interval == 0: + # 2.1 print example just for debugging + print_interval = self.cfg.trainer.print_example_interval + if print_interval > 0 and self.global_step % print_interval == 0: vis = self.tokenizer.decode(generator_output["response_ids"][0]) - log_example( + pretty_print_example( logger, prompt=generator_input["prompts"][0], response=vis, reward=generator_output["rewards"][0], ) + # 2.2 Optionally upload up to `num_logger_train_samples` samples to tracker + if self.trajectory_logger is not None: + with Timer("log_train_results"): + self.trajectory_logger.log( + tracker=self.tracker, + num_samples=self.cfg.trainer.num_logger_train_samples, + prompts=generator_input["prompts"], + generator_output=generator_output, + tokenizer=self.tokenizer, + global_step=self.global_step, + wandb_key="trajectories/train", + include_idx=False, + ) # 3. Convert GeneratorOutput to TrainingInputBatch with Timer("convert_to_training_input", self.all_timings): diff --git a/skyrl/train/utils/logging_utils.py b/skyrl/train/utils/logging_utils.py deleted file mode 100644 index 863d31d788..0000000000 --- a/skyrl/train/utils/logging_utils.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import Any, Dict, List, Optional, Union - -POSITIVE_RESPONSE_COLOR = "green" -NEGATIVE_RESPONSE_COLOR = "yellow" -BASE_PROMPT_COLOR = "cyan" - - -def _color_block_format_and_kwargs( - text: str, - color: str, - field_prefix: str, -) -> tuple[str, dict]: - """Build a format string and kwargs for a multi-line colored block. - - The format string will look like: - "{p0}\n{p1}\n..." - - where "p0", "p1", ... are placeholder names starting with `field_prefix`. - """ - # Ensure at least one line - lines = text.splitlines() or [""] - - fmt_lines = [] - kwargs: dict[str, str] = {} - - for i, line in enumerate(lines): - key = f"{field_prefix}{i}" - # NOTE: double braces {{ }} so that {key} survives into str.format - fmt_lines.append(f"<{color}>{{{key}}}") - kwargs[key] = line - - fmt = "\n".join(fmt_lines) - return fmt, kwargs - - -def log_example( - logger: Any, - prompt: List[Dict[str, Any]], - response: str, - reward: Optional[Union[float, List[float]]] = None, -) -> None: - """ - Log a single example prompt and response with formatting and colors. - - Args: - logger: The logger instance to use (expected to be loguru logger or compatible). - prompt: The input prompt in OpenAI message format. - response: The output response string. - reward: The reward value(s) associated with the response. - """ - reward_val = 0.0 - reward_str = "N/A" - try: - prompt_str = str(prompt) - response_str = str(response) - # --- Reward handling --- - if reward is not None: - if isinstance(reward, list): - reward_val = float(sum(reward)) - else: - reward_val = float(reward) - reward_str = f"{reward_val:.4f}" - - # --- Color selection --- - if reward is not None and reward_val > 0: - response_color = POSITIVE_RESPONSE_COLOR - else: - response_color = NEGATIVE_RESPONSE_COLOR - - # --- Build per-line colored blocks in the *format string* --- - prompt_fmt, prompt_kwargs = _color_block_format_and_kwargs(prompt_str, BASE_PROMPT_COLOR, "p") - response_fmt, response_kwargs = _color_block_format_and_kwargs(response_str, response_color, "r") - - # Single format string with only our own markup and placeholders - log_format = "Example:\n" f" Input: {prompt_fmt}\n" " Output (Total Reward: {reward}):\n" f"{response_fmt}" - - # Merge all args for str.format - format_kwargs = {**prompt_kwargs, **response_kwargs, "reward": reward_str} - - # Let Loguru parse tags in log_format and then substitute arguments. - logger.opt(colors=True).info(log_format, **format_kwargs) - except Exception as e: - print(f"Error pretty printing example, debug printing instead: {e}") - print(f"Example:\n Input: {prompt}\n Output (Total Reward: {reward_str}):\n{response}") diff --git a/skyrl/train/utils/tracking.py b/skyrl/train/utils/tracking.py index 2fead5342e..f2f6f3dde4 100644 --- a/skyrl/train/utils/tracking.py +++ b/skyrl/train/utils/tracking.py @@ -21,7 +21,7 @@ from enum import Enum from functools import partial from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from loguru import logger from omegaconf import DictConfig, OmegaConf @@ -37,26 +37,20 @@ def __init__( self, project_name, experiment_name, - backends: Union[str, List[str]] = "console", + backend: str = "console", config: Optional[Union[SkyRLTrainConfig, DictConfig]] = None, ): - if isinstance(backends, str): - backends = [backends] - for backend in backends: - assert backend in self.supported_backends, f"{backend} is not supported" + assert backend in self.supported_backends, f"{backend} is not supported" + self.backend = backend - self.logger = {} - - if "wandb" in backends: + if backend == "wandb": import wandb wandb.init(project=project_name, name=experiment_name, config=get_config_as_dict(config)) - self.logger["wandb"] = wandb - - if "mlflow" in backends: - self.logger["mlflow"] = _MlflowLoggingAdapter(project_name, experiment_name, config) - - if "swanlab" in backends: + self.logger: Any = wandb + elif backend == "mlflow": + self.logger = _MlflowLoggingAdapter(project_name, experiment_name, config) + elif backend == "swanlab": import os import swanlab @@ -73,44 +67,41 @@ def __init__( logdir=SWANLAB_LOG_DIR, mode=SWANLAB_MODE, ) - self.logger["swanlab"] = swanlab - - if "tensorboard" in backends: - self.logger["tensorboard"] = _TensorboardAdapter() - - if "console" in backends: - self.console_logger = ConsoleLogger() - self.logger["console"] = self.console_logger + self.logger = swanlab + elif backend == "tensorboard": + self.logger = _TensorboardAdapter() + else: # "console" + self.logger = ConsoleLogger() self._exception_logged = False def log(self, data, step, commit=False): - for logger_name, logger_instance in self.logger.items(): - if logger_name == "wandb": - logger_instance.log(data=data, step=step, commit=commit) - else: - logger_instance.log(data=data, step=step) + if self.backend == "wandb": + self.logger.log(data=data, step=step, commit=commit) + else: + self.logger.log(data=data, step=step) def finish(self): - for logger_name, logger_instance in self.logger.items(): - # NOTE (sumanthrh): We use a try-except block here while finishing tracking. - # This is because wandb often errors out with a BrokenPipeError when closing. - # https://github.com/wandb/wandb/issues/6449 - try: - if logger_name == "wandb": - logger_instance.finish(exit_code=0) - elif logger_name != "console": - logger_instance.finish() - except Exception as e: - logger.warning(f"Attempted to finish tracking with logger {logger_name} but got error {e}") + if self.backend == "console": + return + # NOTE (sumanthrh): We use a try-except block here while finishing tracking. + # This is because wandb often errors out with a BrokenPipeError when closing. + # https://github.com/wandb/wandb/issues/6449 + try: + if self.backend == "wandb": + self.logger.finish(exit_code=0) + else: + self.logger.finish() + except Exception as e: + logger.warning(f"Attempted to finish tracking with backend {self.backend} but got error {e}") def log_exception(self, e: BaseException, step: int = 0) -> None: - """Log the active exception's traceback to all configured backends. + """Log the active exception's traceback to the configured backend. Always prints the traceback on the driver via loguru (so it lands in - Ray driver logs instead of being swallowed). If wandb is configured, - also logs a row to an `error/tracebacks` wandb.Table and calls - `finish()` to flush the async upload before the process re-raises. + Ray driver logs instead of being swallowed). If the wandb backend is + active, also logs a row to an `error/tracebacks` wandb.Table and calls + `finish()` to flush the async upload before the caller re-raises. Ray-wrapped worker errors (e.g. OOMs raised inside actors) include both local and remote frames in `traceback.format_exc()`. @@ -123,7 +114,7 @@ def log_exception(self, e: BaseException, step: int = 0) -> None: self._exception_logged = True tb_str = traceback.format_exc()[-10000:] logger.error(f"Training failed at step {step} with {type(e).__name__}:\n{tb_str}") - if "wandb" in self.logger: + if self.backend == "wandb": try: import wandb @@ -132,7 +123,7 @@ def log_exception(self, e: BaseException, step: int = 0) -> None: # Note: omit `step=` here. Per-step logs use commit=True, so # re-logging at the same step would be dropped. The step value # is also embedded in the table row itself. - self.logger["wandb"].log({"error/tracebacks": error_table}) + self.logger.log({"error/tracebacks": error_table}) # Tables upload asynchronously. Finish the run so the upload # completes before the caller re-raises and the process dies. try: @@ -142,6 +133,41 @@ def log_exception(self, e: BaseException, step: int = 0) -> None: except Exception as log_exc: logger.warning(f"Failed to log exception traceback to wandb: {log_exc}") + def log_samples_to_table( + self, + key: str, + columns: List[str], + samples: List[Tuple[Any, ...]], + step: int, + ) -> None: + """Append rows to an accumulating wandb table at ``key``. + + Each call extends the existing table at ``key`` (or creates one on + first call) with ``samples`` and logs the new table to wandb at the + given ``step``. ``columns`` defines the table schema and must stay + consistent across calls for the same ``key``; each row in ``samples`` + must have ``len(columns)`` values in the matching order. + + No-op for non-wandb backends -- only the wandb backend supports + ``wandb.Table``. + """ + if self.backend != "wandb": + return + import wandb + + # Cache one table per key so different callers (e.g. eval vs train + # trajectory loggers, error traceback table) don't trample each other. + if not hasattr(self, "_sample_tables"): + self._sample_tables: Dict[str, Any] = {} + if key not in self._sample_tables: + self._sample_tables[key] = wandb.Table(columns=columns) + # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 + new_table = wandb.Table(columns=columns, data=self._sample_tables[key].data) + for row in samples: + new_table.add_data(*row) + self.logger.log({key: new_table}, step=step) + self._sample_tables[key] = new_table + def __del__(self): try: self.finish() @@ -260,87 +286,3 @@ def _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]: ans = pd.json_normalize(raw, sep=sep).to_dict(orient="records")[0] assert isinstance(ans, dict) return ans - - -@dataclasses.dataclass -class ValidationGenerationsLogger: - def log(self, loggers, samples, step): - if "wandb" in loggers: - self.log_generations_to_wandb(samples, step) - if "swanlab" in loggers: - self.log_generations_to_swanlab(samples, step) - if "mlflow" in loggers: - self.log_generations_to_mlflow(samples, step) - - def log_generations_to_wandb(self, samples, step): - """Log samples to wandb as a table""" - import wandb - - # Create column names for all samples - columns = ["step"] + sum( - [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], [] - ) - - if not hasattr(self, "validation_table"): - # Initialize the table on first call - self.validation_table = wandb.Table(columns=columns) - - # Create a new table with same columns and existing data - # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 - new_table = wandb.Table(columns=columns, data=self.validation_table.data) - - # Add new row with all data - row_data = [] - row_data.append(step) - for sample in samples: - row_data.extend(sample) - - new_table.add_data(*row_data) - - # Update reference and log - wandb.log({"val/generations": new_table}, step=step) - self.validation_table = new_table - - def log_generations_to_swanlab(self, samples, step): - """Log samples to swanlab as text""" - import swanlab - - swanlab_text_list = [] - for i, sample in enumerate(samples): - row_text = f""" - input: {sample[0]} - - --- - - output: {sample[1]} - - --- - - score: {sample[2]} - """ - swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}")) - - # Log to swanlab - swanlab.log({"val/generations": swanlab_text_list}, step=step) - - def log_generations_to_mlflow(self, samples, step): - """Log validation generation to mlflow as artifacts""" - # https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact - - import json - import tempfile - - import mlflow - - try: - with tempfile.TemporaryDirectory() as tmp_dir: - validation_gen_step_file = Path(tmp_dir, f"val_step{step}.json") - row_data = [] - for sample in samples: - data = {"input": sample[0], "output": sample[1], "score": sample[2]} - row_data.append(data) - with open(validation_gen_step_file, "w") as file: - json.dump(row_data, file) - mlflow.log_artifact(validation_gen_step_file) - except Exception as e: - logger.warning(f"save validation generation file to mlflow failed with error {e}") diff --git a/skyrl/train/utils/trajectory_logging.py b/skyrl/train/utils/trajectory_logging.py new file mode 100644 index 0000000000..dddd254fd9 --- /dev/null +++ b/skyrl/train/utils/trajectory_logging.py @@ -0,0 +1,347 @@ +"""Utils for trajectory logging.""" + +import dataclasses +import random +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union + +if TYPE_CHECKING: + from skyrl.train.utils.tracking import Tracking + +POSITIVE_RESPONSE_COLOR = "green" +NEGATIVE_RESPONSE_COLOR = "yellow" +BASE_PROMPT_COLOR = "cyan" + + +def _color_block_format_and_kwargs( + text: str, + color: str, + field_prefix: str, +) -> tuple[str, dict]: + """Build a format string and kwargs for a multi-line colored block. + + The format string will look like: + "{p0}\\n{p1}\\n..." + + where "p0", "p1", ... are placeholder names starting with `field_prefix`. + """ + # Ensure at least one line + lines = text.splitlines() or [""] + + fmt_lines = [] + kwargs: dict[str, str] = {} + + for i, line in enumerate(lines): + key = f"{field_prefix}{i}" + # NOTE: double braces {{ }} so that {key} survives into str.format + fmt_lines.append(f"<{color}>{{{key}}}") + kwargs[key] = line + + fmt = "\n".join(fmt_lines) + return fmt, kwargs + + +def pretty_print_example( + logger: Any, + prompt: List[Dict[str, Any]], + response: str, + reward: Optional[Union[float, List[float]]] = None, +) -> None: + """ + Log a single example prompt and response with formatting and colors. + + Args: + logger: The logger instance to use (expected to be loguru logger or compatible). + prompt: The input prompt in OpenAI message format. + response: The output response string. + reward: The reward value(s) associated with the response. + """ + reward_val = 0.0 + reward_str = "N/A" + try: + prompt_str = str(prompt) + response_str = str(response) + # --- Reward handling --- + if reward is not None: + if isinstance(reward, list): + reward_val = float(sum(reward)) + else: + reward_val = float(reward) + reward_str = f"{reward_val:.4f}" + + # --- Color selection --- + if reward is not None and reward_val > 0: + response_color = POSITIVE_RESPONSE_COLOR + else: + response_color = NEGATIVE_RESPONSE_COLOR + + # --- Build per-line colored blocks in the *format string* --- + prompt_fmt, prompt_kwargs = _color_block_format_and_kwargs(prompt_str, BASE_PROMPT_COLOR, "p") + response_fmt, response_kwargs = _color_block_format_and_kwargs(response_str, response_color, "r") + + # Single format string with only our own markup and placeholders + log_format = "Example:\n" f" Input: {prompt_fmt}\n" " Output (Total Reward: {reward}):\n" f"{response_fmt}" + + # Merge all args for str.format + format_kwargs = {**prompt_kwargs, **response_kwargs, "reward": reward_str} + + # Let Loguru parse tags in log_format and then substitute arguments. + logger.opt(colors=True).info(log_format, **format_kwargs) + except Exception as e: + print(f"Error pretty printing example, debug printing instead: {e}") + print(f"Example:\n Input: {prompt}\n Output (Total Reward: {reward_str}):\n{response}") + + +@dataclasses.dataclass +class TrajectoryLogger: + """Logs rollout samples to tracker backends as a table. + + Accepts a full ``GeneratorOutput``-shaped dict plus a parallel list of + prompts, and derives the per-sample fields (``num_turns`` from the loss + mask, the trajectory string from prompt + response tokens) internally. + A caller with a better signal for ``num_turns`` (e.g. step-wise eval + counting trajectory steps) can pass it in explicitly to override. + + Designed for subclassing. Override the granular pieces to customize: + + - :meth:`build_samples` -- which columns each row contains. If you + change the tuple shape, also update :attr:`columns`. + - :meth:`format_trajectory` -- how the trajectory string is rendered. + - :meth:`count_assistant_turns` -- default loss-mask -> turn-count. + - :meth:`log` -- top-level dispatch (e.g. add a new backend, change the + wandb key, prepend different per-row metadata). + + All hooks accept ``**kwargs`` so callers can plumb extra fields through + a subclass (e.g. ``data_source`` for a per-dataset column) without + changing the base API. + """ + + columns: Tuple[str, ...] = ("step", "idx", "reward", "num_turns", "trajectory") + sample_seed: int = 0 + """Seed for the random picks in :meth:`select_sample_indices`. Fixed so + consecutive runs surface the same prompt/response pairs in the table.""" + + def log( + self, + *, + tracker: Optional["Tracking"], + num_samples: int, + prompts: List[Any], + generator_output: Dict[str, Any], + tokenizer: Any, + global_step: Optional[int], + num_turns_list: Optional[List[int]] = None, + wandb_key: str, + include_idx: bool = True, + **kwargs: Any, + ) -> None: + """Build sample rows from a GeneratorOutput-shaped dict and dispatch. + + ``generator_output`` must contain at least ``response_ids``, + ``rewards`` and ``loss_masks``. ``prompts`` is passed separately + because it lives on :class:`GeneratorInput`, not output. + + ``tracker`` is the active :class:`Tracking` instance. Today only the + wandb backend gets a trajectory table written (via + :meth:`Tracking.log_samples_to_table`); other backends are a no-op. + + ``num_turns_list`` defaults to per-response assistant-span counts + derived from ``loss_masks`` via :meth:`count_assistant_turns`. Pass + an explicit value (e.g. trajectory step counts in step-wise eval) + to override. + + Extra ``**kwargs`` are forwarded to :meth:`build_samples` for + subclass extensibility. + + No-op when ``num_samples <= 0`` or the tracker is not a backend we + know how to write to. + """ + response_ids = generator_output.get("response_ids") or [] + if num_samples <= 0 or tracker is None or tracker.backend != "wandb" or not response_ids: + return + loss_masks = generator_output.get("loss_masks") or [] + rewards = generator_output.get("rewards") or [] + if num_turns_list is None: + num_turns_list = [self.count_assistant_turns(m) for m in loss_masks] + samples = self.build_samples( + num_samples=num_samples, + prompts=prompts, + response_ids=response_ids, + rewards=rewards, + loss_masks=loss_masks, + num_turns_list=num_turns_list, + tokenizer=tokenizer, + **kwargs, + ) + if not samples: + return + # `global_step` may be None (eval-only context); the table API wants + # a numeric step. + step = 0 if global_step is None else global_step + # ``build_samples`` always emits ``(idx, reward, num_turns, trajectory)`` + # tuples; drop the leading idx when the caller doesn't want it logged. + columns = list(self.columns) + if not include_idx: + columns = [c for c in columns if c != "idx"] + samples = [sample[1:] for sample in samples] + # Prepend `step` to each row so rows from different calls remain + # distinguishable in the accumulating table. + tracker.log_samples_to_table( + key=wandb_key, + columns=columns, + samples=[(step, *sample) for sample in samples], + step=step, + ) + + def build_samples( + self, + *, + num_samples: int, + prompts: List[Any], + response_ids: List[List[int]], + rewards: List[float], + loss_masks: List[List[int]], + num_turns_list: List[int], + tokenizer: Any, + **kwargs: Any, + ) -> List[Tuple[Any, ...]]: + """Build the per-row tuples to be written. + + Default shape: ``(idx, reward, num_turns, trajectory)`` (matching + :attr:`columns` after the ``step`` column is prepended by :meth:`log`). + ``idx`` is the position in the input arrays, *not* a sequential row + number, so it points back to the original sample. + + Indices are chosen by :meth:`select_sample_indices`, which by default + anchors on the min- and max-reward samples and fills the rest at + random. Override either method to customize selection or column shape; + ``**kwargs`` are whatever extra values the caller passed to + :meth:`log`. + """ + total = min( + len(response_ids), + len(prompts), + len(rewards), + len(loss_masks), + len(num_turns_list), + ) + # Per-token rewards arrive as lists; collapse to scalars so the + # min/max picks and the wandb column are both well-typed. + scalar_rewards = [float(sum(r)) if isinstance(r, list) else float(r) for r in rewards[:total]] + indices = self.select_sample_indices(num_samples=num_samples, rewards=scalar_rewards, total=total) + return [ + ( + i, + scalar_rewards[i], + num_turns_list[i], + self.format_trajectory(prompts[i], response_ids[i], loss_masks[i], tokenizer, **kwargs), + ) + for i in indices + ] + + def select_sample_indices( + self, + *, + num_samples: int, + rewards: Sequence[float], + total: int, + ) -> List[int]: + """Pick up to ``num_samples`` indices from ``[0, total)`` to log. + + Guarantees that the min- and max-reward samples are included when + ``num_samples >= 2`` and ``total >= 2``; remaining slots are filled by + random sampling without replacement. With ``num_samples == 1`` only + the min-reward sample is kept (arbitrary choice when we can fit one). + If all rewards tie, ``min`` and ``max`` may resolve to the same index; + the duplicate is dropped and the rest of the budget goes to random + picks. Returned indices are sorted ascending so the output table reads + in input order. + """ + if total <= 0 or num_samples <= 0: + return [] + n = min(num_samples, total) + if n >= total: + return list(range(total)) + + # Anchor on extremes. `min`/`max` return the first occurrence on ties, + # which is fine -- we just need one of each. + min_idx = min(range(total), key=lambda i: rewards[i]) + max_idx = max(range(total), key=lambda i: rewards[i]) + anchors: List[int] = [] + for cand in (min_idx, max_idx): + if cand not in anchors: + anchors.append(cand) + anchors = anchors[:n] + + rest_needed = n - len(anchors) + if rest_needed > 0: + pool = [i for i in range(total) if i not in anchors] + rng = random.Random(self.sample_seed) + anchors.extend(rng.sample(pool, min(rest_needed, len(pool)))) + return sorted(anchors) + + def format_trajectory( + self, + prompt: Any, + response_token_ids: Optional[List[int]], + loss_mask: Optional[List[int]], + tokenizer: Any, + **kwargs: Any, + ) -> str: + """Render a trajectory as a human-readable string with role separators. + + The initial prompt (a list of ``{"role", "content"}`` chat messages, + or a plain string) is rendered with one ``[ROLE]\\n{content}`` block + per message. The generated response is then split into runs based on + ``loss_mask`` -- mask=1 spans are ``[ASSISTANT]``, mask=0 spans are + ``[USER/TOOL]`` (the mask alone can't distinguish the two). Override + for env-specific formatting (e.g. parsing structured tool calls). + """ + parts: List[str] = [] + if isinstance(prompt, list) and prompt and isinstance(prompt[0], dict): + for msg in prompt: + role = str(msg.get("role", "user")).upper() + content = msg.get("content", "") + parts.append(f"[{role}]\n{content}") + elif prompt is not None: + parts.append(f"[USER]\n{prompt}") + + if response_token_ids: + if loss_mask and len(loss_mask) == len(response_token_ids): + cur_role: Optional[str] = None + cur_tokens: List[int] = [] + for tok, m in zip(response_token_ids, loss_mask): + new_role = "ASSISTANT" if m == 1 else "USER/TOOL" + if cur_role is None: + cur_role = new_role + cur_tokens = [tok] + elif new_role == cur_role: + cur_tokens.append(tok) + else: + decoded = tokenizer.decode(cur_tokens, skip_special_tokens=True) + parts.append(f"[{cur_role}]\n{decoded}") + cur_role = new_role + cur_tokens = [tok] + if cur_tokens and cur_role is not None: + decoded = tokenizer.decode(cur_tokens, skip_special_tokens=True) + parts.append(f"[{cur_role}]\n{decoded}") + else: + decoded = tokenizer.decode(response_token_ids, skip_special_tokens=True) + parts.append(f"[ASSISTANT]\n{decoded}") + + return "\n\n".join(parts) + + def count_assistant_turns(self, loss_mask: Optional[List[int]]) -> int: + """Count contiguous 1-spans in ``loss_mask`` (each = one assistant turn). + + Default for the ``num_turns`` column when the caller doesn't supply + one. Returns 0 for empty/None masks. + """ + if not loss_mask: + return 0 + turns = 0 + prev = 0 + for m in loss_mask: + if m == 1 and prev == 0: + turns += 1 + prev = m + return turns diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_trainer_full_checkpointing.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_trainer_full_checkpointing.py index 12eb937819..35ff9fa310 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_trainer_full_checkpointing.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_trainer_full_checkpointing.py @@ -111,7 +111,7 @@ def create_minimal_trainer(cfg: SkyRLTrainConfig): tracker = Tracking( project_name=cfg.trainer.project_name, experiment_name=cfg.trainer.run_name, - backends=cfg.trainer.logger, + backend=cfg.trainer.logger, config=cfg, ) diff --git a/tests/train/utils/test_logging_utils.py b/tests/train/utils/test_logging_utils.py index 784e847843..dea39d62d7 100644 --- a/tests/train/utils/test_logging_utils.py +++ b/tests/train/utils/test_logging_utils.py @@ -4,12 +4,12 @@ import pytest -from skyrl.train.utils.logging_utils import ( +from skyrl.train.utils.trajectory_logging import ( BASE_PROMPT_COLOR, NEGATIVE_RESPONSE_COLOR, POSITIVE_RESPONSE_COLOR, _color_block_format_and_kwargs, - log_example, + pretty_print_example, ) @@ -56,13 +56,13 @@ def test_color_block_format_and_kwargs_multi_line(): ([0.1, 0.2], POSITIVE_RESPONSE_COLOR), ], ) -def test_log_example_uses_expected_colors_and_reward_string(reward, expected_color): +def test_pretty_print_example_uses_expected_colors_and_reward_string(reward, expected_color): logger = StubLogger() prompt = [{"role": "user", "content": "line1\nline2"}] response = "out1\nout2" - log_example(logger, prompt=prompt, response=response, reward=reward) + pretty_print_example(logger, prompt=prompt, response=response, reward=reward) # Basic structure checks assert logger.last_message.startswith("Example:\n Input: ") @@ -84,7 +84,7 @@ def test_log_example_uses_expected_colors_and_reward_string(reward, expected_col if reward is None: assert reward_str == "N/A" else: - # log_example normalizes rewards to a single float + # pretty_print_example normalizes rewards to a single float if isinstance(reward, list): expected_val = float(sum(reward)) else: @@ -98,20 +98,20 @@ def test_log_example_uses_expected_colors_and_reward_string(reward, expected_col assert f"" in logger.last_message -def test_log_example_handles_exceptions_gracefully(monkeypatch, capsys): - """Force an exception inside log_example and ensure the fallback path prints.""" +def test_pretty_print_example_handles_exceptions_gracefully(monkeypatch, capsys): + """Force an exception inside pretty_print_example and ensure the fallback path prints.""" def broken_color_block(*args, **kwargs): raise RuntimeError("boom") # Patch the helper to raise monkeypatch.setattr( - "skyrl.train.utils.logging_utils._color_block_format_and_kwargs", + "skyrl.train.utils.trajectory_logging._color_block_format_and_kwargs", broken_color_block, ) logger = StubLogger() - log_example(logger, prompt=[{"role": "user", "content": "p"}], response="r", reward=None) + pretty_print_example(logger, prompt=[{"role": "user", "content": "p"}], response="r", reward=None) # And the plain-text fallback should be printed to stdout captured = capsys.readouterr()