-
Notifications
You must be signed in to change notification settings - Fork 614
[Pytorch] Add get_backward_dw_params api for TE module #2614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Pytorch] Add get_backward_dw_params api for TE module #2614
Conversation
Greptile OverviewGreptile SummaryThis PR fixes a bug where weight gradient accumulation hooks were not being triggered when using CUDA graphs in Megatron-LM. The fix extracts the hook triggering logic into a reusable method
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant ML as Megatron-LM
participant GC as CUDA Graph (backward_dw)
participant TEModule as TransformerEngineBaseModule
participant Hooks as wgrad_accumulation_and_reduce_hooks
ML->>GC: Call backward_dw() on graphed callable
GC->>GC: Check if need_bwd_dw_graph is True
GC->>GC: Replay bwd_dw_graphs[graph_idx]
Note over GC,TEModule: New logic added in this PR
GC->>GC: Iterate through visited_te_modules[graph_idx]
loop For each module in te_modules
GC->>TEModule: Check isinstance(module, TransformerEngineBaseModule)
GC->>TEModule: Call module.need_backward_dw()
alt module needs backward_dw
GC->>TEModule: Call _trigger_wgrad_accumulation_and_reduce_hooks()
TEModule->>Hooks: Trigger each registered hook
Hooks-->>TEModule: Execute grad accumulation/reduce
end
end
GC-->>ML: Return from backward_dw()
|
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
| Get the parameters for the backward weight gradient computation. | ||
| """ | ||
| params = [] | ||
| params.append(noop_cat(self._get_weight_tensors())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: in backward_dw() (line 1520-1522), weight tensors are only accessed when not self.fuse_wgrad_accumulation, but this method unconditionally returns weight parameters. depending on Megatron-LM's usage, this could cause hooks to be registered on parameters that shouldn't have them when fuse_wgrad_accumulation=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commit content reverted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
Description
This PR adds
get_backward_dw_paramsfor TE modules, which helps manage the hooks of parameters.For Megatron-LM,
get_backward_dw_paramswill be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: