add validation sample logging#1713
Conversation
Refactor example/trajectory logging: - Rename `trainer.log_example_interval` -> `print_example_interval`; add `trainer.num_logger_eval_samples` / `num_logger_train_samples` knobs. - Replace `logging_utils.log_example` with `trajectory_logging.pretty_print_example` and add a `TrajectoryLogger` class that uploads (prompt, response, reward, num_turns, trajectory) rows to a wandb table during eval and (optionally) training. Pluggable via `BasePPOExp.get_trajectory_logger`. - Simplify `Tracking` from multi-backend (`backends: List[str]`) to a single `backend: str`; add `log_samples_to_table` for accumulating wandb tables. - Update evaluate.py / trainer.py / sft_trainer.py / main_base.py / tests / docs / gsm8k example. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a new TrajectoryLogger utility to log training and evaluation trajectories (prompts, responses, and rewards) to wandb tables. It also refactors the Tracking class to support a single backend instead of multiple backends, and replaces the old log_example utility with pretty_print_example inside the new trajectory_logging.py module. The review feedback highlights two important issues: a potential data mutation/race condition when appending rows to wandb.Table in tracking.py (which can be resolved by copying the data list), and a robust turn-detection logic improvement in trajectory_logging.py to handle non-one values (like -100) in the loss mask.
| if key not in self._sample_tables: | ||
| self._sample_tables[key] = wandb.Table(columns=columns) | ||
| # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 | ||
| new_table = wandb.Table(columns=columns, data=self._sample_tables[key].data) |
There was a problem hiding this comment.
Passing self._sample_tables[key].data directly to wandb.Table passes a reference to the underlying list of rows. When new_table.add_data(*row) is called, it appends to this shared list, mutating the data of the table logged in the previous step. This can cause race conditions or data corruption during asynchronous wandb uploads. Creating a shallow copy of the list prevents this issue.
| new_table = wandb.Table(columns=columns, data=self._sample_tables[key].data) | |
| new_table = wandb.Table(columns=columns, data=list(self._sample_tables[key].data)) |
| turns = 0 | ||
| prev = 0 | ||
| for m in loss_mask: | ||
| if m == 1 and prev == 0: |
There was a problem hiding this comment.
Checking prev == 0 to detect the start of an assistant turn can fail if the loss mask contains other non-one values (such as -100 for ignored/padding tokens). Checking prev != 1 is more robust as it correctly identifies any transition from a non-assistant token to an assistant token.
| if m == 1 and prev == 0: | |
| if m == 1 and prev != 1: |
Summary
trainer.log_example_interval→print_example_interval; addtrainer.num_logger_eval_samples/num_logger_train_samples.logging_utils.log_examplewithtrajectory_logging.pretty_print_example; add aTrajectoryLoggerclass that uploads(prompt, response, reward, num_turns, trajectory)rows to a wandb table during eval and (optionally) training. Pluggable viaBasePPOExp.get_trajectory_logger.Trackingfrom multi-backend (backends: List[str]) to singlebackend: str; addlog_samples_to_tablefor accumulating wandb tables. Drops the oldValidationGenerationsLoggerdataclass.evaluate.py,trainer.py,sft_trainer.py,main_base.py, tests, docs, and the gsm8k example.Test plan
uv run --extra dev pytest tests/train/utils/test_logging_utils.pytrainer.num_logger_eval_samples=10and confirm atrajectories/evaltable appears in wandb