Skip to content

add validation sample logging#1713

Open
erictang000 wants to merge 1 commit into
mainfrom
log_val_samples
Open

add validation sample logging#1713
erictang000 wants to merge 1 commit into
mainfrom
log_val_samples

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 commented May 28, 2026

Originally stacked on log_ooms_to_wandb (#1706) since this PR builds on Tracking.log_exception and the _setup_trainer hook. #1706 merged into main shortly after this branch was created, so this PR was rebased onto current main (which now contains those changes) and targets main directly. Includes resolutions for benign overlaps with the recently-merged #1711 (tokens_per_second_per_gpu) and #1712 (RayGpuMonitor) in trainer.py — both RayGpuMonitor and the new TrajectoryLogger are kept.

Summary

  • Rename trainer.log_example_intervalprint_example_interval; add trainer.num_logger_eval_samples / num_logger_train_samples.
  • Replace logging_utils.log_example with trajectory_logging.pretty_print_example; add a TrajectoryLogger class that uploads (prompt, response, reward, num_turns, trajectory) rows to a wandb table during eval and (optionally) training. Pluggable via BasePPOExp.get_trajectory_logger.
  • Simplify Tracking from multi-backend (backends: List[str]) to single backend: str; add log_samples_to_table for accumulating wandb tables. Drops the old ValidationGenerationsLogger dataclass.
  • Update evaluate.py, trainer.py, sft_trainer.py, main_base.py, tests, docs, and the gsm8k example.

Test plan

  • uv run --extra dev pytest tests/train/utils/test_logging_utils.py
  • Manual eval run with trainer.num_logger_eval_samples=10 and confirm a trajectories/eval table appears in wandb

Refactor example/trajectory logging:
- Rename `trainer.log_example_interval` -> `print_example_interval`; add
  `trainer.num_logger_eval_samples` / `num_logger_train_samples` knobs.
- Replace `logging_utils.log_example` with
  `trajectory_logging.pretty_print_example` and add a `TrajectoryLogger`
  class that uploads (prompt, response, reward, num_turns, trajectory)
  rows to a wandb table during eval and (optionally) training. Pluggable
  via `BasePPOExp.get_trajectory_logger`.
- Simplify `Tracking` from multi-backend (`backends: List[str]`) to a
  single `backend: str`; add `log_samples_to_table` for accumulating
  wandb tables.
- Update evaluate.py / trainer.py / sft_trainer.py / main_base.py /
  tests / docs / gsm8k example.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new TrajectoryLogger utility to log training and evaluation trajectories (prompts, responses, and rewards) to wandb tables. It also refactors the Tracking class to support a single backend instead of multiple backends, and replaces the old log_example utility with pretty_print_example inside the new trajectory_logging.py module. The review feedback highlights two important issues: a potential data mutation/race condition when appending rows to wandb.Table in tracking.py (which can be resolved by copying the data list), and a robust turn-detection logic improvement in trajectory_logging.py to handle non-one values (like -100) in the loss mask.

if key not in self._sample_tables:
self._sample_tables[key] = wandb.Table(columns=columns)
# Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
new_table = wandb.Table(columns=columns, data=self._sample_tables[key].data)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Passing self._sample_tables[key].data directly to wandb.Table passes a reference to the underlying list of rows. When new_table.add_data(*row) is called, it appends to this shared list, mutating the data of the table logged in the previous step. This can cause race conditions or data corruption during asynchronous wandb uploads. Creating a shallow copy of the list prevents this issue.

Suggested change
new_table = wandb.Table(columns=columns, data=self._sample_tables[key].data)
new_table = wandb.Table(columns=columns, data=list(self._sample_tables[key].data))

turns = 0
prev = 0
for m in loss_mask:
if m == 1 and prev == 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Checking prev == 0 to detect the start of an assistant turn can fail if the loss mask contains other non-one values (such as -100 for ignored/padding tokens). Checking prev != 1 is more robust as it correctly identifies any transition from a non-assistant token to an assistant token.

Suggested change
if m == 1 and prev == 0:
if m == 1 and prev != 1:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant