-
Notifications
You must be signed in to change notification settings - Fork 814
Description
π Describe the bug
I am encountering an issue where a simple reduce_mean operation is not fully delegated to the Ethos-U NPU. Specifically, for an input of shape (1, 256, 400), the partitioner fails with:
No support for shape=[1, 1, 256, 400], dtype=torch.int8. Product of axes must be <65536
However, the equivalent graph is fully delegated to the NPU when using TFLite -> Vela path. This suggests that it's a lowering implementation issue rather than Vela/hardware limitation?
Reproduce
Here is the ExecuTorch script to reproduce:
import logging
from pathlib import Path
import torch
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
)
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from executorch.extension.export_util.utils import save_pte_program
from torch import nn
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
OUTPUT_DIR = f"out_{Path(__file__).stem}"
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
class MeanModel(nn.Module):
def forward(self, x):
return x.mean(dim=(1, 2), keepdim=True)
model = MeanModel()
print(model)
batch_size = 1
example_inputs = (torch.randn(batch_size, 256, 400),)
COMPILE_SPEC = EthosUCompileSpec(
target="ethos-u55-128",
system_config="Ethos_U55_High_End_Embedded",
memory_mode="Shared_Sram",
extra_flags=["--output-format=raw", "--debug-force-regor"],
)
log_file = f"{OUTPUT_DIR}/partitioner.log"
logger = logging.getLogger("executorch.backends.arm.tosa.partitioner")
handler = logging.FileHandler(log_file, mode="w")
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
exported_program = torch.export.export(model, example_inputs, strict=True)
torch.export.save(exported_program, f"{OUTPUT_DIR}/exported_program.pt2")
graph_module = exported_program.module()
# Quantize
quantizer = EthosUQuantizer(COMPILE_SPEC)
operator_config = get_symmetric_quantization_config()
quantizer.set_global(operator_config)
quantized_graph_module = prepare_pt2e(graph_module, quantizer)
quantized_graph_module(*example_inputs)
quantized_graph_module = convert_pt2e(quantized_graph_module)
quantized_exported_program = torch.export.export(quantized_graph_module, example_inputs)
torch.export.save(quantized_exported_program, f"{OUTPUT_DIR}/quantized_exported_program.pt2")
# Lower the quantized program to edge and compile to executorch
partitioner = EthosUPartitioner(COMPILE_SPEC)
edge_program_manager = to_edge_transform_and_lower(
quantized_exported_program,
partitioner=[partitioner],
compile_config=EdgeCompileConfig(
_check_ir_validity=True,
),
)
executorch_program_manager = edge_program_manager.to_executorch(config=ExecutorchBackendConfig())
save_pte_program(executorch_program_manager, f"{OUTPUT_DIR}/ethos_u.pte")This is the partitioner.log (ACTUAL BEHAVIOR):
TOSAPartitioner::partition
Partitioning for EthosUBackend: TOSA-1.0+INT+int16+u55
The following nodes were rejected for TOSA-1.0+INT+int16+u55:
ββββββββββββββββββββββββββ€βββββββββββββββββββββββββ€ββββββββββββββββββββββββββββ€βββββββββββββββββββββββββββββββββββββ
β Node name β Target β Torch func β Reason β
ββββββββββββββββββββββββββͺβββββββββββββββββββββββββͺββββββββββββββββββββββββββββͺβββββββββββββββββββββββββββββββββββββ‘
β aten_view_copy_default β aten.view_copy.default β ('mean_1', β No support for shape=[1, 1, 256, β
β β β 'method_descriptor.mean') β 400], dtype=torch.int8. Product of β
β β β β axes must be <65536 β
ββββββββββββββββββββββββββ§βββββββββββββββββββββββββ§ββββββββββββββββββββββββββββ§βββββββββββββββββββββββββββββββββββββ
(Placeholders and outputs are not included in this list)
While here is an example of the equivalent graph conversion via TFLite - Vela flow:
import subprocess
from collections.abc import Iterable
from pathlib import Path
import numpy as np
import tensorflow as tf
OUTPUT_DIR = Path(f"out_{Path(__file__).stem}")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
class MeanLayer(tf.keras.layers.Layer):
def call(self, inputs: tf.Tensor) -> tf.Tensor:
return tf.reduce_mean(inputs, axis=(1, 2), keepdims=True)
input_shape = (256, 400)
inputs = tf.keras.Input(shape=input_shape)
outputs = MeanLayer()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()
model_fn = tf.function(func=model)
batch_size = 1
cf = model_fn.get_concrete_function(tf.TensorSpec(shape=(batch_size, *input_shape), dtype=tf.float32))
def representative_dataset(num_samples: int = 10) -> Iterable[list[np.ndarray]]:
rng = np.random.default_rng(seed=0)
for _ in range(num_samples):
sample = rng.normal(loc=0.0, scale=1.0, size=input_shape).astype(np.float32)
yield [sample]
converter = tf.lite.TFLiteConverter.from_concrete_functions([cf], trackable_obj=model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_bytes = converter.convert()
int8_path = OUTPUT_DIR / "mean_int8.tflite"
int8_path.write_bytes(tflite_bytes)
cmd = [
"vela",
str(int8_path),
"--output-dir",
OUTPUT_DIR.as_posix(),
"--accelerator-config",
"ethos-u55-128",
"--system-config",
"Ethos_U55_High_End_Embedded",
"--memory-mode",
"Shared_Sram",
"--config",
"Arm/vela.ini",
]
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
log_path = OUTPUT_DIR / "vela.log"
log_path.write_text(result.stdout)
print(result.stdout)Inspecting mean_int8_vela.tflite with Netron or Model Explorer shows everything in a single ethos-u block (EXPECTED BEHAVIOR).
Versions
Collecting environment information...
PyTorch version: 2.11.0.dev20251222+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.31.10
Libc version: glibc-2.35
Python version: 3.12.11 (main, Sep 18 2025, 19:47:19) [Clang 20.1.4 ] (64-bit runtime)
Python platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Vendor ID: GenuineIntel
Model name: 11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz
CPU family: 6
Model: 140
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
Stepping: 1
BogoMIPS: 5990.42
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves vnmi avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid movdiri movdir64b fsrm avx512_vp2intersect md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 192 KiB (4 instances)
L1i cache: 128 KiB (4 instances)
L2 cache: 5 MiB (4 instances)
L3 cache: 12 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-7
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] executorch==1.1.0a0+0391fe7
[pip3] numpy==2.4.0
[pip3] pytorch_tokenizers==1.0.1
[pip3] torch==2.11.0.dev20251222+cpu
[pip3] torchao==0.14.0+git01849b2b1
[pip3] torchaudio==2.10.0.dev20251222+cpu
[pip3] torchdata==0.11.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.6.1
[pip3] torchvision==0.25.0.dev20251222+cpu
[conda] Could not collect