Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/content/docs/configuration/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ run_name: "test_run"
log_path: "/tmp/skyrl-logs"
dump_data_batch: false
dump_eval_results: true
log_example_interval: 1
print_example_interval: 1

```

Expand All @@ -128,7 +128,7 @@ log_example_interval: 1
- `log_path`: Path for infrastructure log files. Infrastructure logs (vLLM engine startup, model loading, worker initialization) are written to `{log_path}/infra-YYMMDD_HHMMSS.log`. For multi-node training, use a shared filesystem path to consolidate logs into a single file. See the [logging guide](../checkpointing-logging/logging) for details.
- `dump_data_batch`: Whether to dump the data batch to a file. This is useful for debugging. When `true`, the data batch will be dumped to a file in the `export_path` directory. The training batch at global step `N` is saved to `self.cfg.trainer.export_path / "dumped_data" / global_step_N_training_input`
- `dump_eval_results`: Whether to dump the evaluation results to a file. When `true`, the full evaluation results will be dumped to a file in the `export_path` directory. The evaluation results at global step `N` is saved to `self.cfg.trainer.export_path / "dumped_eval" / global_step_N_eval_results`
- `log_example_interval`: Log an example prompt every N training steps, `0`/`-1` to disable.
- `print_example_interval`: Pretty-print an example prompt/response/reward to stdout every N training steps, `0`/`-1` to disable.

## Training Backends

Expand Down
3 changes: 2 additions & 1 deletion examples/train/gsm8k/run_gsm8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \
trainer.policy.model.path="Qwen/Qwen3-0.6B" \
trainer.num_logger_eval_samples=10 \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
Expand Down
17 changes: 15 additions & 2 deletions skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,21 @@ class TrainerConfig(BaseConfig):
dump_eval_results: bool = True
rope_scaling: Optional[Dict[str, Any]] = None
rope_theta: Optional[float] = None
log_example_interval: int = 1
"""Log an example prompt every N training steps, ``0``/``-1`` to disable"""
print_example_interval: int = 1
"""Pretty-print an example prompt/response/reward to stdout every N
training steps; ``0``/``-1`` disables. Renamed from ``log_example_interval``."""
num_logger_eval_samples: int = -1
"""Number of evaluation trajectory (prompt, response, score) tuples to upload to a wandb
table on each eval. ``-1`` (default) or ``0`` disables. When positive,
up to this many samples are taken from the start of each eval pass and
logged via :class:`TrajectoryLogger`. Column count is fixed
by the first call, so keep the eval set size and this value stable."""
num_logger_train_samples: int = -1
"""Number of training trajectory (prompt, response, score) tuples to upload to a wandb
table on each training step. ``-1`` (default) or ``0`` disables. When positive,
up to this many samples are taken from the start of each training step and
logged via :class:`TrajectoryLogger`. Column count is fixed
by the first call, so keep the training set size and this value stable."""
logprobs_chunk_size: Optional[int] = 1024
"""Chunk size along the sequence dimension when computing log-probs from logits.
This lowers peak GPU memory at the cost of ~2x wall-clock time.
Expand Down
19 changes: 18 additions & 1 deletion skyrl/train/entrypoints/main_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from skyrl.train.trainer import RayPPOTrainer
from skyrl.train.utils import validate_cfg
from skyrl.train.utils.tracking import Tracking
from skyrl.train.utils.trajectory_logging import TrajectoryLogger
from skyrl.train.utils.utils import (
ResolvedPlacementGroup,
get_ray_pg_ready_with_timeout,
Expand Down Expand Up @@ -278,10 +279,24 @@ def get_tracker(self):
return Tracking(
project_name=self.cfg.trainer.project_name,
experiment_name=self.cfg.trainer.run_name,
backends=self.cfg.trainer.logger,
backend=self.cfg.trainer.logger,
config=self.cfg,
)

def get_trajectory_logger(self) -> TrajectoryLogger:
"""Initializes the trajectory logger used during eval (and optionally
training) to upload (prompt, response, reward) samples to wandb.

Override in a subclass to swap in a project-specific
:class:`TrajectoryLogger` (custom columns, trajectory rendering, or
wandb key). The instance is cheap and statefully accumulates rows
across evals via the wandb-Table re-create workaround.

Returns:
TrajectoryLogger: The trajectory logger.
"""
return TrajectoryLogger()

def get_inference_client(self) -> InferenceEngineInterface:
"""Setup and return the inference engine client.

Expand Down Expand Up @@ -387,6 +402,8 @@ def _setup_trainer(self):
generator=generator,
colocate_pg=self.colocate_pg,
)
# Install the trajectory logger after construction
trainer.trajectory_logger = self.get_trajectory_logger()
# Expose the trainer on self so callers can log exceptions raised
# during `build_models` (which happens before _setup_trainer returns).
self.trainer = trainer
Expand Down
67 changes: 58 additions & 9 deletions skyrl/train/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from collections import defaultdict
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional

if TYPE_CHECKING:
from skyrl.train.utils.tracking import Tracking

import torch
from loguru import logger
Expand All @@ -22,12 +25,12 @@
prepare_generator_input,
)
from skyrl.train.utils import Timer
from skyrl.train.utils.logging_utils import log_example
from skyrl.train.utils.trainer_utils import (
calculate_per_dataset_metrics,
dump_per_dataset_eval_results,
validate_generator_output,
)
from skyrl.train.utils.trajectory_logging import TrajectoryLogger, pretty_print_example


@torch.no_grad()
Expand All @@ -37,6 +40,8 @@ async def evaluate(
cfg: SkyRLTrainConfig,
global_step: int | None,
tokenizer: AutoTokenizer,
trajectory_logger: Optional[TrajectoryLogger] = None,
tracker: Optional["Tracking"] = None,
) -> Dict[str, float]:
"""Runs generation and evaluation of trajectories.

Expand All @@ -57,6 +62,7 @@ async def evaluate(
concat_all_envs: List[str] = []
concat_env_extras: List[Dict[str, Any]] = []
concat_uids: List[str] = []
concat_prompts: List[str] = []
sampling_params = cfg.generator.eval_sampling_params
pbar = tqdm(total=len(eval_dataloader), initial=0, desc="Evaluation Progress")
for _, prompts in enumerate(eval_dataloader):
Expand All @@ -75,20 +81,34 @@ async def evaluate(
concat_all_envs.extend(generator_input["env_classes"])
concat_env_extras.extend(generator_input["env_extras"])
concat_uids.extend(uids)
concat_prompts.extend(generator_input["prompts"])
concat_generator_outputs: GeneratorOutput = concatenate_generator_outputs(generator_outputs)

# Extract data_sources from env_extras
concat_data_sources = [env_extra.get("data_source") for env_extra in concat_env_extras]

if cfg.trainer.log_example_interval > 0:
if cfg.trainer.print_example_interval > 0:
vis = tokenizer.decode(generator_output["response_ids"][0])
log_example(
pretty_print_example(
logger,
prompt=generator_input["prompts"][0],
response=vis,
reward=generator_output["rewards"][0],
)

# Optionally upload up to `num_logger_eval_samples` samples to tracker (wandb)
if trajectory_logger is not None:
with Timer("log_eval_results"):
trajectory_logger.log(
tracker=tracker,
num_samples=cfg.trainer.num_logger_eval_samples,
prompts=concat_prompts,
generator_output=concat_generator_outputs,
tokenizer=tokenizer,
global_step=global_step,
wandb_key="trajectories/eval",
)

# 2. Group data by data source and calculate per-dataset metrics
eval_metrics = calculate_per_dataset_metrics(
concat_generator_outputs, concat_uids, concat_data_sources, cfg.generator.eval_n_samples_per_prompt
Expand Down Expand Up @@ -137,6 +157,8 @@ async def evaluate_step_wise(
cfg: SkyRLTrainConfig,
global_step: int | None,
tokenizer: AutoTokenizer,
trajectory_logger: Optional[TrajectoryLogger] = None,
tracker: Optional["Tracking"] = None,
) -> Dict[str, float]:
"""Runs generation and evaluation of trajectories for step-wise training.

Expand All @@ -159,6 +181,7 @@ async def evaluate_step_wise(
concat_all_envs: List[str] = []
concat_env_extras: List[Dict[str, Any]] = []
concat_uids: List[str] = []
concat_prompts: List[str] = []
sampling_params = cfg.generator.eval_sampling_params
pbar = tqdm(total=len(eval_dataloader), initial=0, desc="Evaluation Progress")
for _, prompts in enumerate(eval_dataloader):
Expand All @@ -173,24 +196,32 @@ async def evaluate_step_wise(
)
generator_output: GeneratorOutput = await generator.generate(generator_input)
traj_id_to_input = {
traj_id.instance_id: {"env_class": env_class, "env_extras": env_extra}
for traj_id, env_class, env_extra in zip(
generator_input["trajectory_ids"], generator_input["env_classes"], generator_input["env_extras"]
traj_id.instance_id: {
"env_class": env_class,
"env_extras": env_extra,
"prompt": prompt,
}
for traj_id, env_class, env_extra, prompt in zip(
generator_input["trajectory_ids"],
generator_input["env_classes"],
generator_input["env_extras"],
generator_input["prompts"],
)
}
for traj_id in generator_output["trajectory_ids"]:
assert traj_id.instance_id in traj_id_to_input, f"Trajectory ID {traj_id.instance_id} not found in input"
concat_all_envs.append(traj_id_to_input[traj_id.instance_id]["env_class"])
concat_env_extras.append(traj_id_to_input[traj_id.instance_id]["env_extras"])
concat_uids.append(traj_id.instance_id)
concat_prompts.append(traj_id_to_input[traj_id.instance_id]["prompt"])
validate_generator_output(generator_input, generator_output, step_wise=True)
generator_outputs.append(generator_output)
concat_generator_outputs: GeneratorOutput = concatenate_generator_outputs(generator_outputs)

# Extract data_sources from env_extras
concat_data_sources = [env_extra.get("data_source") for env_extra in concat_env_extras]

if cfg.trainer.log_example_interval > 0:
if cfg.trainer.print_example_interval > 0:
vis = tokenizer.decode(generator_output["response_ids"][0])
logger.info(f"Eval output example: {vis}")

Expand All @@ -209,6 +240,24 @@ async def evaluate_step_wise(
data_sources_last_step = [
data_source for data_source, is_last_step in zip(concat_data_sources, is_last_step_mask) if is_last_step
]
prompts_last_step = [prompt for prompt, is_last_step in zip(concat_prompts, is_last_step_mask) if is_last_step]

# Optionally upload up to `num_logger_eval_samples` samples to wandb.
# For step-wise we override the logger's default loss-mask-based
# num_turns with the total step count per trajectory (counted *before*
# the last-step filter).
if trajectory_logger is not None:
trajectory_step_counts = Counter(concat_uids)
trajectory_logger.log(
tracker=tracker,
num_samples=cfg.trainer.num_logger_eval_samples,
prompts=prompts_last_step,
generator_output=generator_output_last_step,
tokenizer=tokenizer,
global_step=global_step,
num_turns_list=[trajectory_step_counts[uid] for uid in uids_last_step],
wandb_key="trajectories/eval",
)

# 2. Group data by data source and calculate per-dataset metrics
eval_metrics = calculate_per_dataset_metrics(
Expand Down
2 changes: 1 addition & 1 deletion skyrl/train/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def _init_tracker(self):
self.tracker = Tracking(
project_name=self.cfg.trainer.project_name,
experiment_name=self.cfg.trainer.run_name,
backends=self.cfg.trainer.logger,
backend=self.cfg.trainer.logger,
config=self.sft_cfg,
)

Expand Down
30 changes: 25 additions & 5 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
get_ray_pg_ready_with_timeout,
trainer_utils,
)
from skyrl.train.utils.logging_utils import log_example
from skyrl.train.utils.ray_gpu_monitor import RayGpuMonitor
from skyrl.train.utils.tracking import Tracking
from skyrl.train.utils.trainer_utils import (
Expand All @@ -82,6 +81,7 @@
validate_generator_output,
zero_variance_filter,
)
from skyrl.train.utils.trajectory_logging import TrajectoryLogger, pretty_print_example
from skyrl.train.utils.utils import ResolvedPlacementGroup, configure_ray_worker_logging
from skyrl.train.utils.vllm_metrics_scraper import VLLMMetricsScraper

Expand Down Expand Up @@ -127,6 +127,9 @@ def __init__(

self._ray_gpu_monitor = RayGpuMonitor() if cfg.trainer.enable_ray_gpu_monitor else None

# trajectory logger is installed after construction if needed
self.trajectory_logger: TrajectoryLogger = None

# initialized in `build_models`
self.policy_model: PPORayActorGroup = None
self.critic_model: Optional[PPORayActorGroup] = None
Expand Down Expand Up @@ -180,6 +183,8 @@ async def eval(self) -> Dict[str, float]:
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
trajectory_logger=self.trajectory_logger,
tracker=self.tracker,
)
else:
eval_metrics = await evaluate(
Expand All @@ -188,6 +193,8 @@ async def eval(self) -> Dict[str, float]:
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
trajectory_logger=self.trajectory_logger,
tracker=self.tracker,
)
return eval_metrics

Expand Down Expand Up @@ -268,16 +275,29 @@ async def train(self):
with Timer("postprocess_generator_output", self.all_timings):
generator_output, uids = self.postprocess_generator_output(generator_output, uids)

# 2. print example just for debugging
log_interval = self.cfg.trainer.log_example_interval
if log_interval > 0 and self.global_step % log_interval == 0:
# 2.1 print example just for debugging
print_interval = self.cfg.trainer.print_example_interval
if print_interval > 0 and self.global_step % print_interval == 0:
vis = self.tokenizer.decode(generator_output["response_ids"][0])
log_example(
pretty_print_example(
logger,
prompt=generator_input["prompts"][0],
response=vis,
reward=generator_output["rewards"][0],
)
# 2.2 Optionally upload up to `num_logger_train_samples` samples to tracker
if self.trajectory_logger is not None:
with Timer("log_train_results"):
self.trajectory_logger.log(
tracker=self.tracker,
num_samples=self.cfg.trainer.num_logger_train_samples,
prompts=generator_input["prompts"],
generator_output=generator_output,
tokenizer=self.tokenizer,
global_step=self.global_step,
wandb_key="trajectories/train",
include_idx=False,
)

# 3. Convert GeneratorOutput to TrainingInputBatch
with Timer("convert_to_training_input", self.all_timings):
Expand Down
Loading
Loading