[loss] Add 'prompt_mean' loss aggregation#1718
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new prompt_mean loss reduction option for PPO training, which scales advantages at the prompt level. It adds utility functions to compute prompt boundaries from sequence UIDs, updates the trainer to rebase these boundaries for each mini-batch, and includes corresponding unit tests. The review feedback suggests adding defensive validation to ensure prompt boundaries are contiguous and fully cover the batch range to prevent silent correctness bugs. Additionally, it recommends optimizing the trainer by only rebasing prompt boundaries when the prompt_mean loss reduction is active.
| elif loss_reduction == "prompt_mean": | ||
| if prompt_boundaries is None: | ||
| raise ValueError("`prompt_mean` loss reduction requires `prompt_boundaries`") | ||
| num_prompts = len(prompt_boundaries) | ||
| for p_start, p_end in prompt_boundaries: | ||
| prompt_tokens = loss_mask[p_start:p_end].sum().clamp(min=1) | ||
| normalized_advantages[p_start:p_end] = advantages[p_start:p_end] / (num_prompts * prompt_tokens) |
There was a problem hiding this comment.
Without defensive validation on prompt_boundaries, any misalignment between mini-batch boundaries and prompt boundaries (or malformed inputs) can lead to silent correctness bugs. For example, if a prompt is excluded or partially sliced, some sequences will silently receive 0.0 normalized advantages, leading to incorrect gradients that are extremely difficult to debug in RL.
Adding contiguous and coverage checks ensures that the prompt boundaries partition the entire mini-batch exactly.
elif loss_reduction == "prompt_mean":
if prompt_boundaries is None:
raise ValueError("`prompt_mean` loss reduction requires `prompt_boundaries`")
if not prompt_boundaries:
raise ValueError("`prompt_boundaries` cannot be empty for `prompt_mean` loss reduction")
if prompt_boundaries[0][0] != 0 or prompt_boundaries[-1][1] != advantages.shape[0]:
raise ValueError(
f"prompt_boundaries {prompt_boundaries} must cover the entire advantages batch range (0, {advantages.shape[0]})"
)
for i in range(len(prompt_boundaries) - 1):
if prompt_boundaries[i][1] != prompt_boundaries[i + 1][0]:
raise ValueError(
f"prompt_boundaries must be contiguous, got gap between {prompt_boundaries[i]} and {prompt_boundaries[i+1]}"
)
num_prompts = len(prompt_boundaries)
for p_start, p_end in prompt_boundaries:
prompt_tokens = loss_mask[p_start:p_end].sum().clamp(min=1)
normalized_advantages[p_start:p_end] = advantages[p_start:p_end] / (num_prompts * prompt_tokens)| mb_prompt_boundaries = None | ||
| if prompt_boundaries is not None: | ||
| mb_prompt_boundaries = [ | ||
| (p_start - start_idx, p_end - start_idx) | ||
| for p_start, p_end in prompt_boundaries | ||
| if start_idx <= p_start < end_idx | ||
| ] |
There was a problem hiding this comment.
The list comprehension to filter and rebase prompt_boundaries runs for every mini-batch, even when loss_reduction is not set to "prompt_mean". Since policy_prompt_boundaries is always populated in convert_to_training_input, this introduces unnecessary overhead for other loss reduction types (e.g., "token_mean", "sequence_mean").
We can optimize this by only executing the list comprehension when loss_reduction is actually "prompt_mean".
| mb_prompt_boundaries = None | |
| if prompt_boundaries is not None: | |
| mb_prompt_boundaries = [ | |
| (p_start - start_idx, p_end - start_idx) | |
| for p_start, p_end in prompt_boundaries | |
| if start_idx <= p_start < end_idx | |
| ] | |
| mb_prompt_boundaries = None | |
| if prompt_boundaries is not None and self.cfg.trainer.algorithm.loss_reduction == "prompt_mean": | |
| mb_prompt_boundaries = [ | |
| (p_start - start_idx, p_end - start_idx) | |
| for p_start, p_end in prompt_boundaries | |
| if start_idx <= p_start < end_idx | |
| ] |
|
closing in favor of #1719 |
Add prompt mean loss aggregation as described in ScaleRL paper - details here: #495.