diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py new file mode 100644 index 00000000000..085267a174e --- /dev/null +++ b/backends/arm/_passes/arm_pass.py @@ -0,0 +1,33 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import traceback +from typing import Optional + +import torch +from executorch.exir.pass_base import ExportPass, NodeMetadata + + +class ArmPass(ExportPass): + """Base class for Arm passes""" + + def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None): + super(ArmPass, self).__init__() + self.exported_program = exported_program + + def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): + if not updated: + return super().call_operator(op, args, kwargs, meta) + + # if updated we should update metadata + new_meta = {} + keys = meta.data.keys() + for key in keys: + new_meta[key] = meta[key] + old_stack_trace = new_meta.get("stack_trace", "") + new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}" + return super().call_operator(op, args, kwargs, NodeMetadata(new_meta)) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index a8d06713678..081400f02cc 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -7,12 +7,12 @@ # pyre-unsafe +import traceback from inspect import isclass from typing import Optional import torch import torch.fx - from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -96,6 +96,7 @@ def create_node( kwargs: Optional[dict] = None, quantize: bool = False, q_params: Optional[tuple] = None, + from_node: Optional[torch.fx.Node] = None, ): """ Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node. @@ -108,8 +109,18 @@ def create_node( args=args, kwargs=kwargs or {}, ) + + new_meta = {} + if from_node: + keys = from_node.meta.keys() + for key in keys: + new_meta[key] = from_node.meta[key] + old_stack_trace = new_meta.get("stack_trace", "") + new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}" + node.meta = new_meta + if quantize and q_params: - return insert_q_dq_pair(graph, node, q_params) + return insert_q_dq_pair(graph, node, q_params, from_node) return node @@ -117,6 +128,7 @@ def insert_q_dq_pair( graph: torch.fx.Graph, anchor: torch.fx.Node, q_params: tuple, + from_node: Optional[torch.fx.Node] = None, ): """ Inserts a q dq node pair after the node 'anchor'. @@ -127,6 +139,7 @@ def insert_q_dq_pair( graph=graph, op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(), # We add the argument last + from_node=from_node if from_node else anchor, ) q.meta = anchor.meta with graph.inserting_after(q): @@ -134,6 +147,7 @@ def insert_q_dq_pair( graph=graph, op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(q,) + q_params, + from_node=from_node if from_node else anchor, ) dq.meta = q.meta anchor.replace_all_uses_with(dq) diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index cc4a81caae0..5d132e50c84 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -9,9 +9,10 @@ import operator import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import PassResult def get_layer_norm_decomposition(op) -> tuple: @@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple: raise RuntimeError(f"Can't get layer_norm composition for op {op}") -class DecomposeLayerNormPass(ExportPass): +class DecomposeLayerNormPass(ArmPass): """ layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of: @@ -111,24 +112,39 @@ def call(self, graph_module: torch.fx.GraphModule): var_op, args=(x, dims), kwargs={"correction": 0, "keepdim": keepdim}, + from_node=node, ) full = create_node( graph_module.graph, full_op, args=(epsilon_reshaped_shape, epsilon), kwargs={"dtype": dtype}, + from_node=node, + ) + add0 = create_node( + graph_module.graph, add_op, args=(var, full), from_node=node + ) + rsqrt = create_node( + graph_module.graph, rsqrt_op, args=(add0,), from_node=node + ) + mul0 = create_node( + graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node ) - add0 = create_node(graph_module.graph, add_op, args=(var, full)) - rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,)) - mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt)) if weights is not None: weights_reshaped = create_node( graph_module.graph, view_op, args=(weights, weights_reshaped_shape), + from_node=node, ) mul1 = create_node( - graph_module.graph, mul_op, args=(mul0, weights_reshaped) + graph_module.graph, + mul_op, + args=( + mul0, + weights_reshaped, + ), + from_node=node, ) else: mul1 = mul0 @@ -136,10 +152,16 @@ def call(self, graph_module: torch.fx.GraphModule): if bias is not None: bias_reshaped_shape = weights_reshaped_shape bias_reshaped = create_node( - graph_module.graph, view_op, args=(bias, bias_reshaped_shape) + graph_module.graph, + view_op, + args=(bias, bias_reshaped_shape), + from_node=node, ) output = create_node( - graph_module.graph, add_op, args=(mul1, bias_reshaped) + graph_module.graph, + add_op, + args=(mul1, bias_reshaped), + from_node=node, ) users = [user for user in node.users if node != user] diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index abf5c8f363d..9bcfb72916a 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -7,9 +7,9 @@ # pyre-unsafe import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass def get_meandim_decomposition(op) -> tuple: @@ -28,7 +28,7 @@ def get_meandim_decomposition(op) -> tuple: raise RuntimeError(f"Can't get meandim decomposition for op {op}") -class DecomposeMeanDimPass(ExportPass): +class DecomposeMeanDimPass(ArmPass): """ This pass decomposes meandim into a sum and mul node. @@ -62,8 +62,8 @@ def call_operator(self, op, args, kwargs, meta): sum_op, full_op, mul_op = get_meandim_decomposition(op) - sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta) + sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True) full = super().call_operator( - full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta + full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True ) - return super().call_operator(mul_op, (sum, full), {}, meta) + return super().call_operator(mul_op, (sum, full), {}, meta, True) diff --git a/backends/arm/_passes/decompose_softmax_unstable_pass.py b/backends/arm/_passes/decompose_softmax_unstable_pass.py index 4a2ce712ab7..e4c9113da13 100644 --- a/backends/arm/_passes/decompose_softmax_unstable_pass.py +++ b/backends/arm/_passes/decompose_softmax_unstable_pass.py @@ -6,8 +6,8 @@ # pyre-unsafe import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass # For BI case torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int) @@ -45,7 +45,7 @@ def get_logsoftmax_ops(op) -> tuple: raise RuntimeError(f"Can't get softmax decomposition ops for op {op}") -class DecomposeSoftmaxUnstablePass(ExportPass): +class DecomposeSoftmaxUnstablePass(ArmPass): """ This pass decomposes log softmax or softmax into more primitive ops. @@ -66,10 +66,10 @@ def call_operator(self, op, args, kwargs, meta): _input = args[0] dim = [args[1]] - op1 = super().call_operator(exp_op, (_input,), {}, meta) - op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta) - op3 = super().call_operator(reciprocal_op, (op2,), {}, meta) - op4 = super().call_operator(mul_op, (op1, op3), {}, meta) + op1 = super().call_operator(exp_op, (_input,), {}, meta, True) + op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta, True) + op3 = super().call_operator(reciprocal_op, (op2,), {}, meta, True) + op4 = super().call_operator(mul_op, (op1, op3), {}, meta, True) if op in log_softmax: - op4 = super().call_operator(log_op, (op4,), {}, meta) + op4 = super().call_operator(log_op, (op4,), {}, meta, True) return op4 diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index 73747d8313d..0c43cd1b9cb 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -8,9 +8,9 @@ import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass def get_var_decomposition(op) -> tuple: @@ -33,7 +33,7 @@ def get_var_decomposition(op) -> tuple: raise RuntimeError(f"Can't get var decomposition for op {op}") -class DecomposeVarPass(ExportPass): +class DecomposeVarPass(ArmPass): """ This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html) @@ -77,14 +77,17 @@ def call_operator(self, op, args, kwargs, meta): N *= input_shape[d] mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op) - mean = super().call_operator(mean_op, (x, dim, True), {}, meta) - diff = super().call_operator(diff_op, (x, mean), {}, meta) - squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta) - sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta) + mean = super().call_operator(mean_op, (x, dim, True), {}, meta, True) + diff = super().call_operator(diff_op, (x, mean), {}, meta, True) + squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta, True) + sum = super().call_operator( + sum_op, (squared_diff, dim, keepdim), {}, meta, True + ) full = super().call_operator( full_op, ([], 1 / max(0, N - correction)), {"dtype": dtype}, meta, + True, ) - return super().call_operator(mul_op, (sum, full), {}, meta) + return super().call_operator(mul_op, (sum, full), {}, meta, True) diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 602d4a007a6..34ac7553212 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -47,7 +47,7 @@ def call(self, graph_module: torch.fx.GraphModule): with graph.inserting_before(node): unsqueeze_before = create_node( - graph, exir_ops.edge.aten.unsqueeze_copy.default + graph, exir_ops.edge.aten.unsqueeze_copy.default, from_node=node ) unsqueeze_before.args = ( input_node, # Input is node's original input @@ -58,13 +58,14 @@ def call(self, graph_module: torch.fx.GraphModule): # If Quantized we must insert unsqueeze --> q --> dq --> node if input_node.target == dq_op: q_params = input_node.args[1:] - insert_q_dq_pair(graph, unsqueeze_before, q_params) + insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node) # Replace mm node with bmm with graph.inserting_before(node): bmm_node = create_node( graph, exir_ops.edge.aten.bmm.default, + from_node=node, ) bmm_node.args = node.args node.replace_all_uses_with(bmm_node) @@ -75,6 +76,7 @@ def call(self, graph_module: torch.fx.GraphModule): squeeze_after = create_node( graph, exir_ops.edge.aten.squeeze_copy.dims, + from_node=node, ) squeeze_after.args = ( bmm_node, @@ -89,7 +91,7 @@ def call(self, graph_module: torch.fx.GraphModule): # If quantized, insert mm --> q --> dq --> squeeze if all(original_user.target == q_op for original_user in original_users): q_params = original_users[0].args[1:] - insert_q_dq_pair(graph, bmm_node, q_params) + insert_q_dq_pair(graph, bmm_node, q_params, from_node=node) modified_graph = True diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 788ebf39696..e265f81dc88 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -14,7 +14,7 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.print_program import inspect_node +from executorch.exir.print_program import add_cursor_to_graph from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -32,7 +32,8 @@ def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> str: output = ( - f" {inspect_node(graph=graph_module.graph, node=node)}\n" + " Here is the node in the graph:\n" + f" {add_cursor_to_graph(graph=graph_module.graph, finding_node=node)}\n" "-- NODE DEBUG INFO --\n" f" Op is {node.op}\n" f" Name is {node.name}\n" @@ -43,10 +44,17 @@ def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) " Node.meta = \n" ) for k, v in node.meta.items(): - output += f" '{k}' = {v}\n" - if isinstance(v, list): - for i in v: - output += f" {i}\n" + if k == "stack_trace": + matches = v.split("\n") + output += " 'stack_trace =\n" + for m in matches: + output += f" {m}\n" + else: + output += f" '{k}' = {v}\n" + + if isinstance(v, list): + for i in v: + output += f" {i}\n" return output @@ -78,10 +86,10 @@ def dbg_fail( tosa_graph: Optional[ts.TosaSerializer] = None, path: Optional[str] = None, ): - logger.warning("Internal error due to poorly handled node:") + logger.warning(f" Internal error due to poorly handled node: {node.name}") if tosa_graph is not None and path is not None: dbg_tosa_dump(tosa_graph, path) - logger.warning(f"Debug output captured in '{path}'.") + logger.warning(f" Debug output captured in '{path}'.") dbg_node(node, graph_module)