From 150a9666b2e2b260104475b856679046caa5cd24 Mon Sep 17 00:00:00 2001 From: Pingtian Li Date: Thu, 22 Jan 2026 01:02:04 -0800 Subject: [PATCH 1/4] add get_backward_dw_params --- transformer_engine/pytorch/module/base.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 841cdf04ca..8efe6a53b5 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1529,6 +1529,16 @@ def backward_dw(self): for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: wgrad_accumulation_and_reduce_hook() + def get_backward_dw_params(self): + """ + Get the parameters for the backward weight gradient computation. + """ + params = [] + params.append(noop_cat(self._get_weight_tensors())) + if self.use_bias: + params.append(noop_cat([getattr(self, name) for name in self.bias_names])) + return params + def is_debug_iter(self) -> bool: """ This function checks if the debug should be enabled for this layer. From d04c008e87174a4cdbaafe47cd807fcdf79553e1 Mon Sep 17 00:00:00 2001 From: Pingtian Li Date: Sun, 25 Jan 2026 18:49:01 -0800 Subject: [PATCH 2/4] revert get_backward_dw_params and trigger hook after wgrad graph execution --- transformer_engine/pytorch/graph.py | 9 +++++++-- transformer_engine/pytorch/module/base.py | 14 +++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f587ca9946..4dd62b07bd 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -852,13 +852,18 @@ def functionalized(*user_args, **user_kwargs): return functionalized - def make_graphed_attribute_functions(graph_idx): + def make_graphed_attribute_functions(graph_idx, te_modules): # Attach backward_dw as an attribute to the graphed callable. def backward_dw(): if need_bwd_dw_graph.get(graph_idx, False): bwd_dw_graphs[graph_idx].replay() + # Trigger the grad accumulation hook for wgrad graphs. + for module in te_modules: + if isinstance(module, TransformerEngineBaseModule) and module.need_backward_dw(): + module._trigger_wgrad_accumulation_and_reduce_hooks() + # Attach reset as an attribute to the graphed callable. def reset(): fwd_graphs[graph_idx].reset() @@ -942,7 +947,7 @@ def new_fwd(*user_args, **user_kwargs): else: ret.append(graphed) - backward_dw_func, reset_func = make_graphed_attribute_functions(i) + backward_dw_func, reset_func = make_graphed_attribute_functions(i, te_modules) setattr(ret[-1], "backward_dw", backward_dw_func) setattr(ret[-1], "reset", reset_func) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8efe6a53b5..09b12afa21 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1526,18 +1526,14 @@ def backward_dw(self): bias_tensor.grad = bgrad.to(bias_tensor.dtype) del wgrad del bgrad - for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: - wgrad_accumulation_and_reduce_hook() + self._trigger_wgrad_accumulation_and_reduce_hooks() - def get_backward_dw_params(self): + def _trigger_wgrad_accumulation_and_reduce_hooks(self): """ - Get the parameters for the backward weight gradient computation. + Trigger the wgrad accumulation and reduce hooks. """ - params = [] - params.append(noop_cat(self._get_weight_tensors())) - if self.use_bias: - params.append(noop_cat([getattr(self, name) for name in self.bias_names])) - return params + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() def is_debug_iter(self) -> bool: """ From 9894713de2d9e68d5a88e16f704013cea3419e01 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jan 2026 02:50:09 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 4dd62b07bd..6bebc9f0e5 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -861,7 +861,10 @@ def backward_dw(): # Trigger the grad accumulation hook for wgrad graphs. for module in te_modules: - if isinstance(module, TransformerEngineBaseModule) and module.need_backward_dw(): + if ( + isinstance(module, TransformerEngineBaseModule) + and module.need_backward_dw() + ): module._trigger_wgrad_accumulation_and_reduce_hooks() # Attach reset as an attribute to the graphed callable. From 52985d3168be0d1c0c5b25e55b34705963469363 Mon Sep 17 00:00:00 2001 From: Pingtian Li Date: Mon, 26 Jan 2026 00:09:51 -0800 Subject: [PATCH 4/4] simpler api --- transformer_engine/pytorch/graph.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 6bebc9f0e5..37fff943d6 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -852,7 +852,9 @@ def functionalized(*user_args, **user_kwargs): return functionalized - def make_graphed_attribute_functions(graph_idx, te_modules): + def make_graphed_attribute_functions(graph_idx): + # Get te modules for current graph + te_modules = visited_te_modules.get(graph_idx, set()) # Attach backward_dw as an attribute to the graphed callable. def backward_dw(): @@ -950,7 +952,7 @@ def new_fwd(*user_args, **user_kwargs): else: ret.append(graphed) - backward_dw_func, reset_func = make_graphed_attribute_functions(i, te_modules) + backward_dw_func, reset_func = make_graphed_attribute_functions(i) setattr(ret[-1], "backward_dw", backward_dw_func) setattr(ret[-1], "reset", reset_func)