Skip to content

fix(dist-ckpt): handle bare BytesIO _extra_state in key recovery (TE FP8)#53

Merged
Zhichenzzz merged 1 commit into
miles-mainfrom
fix/1293-te-fp8-bytesio
Jun 19, 2026
Merged

fix(dist-ckpt): handle bare BytesIO _extra_state in key recovery (TE FP8)#53
Zhichenzzz merged 1 commit into
miles-mainfrom
fix/1293-te-fp8-bytesio

Conversation

@Zhichenzzz

Copy link
Copy Markdown

Problem

Loading a TransformerEngine FP8 distributed checkpoint crashes in _replace_sharded_keys_with_state_dict_keys: TE stores _extra_state as a bare io.BytesIO, which _unwrap_pyt_sharded_tensor returns unwrapped, so assert len(tensors) == ... raises TypeError: object of type 'BytesIO' has no len(). (Upstream NVIDIA/Megatron-LM has the same unpatched code — this is a carried fix, not a sync gap.)

Fix

When the value is a bare io.BytesIO, substitute an empty uint8 tensor. TE's set_extra_state treats an empty tensor as "no FP8 state to restore" (early return), which is correct for RL training where FP8 is not used.

Validation

Routed a realistic BytesIO _extra_state (torch.save into io.BytesIO) through the exact recovery function:

  • before: len(BytesIO) → TypeError reproduced
  • after: substituted empty uint8 tensor, no crash; a real TE module accepts it via set_extra_state.

Fixes radixark/miles#1293

…FP8)

TE FP8 _extra_state can arrive as a bare io.BytesIO in the state dict, so
len(tensors) in _replace_sharded_keys_with_state_dict_keys raises. Wrap it
as a single empty uint8 tensor, which makes TE.set_extra_state skip
restoring FP8 scaling state instead of crashing checkpoint key recovery.
@Zhichenzzz Zhichenzzz force-pushed the fix/1293-te-fp8-bytesio branch from 7f52ebe to e56a184 Compare June 18, 2026 21:39
@Zhichenzzz Zhichenzzz merged commit 87d1155 into miles-main Jun 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TE FP8 dist-checkpoint load crashes on BytesIO extra_state (len() TypeError) — not fixed upstream either

1 participant