[feat] Add prompt_mean loss reduction#1719
Open
erictang000 wants to merge 1 commit into
Open
Conversation
Add a `prompt_mean` option to `algorithm.loss_reduction`: compute the token-mean within each prompt group (the `n_samples_per_prompt` responses sampled for a prompt), then average over prompts. Each token [i, t] in prompt p is scaled by 1 / (num_prompts * tokens_in_prompt_p) so that summing the per-token policy loss yields mean_p(token_mean within prompt p). Unlike `token_mean`, every prompt contributes equally regardless of total token count. - preprocess: add `compute_prompt_boundaries(uids)` for per-prompt slices (works for step-wise and non-step-wise training). - trainer: thread per-prompt boundaries through metadata into `_normalize_advantages`, rebased to mini-batch-relative indices. - ppo_utils: implement the `prompt_mean` branch in `apply_loss_reduction_to_advantages_minibatch`. - config validation + docs updated for the new option. - unit tests for `compute_prompt_boundaries` and `prompt_mean`. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces a new loss reduction strategy called prompt_mean, which computes the average token loss within each prompt group and then averages over all prompts. The changes include implementing the prompt boundary computation, updating the advantage normalization logic in the trainer, adding configuration validation, and providing comprehensive unit tests. There are no review comments, so I have no feedback to provide.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add a
prompt_meanoption toalgorithm.loss_reduction: compute the token-mean within each prompt group (then_samples_per_promptresponses sampled for a prompt), then average over prompts. Each token [i, t] in prompt p is scaled by 1 / (num_prompts * tokens_in_prompt_p) so that summing the per-token policy loss yields mean_p(token_mean within prompt p). Unliketoken_mean, every prompt contributes equally regardless of total token count.compute_prompt_boundaries(uids)for per-prompt slices (works for step-wise and non-step-wise training)._normalize_advantages, rebased to mini-batch-relative indices.prompt_meanbranch inapply_loss_reduction_to_advantages_minibatch.compute_prompt_boundariesandprompt_mean.