Skip to content

[loss] Add 'prompt_mean' loss aggregation#1718

Closed
erictang000 wants to merge 2 commits into
mainfrom
prompt_mean
Closed

[loss] Add 'prompt_mean' loss aggregation#1718
erictang000 wants to merge 2 commits into
mainfrom
prompt_mean

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

Add prompt mean loss aggregation as described in ScaleRL paper - details here: #495.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new prompt_mean loss reduction option for PPO training, which scales advantages at the prompt level. It adds utility functions to compute prompt boundaries from sequence UIDs, updates the trainer to rebase these boundaries for each mini-batch, and includes corresponding unit tests. The review feedback suggests adding defensive validation to ensure prompt boundaries are contiguous and fully cover the batch range to prevent silent correctness bugs. Additionally, it recommends optimizing the trainer by only rebasing prompt boundaries when the prompt_mean loss reduction is active.

Comment on lines +1058 to +1064
elif loss_reduction == "prompt_mean":
if prompt_boundaries is None:
raise ValueError("`prompt_mean` loss reduction requires `prompt_boundaries`")
num_prompts = len(prompt_boundaries)
for p_start, p_end in prompt_boundaries:
prompt_tokens = loss_mask[p_start:p_end].sum().clamp(min=1)
normalized_advantages[p_start:p_end] = advantages[p_start:p_end] / (num_prompts * prompt_tokens)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Without defensive validation on prompt_boundaries, any misalignment between mini-batch boundaries and prompt boundaries (or malformed inputs) can lead to silent correctness bugs. For example, if a prompt is excluded or partially sliced, some sequences will silently receive 0.0 normalized advantages, leading to incorrect gradients that are extremely difficult to debug in RL.

Adding contiguous and coverage checks ensures that the prompt boundaries partition the entire mini-batch exactly.

    elif loss_reduction == "prompt_mean":
        if prompt_boundaries is None:
            raise ValueError("`prompt_mean` loss reduction requires `prompt_boundaries`")
        if not prompt_boundaries:
            raise ValueError("`prompt_boundaries` cannot be empty for `prompt_mean` loss reduction")
        if prompt_boundaries[0][0] != 0 or prompt_boundaries[-1][1] != advantages.shape[0]:
            raise ValueError(
                f"prompt_boundaries {prompt_boundaries} must cover the entire advantages batch range (0, {advantages.shape[0]})"
            )
        for i in range(len(prompt_boundaries) - 1):
            if prompt_boundaries[i][1] != prompt_boundaries[i + 1][0]:
                raise ValueError(
                    f"prompt_boundaries must be contiguous, got gap between {prompt_boundaries[i]} and {prompt_boundaries[i+1]}"
                )
        num_prompts = len(prompt_boundaries)
        for p_start, p_end in prompt_boundaries:
            prompt_tokens = loss_mask[p_start:p_end].sum().clamp(min=1)
            normalized_advantages[p_start:p_end] = advantages[p_start:p_end] / (num_prompts * prompt_tokens)

Comment thread skyrl/train/trainer.py
Comment on lines +1244 to +1250
mb_prompt_boundaries = None
if prompt_boundaries is not None:
mb_prompt_boundaries = [
(p_start - start_idx, p_end - start_idx)
for p_start, p_end in prompt_boundaries
if start_idx <= p_start < end_idx
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list comprehension to filter and rebase prompt_boundaries runs for every mini-batch, even when loss_reduction is not set to "prompt_mean". Since policy_prompt_boundaries is always populated in convert_to_training_input, this introduces unnecessary overhead for other loss reduction types (e.g., "token_mean", "sequence_mean").

We can optimize this by only executing the list comprehension when loss_reduction is actually "prompt_mean".

Suggested change
mb_prompt_boundaries = None
if prompt_boundaries is not None:
mb_prompt_boundaries = [
(p_start - start_idx, p_end - start_idx)
for p_start, p_end in prompt_boundaries
if start_idx <= p_start < end_idx
]
mb_prompt_boundaries = None
if prompt_boundaries is not None and self.cfg.trainer.algorithm.loss_reduction == "prompt_mean":
mb_prompt_boundaries = [
(p_start - start_idx, p_end - start_idx)
for p_start, p_end in prompt_boundaries
if start_idx <= p_start < end_idx
]

@erictang000
Copy link
Copy Markdown
Collaborator Author

closing in favor of #1719

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant