Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion backends/nxp/backend/custom_delegation_options.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 NXP
# Copyright 2025-2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -17,3 +17,8 @@ class CustomDelegationOptions:
# of `num_macs`. The `force_delegate_cat` allows the user to turn off the defensive check if from the model design
# it is known this constraint will be satisfied.
force_delegate_cat: bool = False

# Proposed partitions which only contain Neutron no-ops are normally not delegated, as the NeutronConverter would
# not create any NeutronGraph that can be called. This is done by the partitioner itself, and is not handled by
# the individual node converters.
allow_no_op_partitions: bool = False
133 changes: 132 additions & 1 deletion backends/nxp/backend/edge_helper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright 2024-2025 NXP
# Copyright 2024-2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch

from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -19,6 +21,14 @@
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
]

# A set of operators which could possibly be no-ops in certain conditions. The operators in this set will be proclaimed
# as no-ops (and potentially not delegated), if when run with random data their output matches their input.
operators_which_may_be_no_ops = {
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.sub.Tensor,
}


def input_tensor(node: Node, input_index: int) -> torch.Tensor:
if len(node.all_input_nodes) <= input_index:
Expand Down Expand Up @@ -220,3 +230,124 @@ def get_non_qdq_parent(node: Node, input_index: int = 0) -> Node | None:
return None

return quant_node.args[0]


def try_get_dequantized_data(
dequantize_node: Node, parameters_mapping: dict[str, Parameter]
) -> Parameter | None:
"""Get the dequantized data from the following pattern. The dequantization formula is `r = (q - Z) * S`, where `q`
represents the static quantized data.

┌─────────────────────────┐
│ <static_quantized_data> │
└────────────┬────────────┘
┌─────▼──────┐
│ Dequantize │
└─────┬──────┘


:param dequantize_node: The Dequantize node from the pattern, which dequantizes the static quantized data.
:param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the
`state_dict` attribute of an edge program.
:return: The dequantized static parameter, or `None` if the data is not available.
"""
if not _is_dequantize(dequantize_node):
return None

if not node_is_static_tensor(param := dequantize_node.args[0], parameters_mapping):
return None

# The pattern is correct. Dequantize the static data and return it.
scale, zp = get_quantization_parameters_for(dequantize_node)
quantized_data = parameters_mapping[param.name]

dequantized_data = (quantized_data - zp) * scale
return dequantized_data


def is_no_op_on_neutron(node: Node, parameters_mapping: dict[str, Parameter]) -> bool:
"""Check if a node is a no-op operation from the perspective of Neutron."""
if node.op != "call_function":
raise ValueError(
f"is_no_op_on_neutron(): Expected call_function node, got {node.op}."
)

if node.target in [
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
exir_ops.edge.aten.clone.default,
]:
# Known operators which are always no-ops on Neutron.
return True

if node.target == exir_ops.edge.aten.cat.default and len(node.args[0]) == 1:
# Concatenation with 1 input is a no-op.
return True

# For any other operators, run them with random data ad see if the output is identical to the input.
torch.manual_seed(42)
# noinspection PyBroadException
try:
input_data = None
args_with_random_data = []
for arg in node.args:
match arg:
case Node():
# `arg` is either another operator, a model input, or a static parameter.

if (
data := try_get_dequantized_data(arg, parameters_mapping)
) is not None:
# The `arg` is a static parameter. Use it's actual static data during the no-op test.
args_with_random_data.append(data)

else:
# The `arg` is a compute node or a model input. Replace it with random data for the no-op test.
if input_data is not None:
# Some random input data for `node` has already been stored, which means that the node has
# more than 1 dynamic input node. Therefore, it cannot be a no-op.
return False

# Generate the random data. Use the range [-5, 5) to avoid proclaiming operations like Relu as
# no-ops.
val = arg.meta["val"]
input_data = torch.rand(val.shape, dtype=val.dtype) * 10 - 5
args_with_random_data.append(input_data)

case list():
# Multiple input nodes are not supported. `aten.cat` is explicitly supported above.
return False

case _:
# Generic argument (value). Not an input from a previous node. Store it in the arguments for the
# no-op test.
args_with_random_data.append(arg)

# Run the operator with the random data. If the input equals the output, the node is considered a no-op.
output_data = node.target(*args_with_random_data)

val = node.meta["val"]
if (
output_data.dtype == val.dtype
and output_data.shape == val.shape
and torch.all(input_data == output_data)
):
# The operator preserves the shape, data type, and data. Therefore, it is a no-op from the perspective of
# Neutron.
if node.target in operators_which_may_be_no_ops:
return True
else:
logging.info(
f"Found the operator `{node.target}`, which appears to be a no-op, but is not in the "
"known no-op list. Please report this issue."
)

else:
# Type, shape, or data doesn't match.
return False

except Exception:
# If execution fails, assume it's not a no-op.
return False
43 changes: 37 additions & 6 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
from executorch.backends.nxp.backend.ir.converter.node_converter import (
CustomDelegationOptions,
)
from torch._subclasses import FakeTensor
from torch.export import ExportedProgram
from torch.export.graph_signature import InputKind
from torch.fx import Node
from torch.nn.parameter import Parameter
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from executorch.exir.dialects._ops import ops as exir_ops

# noinspection PyProtectedMember
Expand Down Expand Up @@ -65,7 +67,7 @@ def convert_program(
conversion_config: ConversionConfig = _default_conversion_config,
neutron_target_spec: NeutronTargetSpec = _default_target_spec,
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
) -> (bytes, dict[str, NodeFormat]):
) -> tuple[bytes, dict[str, dict[str, TensorFormat]]]:
"""
Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes.

Expand Down Expand Up @@ -161,20 +163,49 @@ def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContex
)

@staticmethod
def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Parameter]:
def map_inputs_to_parameters(
edge_program: ExportedProgram,
post_quantization_state_dict: dict[str, Parameter] | None = None,
) -> dict[str, Parameter]:
"""
Create mapping between program parameters (input nodes & static data nodes) and their names.

:param edge_program: EdgeProgram instance.
:param post_quantization_state_dict: State-dict of the model right after quantization. During partitioning, the
`edge_program` only contains fake tensors without any data. In this case,
this state dict is used instead (if provided). Notice: It may potentially
contain outdated data,
:return: Mapping from parameter name to parameter instance.
"""
result_map = {}

for input_spec in edge_program.graph_signature.input_specs:
if input_spec.kind in [InputKind.PARAMETER, InputKind.BUFFER]:
result_map[input_spec.arg.name] = edge_program.state_dict[
input_spec.target
]

# First, try to load the static data from the model.
param = edge_program.state_dict[input_spec.target]

if not isinstance(param, FakeTensor):
# Use the data from the model.
result_map[input_spec.arg.name] = param

else:
# It is the partitioning stage, which uses a FakeModel with FakeTensors (without the actual data).
# Try to load the data from the post-quantization state dict.
if (
post_quantization_state_dict is not None
and (
param := post_quantization_state_dict.get(
input_spec.target, None
)
)
is not None
):
result_map[input_spec.arg.name] = param

else:
# There is no data available.
continue

return result_map

Expand Down
22 changes: 17 additions & 5 deletions backends/nxp/backend/ir/converter/node_converter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 NXP
# Copyright 2024-2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -75,7 +75,10 @@ def _is_supported_in_IR(
Classes which implement conversion for individual operators must overwrite this method.

:param node: torch.Node to check.
:param parameters_mapping: Dictionary mapping tensor names to their static data (if they have it).
:param parameters_mapping: Dictionary mapping static parameter names to Parameter objects containing their data
(if they have any). During partitioning, this data is extracted from the model right
after quantization and before edge dialect passes. Therefore, it could potentially
be outdated.
:param custom_delegation_options: Custom options which affect delegation.
"""
pass
Expand All @@ -93,7 +96,10 @@ def _is_supported_on_target(

:param node: The node (edge operator) to check.
:param neutron_target_spec: Object for querying the target platform to retrieve its properties.
:param parameters_mapping: Dictionary mapping tensor names to their static data (if they have it).
:param parameters_mapping: Dictionary mapping static parameter names to Parameter objects containing their data
(if they have any). During partitioning, this data is extracted from the model right
after quantization and before edge dialect passes. Therefore, it could potentially
be outdated.
:param custom_delegation_options: Custom options which affect delegation.
"""
return True
Expand All @@ -110,7 +116,10 @@ def is_supported(

:param node: torch.Node to check.
:param neutron_target_spec: Object for querying the target platform to retrieve its properties.
:param parameters_mapping: Dict mapping tensor names to their data.
:param parameters_mapping: Dictionary mapping static parameter names to Parameter objects containing their data
(if they have any). During partitioning, this data is extracted from the model right
after quantization and before edge dialect passes. Therefore, it could potentially
be outdated.
:param custom_delegation_options: Custom user options which affect node delegation.
"""
return cls._is_supported_in_IR(
Expand All @@ -136,7 +145,10 @@ def supports_partitioning_result(
:param partition_list: List of proposed partitions.
:param custom_delegation_options: Custom user options which affect node delegation.
:param neutron_target_spec: NeutronTargetSpec instance.
:param parameters_mapping: Dictionary mapping tensor names to their static data.
:param parameters_mapping: Dictionary mapping static parameter names to Parameter objects containing their data
(if they have any). During partitioning, this data is extracted from the model right
after quantization and before edge dialect passes. Therefore, it could potentially
be outdated.
:return: Boolean indicating whether the node supports the current partitioning.
"""
return True
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 NXP
# Copyright 2024-2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -20,7 +20,6 @@
)
from executorch.backends.nxp.backend.ir.converter.node_converter import (
CustomDelegationOptions,
is_not_qdq_node,
NodeConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.shared.reshape_transposition import (
Expand Down Expand Up @@ -59,24 +58,6 @@ def _is_supported_in_IR(

return True

@classmethod
def _partition_contains_compute_nodes(cls, view_copy_partition: Partition) -> bool:
non_q_dq_partition_nodes = list(
filter(is_not_qdq_node, view_copy_partition.nodes)
)

if len(non_q_dq_partition_nodes) == 1:
# The `view_copy` cannot be the only node in a partition.
return False

# It is common for a `clone` node to come before the `view_copy`. Make sure these are not the only two nodes
# in the partition.
if any("clone" in n.name for n in non_q_dq_partition_nodes):
if len(non_q_dq_partition_nodes) <= 2:
return False

return True

@classmethod
def supports_partitioning_result(
cls,
Expand All @@ -91,9 +72,6 @@ def supports_partitioning_result(
]
assert len(view_copy_partitions) == 1

if not cls._partition_contains_compute_nodes(view_copy_partitions[0]):
return False

input_format = node.args[0].meta[NXP_NODE_FORMAT]
output_format = node.meta[NXP_NODE_FORMAT]
input_shape = list(node.args[0].meta["val"].shape)
Expand Down
Loading
Loading