From 7b23a0abd769449c06b243907dab644f5bb90a9f Mon Sep 17 00:00:00 2001 From: Oscar Andersson <oscar.andersson@arm.com> Date: Fri, 28 Feb 2025 08:50:28 +0100 Subject: [PATCH] Make passes preserve and update node metadata When creating or updating nodes in passes, the metadata is not preserved nor updated correctly. This patch adds an ArmPass base class which may update the node metadata if super().call_operator(update=True) is used. It also adds functionality to arm_pass_utils.create_node() to update the node metadata. It will only update the 'stack_trace' field. All the other fields will be preserved from the original node. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Change-Id: I725dd057716ae5a1fac0f97b522df22196f00bdb --- backends/arm/_passes/arm_pass.py | 33 ++++++++++++++++ backends/arm/_passes/arm_pass_utils.py | 18 ++++++++- .../arm/_passes/decompose_layernorm_pass.py | 38 +++++++++++++++---- .../arm/_passes/decompose_meandim_pass.py | 12 +++--- .../decompose_softmax_unstable_pass.py | 14 +++---- backends/arm/_passes/decompose_var_pass.py | 19 ++++++---- backends/arm/_passes/mm_to_bmm_pass.py | 8 ++-- backends/arm/tosa_utils.py | 24 ++++++++---- 8 files changed, 124 insertions(+), 42 deletions(-) create mode 100644 backends/arm/_passes/arm_pass.py diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py new file mode 100644 index 00000000000..085267a174e --- /dev/null +++ b/backends/arm/_passes/arm_pass.py @@ -0,0 +1,33 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import traceback +from typing import Optional + +import torch +from executorch.exir.pass_base import ExportPass, NodeMetadata + + +class ArmPass(ExportPass): + """Base class for Arm passes""" + + def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None): + super(ArmPass, self).__init__() + self.exported_program = exported_program + + def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): + if not updated: + return super().call_operator(op, args, kwargs, meta) + + # if updated we should update metadata + new_meta = {} + keys = meta.data.keys() + for key in keys: + new_meta[key] = meta[key] + old_stack_trace = new_meta.get("stack_trace", "") + new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}" + return super().call_operator(op, args, kwargs, NodeMetadata(new_meta)) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index a8d06713678..081400f02cc 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -7,12 +7,12 @@ # pyre-unsafe +import traceback from inspect import isclass from typing import Optional import torch import torch.fx - from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -96,6 +96,7 @@ def create_node( kwargs: Optional[dict] = None, quantize: bool = False, q_params: Optional[tuple] = None, + from_node: Optional[torch.fx.Node] = None, ): """ Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node. @@ -108,8 +109,18 @@ def create_node( args=args, kwargs=kwargs or {}, ) + + new_meta = {} + if from_node: + keys = from_node.meta.keys() + for key in keys: + new_meta[key] = from_node.meta[key] + old_stack_trace = new_meta.get("stack_trace", "") + new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}" + node.meta = new_meta + if quantize and q_params: - return insert_q_dq_pair(graph, node, q_params) + return insert_q_dq_pair(graph, node, q_params, from_node) return node @@ -117,6 +128,7 @@ def insert_q_dq_pair( graph: torch.fx.Graph, anchor: torch.fx.Node, q_params: tuple, + from_node: Optional[torch.fx.Node] = None, ): """ Inserts a q dq node pair after the node 'anchor'. @@ -127,6 +139,7 @@ def insert_q_dq_pair( graph=graph, op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(), # We add the argument last + from_node=from_node if from_node else anchor, ) q.meta = anchor.meta with graph.inserting_after(q): @@ -134,6 +147,7 @@ def insert_q_dq_pair( graph=graph, op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(q,) + q_params, + from_node=from_node if from_node else anchor, ) dq.meta = q.meta anchor.replace_all_uses_with(dq) diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index cc4a81caae0..5d132e50c84 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -9,9 +9,10 @@ import operator import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import PassResult def get_layer_norm_decomposition(op) -> tuple: @@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple: raise RuntimeError(f"Can't get layer_norm composition for op {op}") -class DecomposeLayerNormPass(ExportPass): +class DecomposeLayerNormPass(ArmPass): """ layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of: @@ -111,24 +112,39 @@ def call(self, graph_module: torch.fx.GraphModule): var_op, args=(x, dims), kwargs={"correction": 0, "keepdim": keepdim}, + from_node=node, ) full = create_node( graph_module.graph, full_op, args=(epsilon_reshaped_shape, epsilon), kwargs={"dtype": dtype}, + from_node=node, + ) + add0 = create_node( + graph_module.graph, add_op, args=(var, full), from_node=node + ) + rsqrt = create_node( + graph_module.graph, rsqrt_op, args=(add0,), from_node=node + ) + mul0 = create_node( + graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node ) - add0 = create_node(graph_module.graph, add_op, args=(var, full)) - rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,)) - mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt)) if weights is not None: weights_reshaped = create_node( graph_module.graph, view_op, args=(weights, weights_reshaped_shape), + from_node=node, ) mul1 = create_node( - graph_module.graph, mul_op, args=(mul0, weights_reshaped) + graph_module.graph, + mul_op, + args=( + mul0, + weights_reshaped, + ), + from_node=node, ) else: mul1 = mul0 @@ -136,10 +152,16 @@ def call(self, graph_module: torch.fx.GraphModule): if bias is not None: bias_reshaped_shape = weights_reshaped_shape bias_reshaped = create_node( - graph_module.graph, view_op, args=(bias, bias_reshaped_shape) + graph_module.graph, + view_op, + args=(bias, bias_reshaped_shape), + from_node=node, ) output = create_node( - graph_module.graph, add_op, args=(mul1, bias_reshaped) + graph_module.graph, + add_op, + args=(mul1, bias_reshaped), + from_node=node, ) users = [user for user in node.users if node != user] diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index abf5c8f363d..9bcfb72916a 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -7,9 +7,9 @@ # pyre-unsafe import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass def get_meandim_decomposition(op) -> tuple: @@ -28,7 +28,7 @@ def get_meandim_decomposition(op) -> tuple: raise RuntimeError(f"Can't get meandim decomposition for op {op}") -class DecomposeMeanDimPass(ExportPass): +class DecomposeMeanDimPass(ArmPass): """ This pass decomposes meandim into a sum and mul node. @@ -62,8 +62,8 @@ def call_operator(self, op, args, kwargs, meta): sum_op, full_op, mul_op = get_meandim_decomposition(op) - sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta) + sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True) full = super().call_operator( - full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta + full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True ) - return super().call_operator(mul_op, (sum, full), {}, meta) + return super().call_operator(mul_op, (sum, full), {}, meta, True) diff --git a/backends/arm/_passes/decompose_softmax_unstable_pass.py b/backends/arm/_passes/decompose_softmax_unstable_pass.py index 4a2ce712ab7..e4c9113da13 100644 --- a/backends/arm/_passes/decompose_softmax_unstable_pass.py +++ b/backends/arm/_passes/decompose_softmax_unstable_pass.py @@ -6,8 +6,8 @@ # pyre-unsafe import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass # For BI case torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int) @@ -45,7 +45,7 @@ def get_logsoftmax_ops(op) -> tuple: raise RuntimeError(f"Can't get softmax decomposition ops for op {op}") -class DecomposeSoftmaxUnstablePass(ExportPass): +class DecomposeSoftmaxUnstablePass(ArmPass): """ This pass decomposes log softmax or softmax into more primitive ops. @@ -66,10 +66,10 @@ def call_operator(self, op, args, kwargs, meta): _input = args[0] dim = [args[1]] - op1 = super().call_operator(exp_op, (_input,), {}, meta) - op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta) - op3 = super().call_operator(reciprocal_op, (op2,), {}, meta) - op4 = super().call_operator(mul_op, (op1, op3), {}, meta) + op1 = super().call_operator(exp_op, (_input,), {}, meta, True) + op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta, True) + op3 = super().call_operator(reciprocal_op, (op2,), {}, meta, True) + op4 = super().call_operator(mul_op, (op1, op3), {}, meta, True) if op in log_softmax: - op4 = super().call_operator(log_op, (op4,), {}, meta) + op4 = super().call_operator(log_op, (op4,), {}, meta, True) return op4 diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index 73747d8313d..0c43cd1b9cb 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -8,9 +8,9 @@ import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass def get_var_decomposition(op) -> tuple: @@ -33,7 +33,7 @@ def get_var_decomposition(op) -> tuple: raise RuntimeError(f"Can't get var decomposition for op {op}") -class DecomposeVarPass(ExportPass): +class DecomposeVarPass(ArmPass): """ This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html) @@ -77,14 +77,17 @@ def call_operator(self, op, args, kwargs, meta): N *= input_shape[d] mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op) - mean = super().call_operator(mean_op, (x, dim, True), {}, meta) - diff = super().call_operator(diff_op, (x, mean), {}, meta) - squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta) - sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta) + mean = super().call_operator(mean_op, (x, dim, True), {}, meta, True) + diff = super().call_operator(diff_op, (x, mean), {}, meta, True) + squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta, True) + sum = super().call_operator( + sum_op, (squared_diff, dim, keepdim), {}, meta, True + ) full = super().call_operator( full_op, ([], 1 / max(0, N - correction)), {"dtype": dtype}, meta, + True, ) - return super().call_operator(mul_op, (sum, full), {}, meta) + return super().call_operator(mul_op, (sum, full), {}, meta, True) diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 602d4a007a6..34ac7553212 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -47,7 +47,7 @@ def call(self, graph_module: torch.fx.GraphModule): with graph.inserting_before(node): unsqueeze_before = create_node( - graph, exir_ops.edge.aten.unsqueeze_copy.default + graph, exir_ops.edge.aten.unsqueeze_copy.default, from_node=node ) unsqueeze_before.args = ( input_node, # Input is node's original input @@ -58,13 +58,14 @@ def call(self, graph_module: torch.fx.GraphModule): # If Quantized we must insert unsqueeze --> q --> dq --> node if input_node.target == dq_op: q_params = input_node.args[1:] - insert_q_dq_pair(graph, unsqueeze_before, q_params) + insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node) # Replace mm node with bmm with graph.inserting_before(node): bmm_node = create_node( graph, exir_ops.edge.aten.bmm.default, + from_node=node, ) bmm_node.args = node.args node.replace_all_uses_with(bmm_node) @@ -75,6 +76,7 @@ def call(self, graph_module: torch.fx.GraphModule): squeeze_after = create_node( graph, exir_ops.edge.aten.squeeze_copy.dims, + from_node=node, ) squeeze_after.args = ( bmm_node, @@ -89,7 +91,7 @@ def call(self, graph_module: torch.fx.GraphModule): # If quantized, insert mm --> q --> dq --> squeeze if all(original_user.target == q_op for original_user in original_users): q_params = original_users[0].args[1:] - insert_q_dq_pair(graph, bmm_node, q_params) + insert_q_dq_pair(graph, bmm_node, q_params, from_node=node) modified_graph = True diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 788ebf39696..e265f81dc88 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -14,7 +14,7 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.print_program import inspect_node +from executorch.exir.print_program import add_cursor_to_graph from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -32,7 +32,8 @@ def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> str: output = ( - f" {inspect_node(graph=graph_module.graph, node=node)}\n" + " Here is the node in the graph:\n" + f" {add_cursor_to_graph(graph=graph_module.graph, finding_node=node)}\n" "-- NODE DEBUG INFO --\n" f" Op is {node.op}\n" f" Name is {node.name}\n" @@ -43,10 +44,17 @@ def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) " Node.meta = \n" ) for k, v in node.meta.items(): - output += f" '{k}' = {v}\n" - if isinstance(v, list): - for i in v: - output += f" {i}\n" + if k == "stack_trace": + matches = v.split("\n") + output += " 'stack_trace =\n" + for m in matches: + output += f" {m}\n" + else: + output += f" '{k}' = {v}\n" + + if isinstance(v, list): + for i in v: + output += f" {i}\n" return output @@ -78,10 +86,10 @@ def dbg_fail( tosa_graph: Optional[ts.TosaSerializer] = None, path: Optional[str] = None, ): - logger.warning("Internal error due to poorly handled node:") + logger.warning(f" Internal error due to poorly handled node: {node.name}") if tosa_graph is not None and path is not None: dbg_tosa_dump(tosa_graph, path) - logger.warning(f"Debug output captured in '{path}'.") + logger.warning(f" Debug output captured in '{path}'.") dbg_node(node, graph_module)