Skip to content

[ET-VK][ops] Add eq.Scalar operator#20383

Open
SS-JIA wants to merge 7 commits into
gh/SS-JIA/562/basefrom
gh/SS-JIA/562/head
Open

[ET-VK][ops] Add eq.Scalar operator#20383
SS-JIA wants to merge 7 commits into
gh/SS-JIA/562/basefrom
gh/SS-JIA/562/head

Conversation

@SS-JIA

@SS-JIA SS-JIA commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Adds Vulkan support for aten.eq.Scalar. This is the second of two ops needed to collapse the Llama4-mini TISO en_US backbone export to a single Vulkan partition (after bitwise_or): the discrete-speech mask compares the int token-id tensor against scalar constants via aten.eq.Scalar, which previously had no Vulkan implementation and forced a CPU fallback that split the delegated graph.

Implemented by extending the existing tensor-scalar binary-op path with a comparison-output variant: binary_scalar_buffer.glsl / binary_scalar_texture.glsl gain an IS_COMPARISON_OP code path that writes a uint8 (bool) output while leaving the existing arithmetic (e.g. pow) path unchanged; binary_scalar_buffer.yaml / binary_scalar_texture.yaml add an eq_scalar variant (half/float/int32 — the texture variant uses equal(X, Y) for per-lane bvec4, the buffer variant uses scalar X == Y); BinaryScalarOp.cpp adds an eq_tensor_scalar dispatch and VK_REGISTER_OP(aten.eq.Scalar, eq_tensor_scalar); op_registry.py registers aten.eq.Scalar OpFeatures (FP/INT tensor input, bool output). The int64 token tensor is serialized to int32 via the existing downcast_64_bit path, so the dispatch resolves to the int32 shader variant; no dtype-conversion pass is added.

This change was authored with Claude.

Differential Revision: D108457791

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 18, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20383

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 5f3b21b with merge base 1227757 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla

linux-foundation-easycla Bot commented Jun 18, 2026

Copy link
Copy Markdown

CLA Missing ID

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 18, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

[ghstack-poisoned]
SS-JIA pushed a commit that referenced this pull request Jun 24, 2026
Pull Request resolved: #20383

Adds Vulkan support for `aten.eq.Scalar`. This is the second of two ops needed to collapse the Llama4-mini TISO en_US backbone export to a single Vulkan partition (after `bitwise_or`): the discrete-speech mask compares the int token-id tensor against scalar constants via `aten.eq.Scalar`, which previously had no Vulkan implementation and forced a CPU fallback that split the delegated graph.

Implemented by extending the existing tensor-scalar binary-op path with a comparison-output variant: `binary_scalar_buffer.glsl` / `binary_scalar_texture.glsl` gain an `IS_COMPARISON_OP` code path that writes a `uint8` (bool) output while leaving the existing arithmetic (e.g. `pow`) path unchanged; `binary_scalar_buffer.yaml` / `binary_scalar_texture.yaml` add an `eq_scalar` variant (half/float/int32 — the texture variant uses `equal(X, Y)` for per-lane `bvec4`, the buffer variant uses scalar `X == Y`); `BinaryScalarOp.cpp` adds an `eq_tensor_scalar` dispatch and `VK_REGISTER_OP(aten.eq.Scalar, eq_tensor_scalar)`; `op_registry.py` registers `aten.eq.Scalar` `OpFeatures` (FP/INT tensor input, bool output). The int64 token tensor is serialized to int32 via the existing `downcast_64_bit` path, so the dispatch resolves to the int32 shader variant; no dtype-conversion pass is added.

This change was authored with Claude.
ghstack-source-id: 396618180
@exported-using-ghexport

Differential Revision: [D108457791](https://our.internmc.facebook.com/intern/diff/D108457791/)
[ghstack-poisoned]
SS-JIA pushed a commit that referenced this pull request Jun 24, 2026
Pull Request resolved: #20383

Adds Vulkan support for `aten.eq.Scalar`. This is the second of two ops needed to collapse the Llama4-mini TISO en_US backbone export to a single Vulkan partition (after `bitwise_or`): the discrete-speech mask compares the int token-id tensor against scalar constants via `aten.eq.Scalar`, which previously had no Vulkan implementation and forced a CPU fallback that split the delegated graph.

Implemented by extending the existing tensor-scalar binary-op path with a comparison-output variant: `binary_scalar_buffer.glsl` / `binary_scalar_texture.glsl` gain an `IS_COMPARISON_OP` code path that writes a `uint8` (bool) output while leaving the existing arithmetic (e.g. `pow`) path unchanged; `binary_scalar_buffer.yaml` / `binary_scalar_texture.yaml` add an `eq_scalar` variant (half/float/int32 — the texture variant uses `equal(X, Y)` for per-lane `bvec4`, the buffer variant uses scalar `X == Y`); `BinaryScalarOp.cpp` adds an `eq_tensor_scalar` dispatch and `VK_REGISTER_OP(aten.eq.Scalar, eq_tensor_scalar)`; `op_registry.py` registers `aten.eq.Scalar` `OpFeatures` (FP/INT tensor input, bool output). The int64 token tensor is serialized to int32 via the existing `downcast_64_bit` path, so the dispatch resolves to the int32 shader variant; no dtype-conversion pass is added.

This change was authored with Claude.
ghstack-source-id: 396719904
@exported-using-ghexport

Differential Revision: [D108457791](https://our.internmc.facebook.com/intern/diff/D108457791/)
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 14m 31s —— View job


Code Review: eq.Scalar for Vulkan

I read the full diff against origin/gh/SS-JIA/562/base and traced the shader codegen, the C++ dispatch, and the test harness. Overall this is a clean, well-scoped extension of the existing tensor-scalar path — the IS_COMPARISON_OP split mirrors the established pattern in binary_op_texture.glsl, the COMPUTE_T promotion is applied consistently across buffer/texture variants, and push-constant sizing matches SCALAR_T in each generated variant. A few things worth addressing before landing.


1. Stale/incorrect doc comment in resolve_scalar_extract_dtype (medium)

BinaryScalarOp.cpp:33-37 states:

"Lower precision" tensor dtype (i.e. int32) + "higher precision" scalar dtype (i.e. float/half) are currently not supported because the shader currently casts the scalar to the tensor dtype before computation.

Both claims are now false for this PR:

  • The yaml does generate an int32 + float variant (binary_scalar_buffer.yaml / .texture.yaml), and resolve_scalar_extract_dtype happily routes an int32 tensor + float scalar to it (returns kFloat, suffix resolves to ..._int32_float).
  • The shader no longer casts the scalar to the tensor dtype; it promotes both operands to COMPUTE_T = get_higher_precision_dtype(DTYPE, SCALAR_VALUE_TYPE) (i.e. float for the int32+float case).

The comment predates the promotion logic it now sits above and will mislead the next reader. Please update it to describe the actual COMPUTE_T promotion + the four generated combos. Fix this →

2. int64 is in the registered dtype set but has no shader variant (low / latent)

op_registry.py:334 registers inputs_dtypes=utils.FP_INT_T, and FP_INT_T includes torch.int64 (utils.py:65,72). The generated variants only cover {half, float, int32} tensor dtypes. The PR relies on the downcast_64_bit pass converting the int64 token tensor to int32 before dispatch — if that pass is ever skipped or a model reaches this op with a live kLong tensor, add_dtype_suffix will produce eq_scalar_..._int64_... and the shader lookup will fail at runtime rather than at partitioning time. This matches the existing pow behavior, so it's not a regression, but a one-line comment noting the int64→int32 downcast dependency (or a registry note) would save a future debugging session.

3. Float-promotion precision for large int32 values (low / latent)

get_higher_precision_dtype makes the float family always outrank int, so an int32 tensor compared against a float scalar computes in float32. float32's 24-bit mantissa cannot exactly represent |x| > 2^24, so eq could report a false match for large integer tensor values against a float scalar. This is fine for the actual use case here — the discrete-speech mask compares against integer constants, which routes to the exact int32_int32 variant — and the only float case in the tests (((M1,), 3.5)) is all-false by construction. But it's a latent correctness trap for the generic int32 + float path; the docstring's claim that "integer values in these shaders are small enough to be exactly representable" is an assumption, not an invariant. Worth a brief note at the registration site.

4. Test coverage skips the half/float comparison variants (nit)

cases.py registers eq.Scalar with dtypes = ["at::kInt"] only. The half/float comparison shader variants (eq_scalar_..._half_float, ..._float_float) are generated but never exercised by op_tests — only the arithmetic pow path covers float. Adding a float-tensor case would close the gap and validate the equal()/== → bool output path for FP inputs too.

5. Minor

  • define_required_extensions(STORAGE, DTYPE) + (STORAGE, PROMOTED_DTYPE) can emit duplicate #extension ... : require lines when they coincide. Harmless in GLSL, just noise.
  • The AT_SCALAR branch in gen_computegraph.py:478-480 emits add_scalar_to_graph(*graph, ...), which assumes self.dot == "->" (pointer graph). That holds for both correctness and benchmark generators today, but the * deref is silently coupled to that — consistent with the existing deref handling at line 509, so acceptable.

Things I checked that are correct:

  • PushConstantDataInfo(&scalar_val, sizeof(...)) copies via memcpy in its constructor (PushConstantData.h:50-59), so the scoped local scalar_val going out of scope is safe.
  • bvec4u8vec4 (VEC4_OUT_T(equal(X,Y))) and scalar booluint8 conversions are well-defined.
  • The buffer shader correctly omits binary_op_defs.glslh for comparison ops and never defines COMPUTE_VEC4_T, so the vector power_of overload isn't compiled there.
  • Push-constant SCALAR_T width matches the pushed scalar in every variant (int32→int, float→float).

None of the findings are blocking; #1 is the one I'd most want fixed since the comment actively contradicts the code.

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants