Support PyTorch 2.12 (fx codegen, autograd_function_apply)#2825
Draft
bhimrazy wants to merge 4 commits into
Draft
Support PyTorch 2.12 (fx codegen, autograd_function_apply)#2825bhimrazy wants to merge 4 commits into
bhimrazy wants to merge 4 commits into
Conversation
torch 2.12 changed fx code generation in ways the interpreter traces through
when thunder compiles a GraphModule:
- CALL_FUNCTION_EX may unpack a generator as *args (e.g. fx emits
`"{}".format(*(_get_repr(a) for a in node.args))`). wrap_args_from_list
called len() on it, raising "object of type 'generator' has no len()".
Materialize non-sequence iterables via interpreted iteration before indexing.
- fx codegen calls list.insert (e.g. `free_vars.insert(0, "self")`), which
was an unimplemented Sequence.insert stub. Implement it following the
append/pop item-wrapper pattern.
- MappingKeysView / MappingValuesWrapper / MappingItemsWrapper lacked __len__,
so len() on a dict view raised during tracing. Delegate to the mapping.
Adds focused interpreter tests for each pattern.
torch 2.12 changed torch.ops.higher_order.autograd_function_apply and the GraphModule dynamo produces for it: - A new required `saved_for_backward_idx` kwarg. Accept (and ignore) it on the thunder symbol and augmented forward for API parity; supply it in the tests (mirroring the existing _detect_has_args_tensor_mask pattern). - fx GraphModule recompile() (triggered while interpreting the module) writes framework-internal state -- `_code` (a source string) and `_graph` (an fx.Graph) -- into the module __dict__. These were recorded as module-member modifications and corrupted prologue/epilogue provenance unpacking. Only record writes whose value is trackable computational state. - dynamo now wraps the forward output in a tuple and indexes it (`autograd_function_apply(...)[0]`). The lookaside unpacked a single-element output, collapsing the tuple, so the index then sliced into the tensor and produced a wrong (scalar) result. Return the output preserving its structure.
torchaudio is not imported anywhere in the codebase.
The temporary xfail added in Lightning-AI#2805 (tracked by Lightning-AI#2807) is no longer needed: the preceding commits make test_splitter_autograd_function pass on torch 2.12 (it was xfailing, now xpasses). Drop the marker, its import, and the now-unused _pytorch_removed_args_tensor_mask / xfail_if_args_tensor_mask_removed helpers. Closes Lightning-AI#2807.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Before submitting
xfails once PyTorch autograd stabilizes #2807)What does this PR do?
Restores PyTorch 2.12 compatibility. On
torch >= 2.12thecoreCI suites go red (latest/nightly) whileoldeststays green — affecting alltest_nanogpt_*_DynamoThunder_*, ~40test_dynamo.pycases,test_autograd_function_apply*, andtest_higher_order_inplace_alias_update.Root cause: tracing the dynamo
GraphModulemakes thunder interpret torch's fx codegen andautograd_function_apply, both of which changed internals in 2.12.Interpreter (
thunder/core/interpreter.py)generator has no len()— fx unpacks a generator as*args; materialize it before indexing.NotImplementedError: Sequence.insert— implement the stubbedlist.insert.MappingKeysView has no len()— add__len__to the mapping views.autograd_function_apply (
thunder/core/jit_ext.py,thunder/torch/__init__.py)KeyError: 'saved_for_backward_idx'— accept (and ignore) the new kwarg.trying to set ._code …— fxrecompile()writes_code/_graphinto__dict__; stop recording non-trackable module-member writes.[]vs[2]) — dynamo wraps the output in a tuple and indexes it (…apply(...)[0]); preserve the output structure instead of collapsing it.Cleanup: drop the unused
torchaudiodev dependency.All fixes are version-agnostic (no version gating): inputs are just handled more generally, the new kwarg is optional, and a bare output stays bare while a tuple stays a tuple.
oldest/latest/nightlyCI covers both ends.Tests: new focused interpreter tests (generator unpack,
list.insert, mapping-viewlen), each verified to fail pre-fix; autograd tests extended forsaved_for_backward_idx(mirroring_detect_has_args_tensor_mask).CI: ❌ before → ✅ after.
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃