Skip to content

[feat] Add prompt_mean loss reduction#1719

Open
erictang000 wants to merge 1 commit into
NovaSky-AI:mainfrom
erictang000:prompt_mean_loss_reduction
Open

[feat] Add prompt_mean loss reduction#1719
erictang000 wants to merge 1 commit into
NovaSky-AI:mainfrom
erictang000:prompt_mean_loss_reduction

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

Add a prompt_mean option to algorithm.loss_reduction: compute the token-mean within each prompt group (the n_samples_per_prompt responses sampled for a prompt), then average over prompts. Each token [i, t] in prompt p is scaled by 1 / (num_prompts * tokens_in_prompt_p) so that summing the per-token policy loss yields mean_p(token_mean within prompt p). Unlike token_mean, every prompt contributes equally regardless of total token count.

  • preprocess: add compute_prompt_boundaries(uids) for per-prompt slices (works for step-wise and non-step-wise training).
  • trainer: thread per-prompt boundaries through metadata into _normalize_advantages, rebased to mini-batch-relative indices.
  • ppo_utils: implement the prompt_mean branch in apply_loss_reduction_to_advantages_minibatch.
  • config validation + docs updated for the new option.
  • unit tests for compute_prompt_boundaries and prompt_mean.

Add a `prompt_mean` option to `algorithm.loss_reduction`: compute the
token-mean within each prompt group (the `n_samples_per_prompt` responses
sampled for a prompt), then average over prompts. Each token [i, t] in
prompt p is scaled by 1 / (num_prompts * tokens_in_prompt_p) so that summing
the per-token policy loss yields mean_p(token_mean within prompt p). Unlike
`token_mean`, every prompt contributes equally regardless of total token count.

- preprocess: add `compute_prompt_boundaries(uids)` for per-prompt slices
  (works for step-wise and non-step-wise training).
- trainer: thread per-prompt boundaries through metadata into
  `_normalize_advantages`, rebased to mini-batch-relative indices.
- ppo_utils: implement the `prompt_mean` branch in
  `apply_loss_reduction_to_advantages_minibatch`.
- config validation + docs updated for the new option.
- unit tests for `compute_prompt_boundaries` and `prompt_mean`.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new loss reduction strategy called prompt_mean, which computes the average token loss within each prompt group and then averages over all prompts. The changes include implementing the prompt boundary computation, updating the advantage normalization logic in the trainer, adding configuration validation, and providing comprehensive unit tests. There are no review comments, so I have no feedback to provide.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant