Skip to content

Conversation

@Wohox
Copy link
Contributor

@Wohox Wohox commented Jan 22, 2026

Description

This PR adds get_backward_dw_params for TE modules, which helps manage the hooks of parameters.

For Megatron-LM, get_backward_dw_params will be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile Overview

Greptile Summary

This PR fixes a bug where weight gradient accumulation hooks were not being triggered when using CUDA graphs in Megatron-LM. The fix extracts the hook triggering logic into a reusable method _trigger_wgrad_accumulation_and_reduce_hooks() and calls it after the wgrad CUDA graph replay to ensure gradient reduction happens correctly.

  • Extracted hook triggering logic into _trigger_wgrad_accumulation_and_reduce_hooks() method in transformer_engine/pytorch/module/base.py
  • Added hook triggering after wgrad graph replay in transformer_engine/pytorch/graph.py for all TE modules that need backward_dw
  • The approach changed from the initial implementation (adding get_backward_dw_params API) to this simpler solution of triggering hooks directly

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk - it fixes a legitimate bug in hook triggering for CUDA graphs
  • The change is focused and addresses a specific bug where hooks were not triggered during CUDA graph replay. The refactoring to extract _trigger_wgrad_accumulation_and_reduce_hooks() improves code reusability. The isinstance check properly filters modules. Score is 4 rather than 5 due to lack of tests validating the fix.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/graph.py Added hook triggering logic after wgrad CUDA graph replay to ensure gradient accumulation hooks are called for TE modules
transformer_engine/pytorch/module/base.py Extracted hook triggering logic into separate _trigger_wgrad_accumulation_and_reduce_hooks() method for reusability

Sequence Diagram

sequenceDiagram
    participant ML as Megatron-LM
    participant GC as CUDA Graph (backward_dw)
    participant TEModule as TransformerEngineBaseModule
    participant Hooks as wgrad_accumulation_and_reduce_hooks
    
    ML->>GC: Call backward_dw() on graphed callable
    GC->>GC: Check if need_bwd_dw_graph is True
    GC->>GC: Replay bwd_dw_graphs[graph_idx]
    
    Note over GC,TEModule: New logic added in this PR
    
    GC->>GC: Iterate through visited_te_modules[graph_idx]
    loop For each module in te_modules
        GC->>TEModule: Check isinstance(module, TransformerEngineBaseModule)
        GC->>TEModule: Call module.need_backward_dw()
        alt module needs backward_dw
            GC->>TEModule: Call _trigger_wgrad_accumulation_and_reduce_hooks()
            TEModule->>Hooks: Trigger each registered hook
            Hooks-->>TEModule: Execute grad accumulation/reduce
        end
    end
    
    GC-->>ML: Return from backward_dw()
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@Wohox
Copy link
Contributor Author

Wohox commented Jan 22, 2026

@buptzyb @lhb8125 Please help review this PR, thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Get the parameters for the backward weight gradient computation.
"""
params = []
params.append(noop_cat(self._get_weight_tensors()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: in backward_dw() (line 1520-1522), weight tensors are only accessed when not self.fuse_wgrad_accumulation, but this method unconditionally returns weight parameters. depending on Megatron-LM's usage, this could cause hooks to be registered on parameters that shouldn't have them when fuse_wgrad_accumulation=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commit content reverted.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant