diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f587ca9946..37fff943d6 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -853,12 +853,22 @@ def functionalized(*user_args, **user_kwargs): return functionalized def make_graphed_attribute_functions(graph_idx): + # Get te modules for current graph + te_modules = visited_te_modules.get(graph_idx, set()) # Attach backward_dw as an attribute to the graphed callable. def backward_dw(): if need_bwd_dw_graph.get(graph_idx, False): bwd_dw_graphs[graph_idx].replay() + # Trigger the grad accumulation hook for wgrad graphs. + for module in te_modules: + if ( + isinstance(module, TransformerEngineBaseModule) + and module.need_backward_dw() + ): + module._trigger_wgrad_accumulation_and_reduce_hooks() + # Attach reset as an attribute to the graphed callable. def reset(): fwd_graphs[graph_idx].reset() diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 841cdf04ca..09b12afa21 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1526,8 +1526,14 @@ def backward_dw(self): bias_tensor.grad = bgrad.to(bias_tensor.dtype) del wgrad del bgrad - for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: - wgrad_accumulation_and_reduce_hook() + self._trigger_wgrad_accumulation_and_reduce_hooks() + + def _trigger_wgrad_accumulation_and_reduce_hooks(self): + """ + Trigger the wgrad accumulation and reduce hooks. + """ + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() def is_debug_iter(self) -> bool: """