Skip to content

Fix ref batch token normalization#113

Merged
Connor1996 merged 1 commit intoskyzh:mainfrom
Connor1996:codex/ref-batch-fixes
Apr 12, 2026
Merged

Fix ref batch token normalization#113
Connor1996 merged 1 commit intoskyzh:mainfrom
Connor1996:codex/ref-batch-fixes

Conversation

@Connor1996
Copy link
Copy Markdown
Collaborator

@Connor1996 Connor1996 commented Apr 12, 2026

What changed

This patch fixes two batching issues in src/tiny_llm_ref/batch.py:

  • normalize logits per request by passing axis=-1 to mx.logsumexp(...)
  • remove finished requests from every layer cache after decoding, instead of only touching the leaked last batch_cache binding

Why

The main bug was that batched decoding normalized across the wrong dimension, so requests in the same batch interfered with each other and produced broken outputs.

While debugging that path, I also fixed request cleanup so a finished request is removed from all layer caches before its slot is reused.

Signed-off-by: Connor1996 <zbk602423539@gmail.com>
@Connor1996 Connor1996 force-pushed the codex/ref-batch-fixes branch from 05a7f75 to 2328736 Compare April 12, 2026 23:49
@Connor1996 Connor1996 changed the title [codex] fix ref batch token normalization Fix ref batch token normalization Apr 12, 2026
@Connor1996 Connor1996 marked this pull request as ready for review April 12, 2026 23:50
@Connor1996 Connor1996 merged commit 05972ff into skyzh:main Apr 12, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant