From 7b23a0abd769449c06b243907dab644f5bb90a9f Mon Sep 17 00:00:00 2001
From: Oscar Andersson <oscar.andersson@arm.com>
Date: Fri, 28 Feb 2025 08:50:28 +0100
Subject: [PATCH] Make passes preserve and update node metadata

When creating or updating nodes in passes, the metadata is not
preserved nor updated correctly. This patch adds an ArmPass base class
which may update the node metadata if
super().call_operator(update=True) is used. It also adds functionality
to arm_pass_utils.create_node() to update the node metadata. It will
only update the 'stack_trace' field. All the other fields will be
preserved from the original node.

Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
Change-Id: I725dd057716ae5a1fac0f97b522df22196f00bdb
---
 backends/arm/_passes/arm_pass.py              | 33 ++++++++++++++++
 backends/arm/_passes/arm_pass_utils.py        | 18 ++++++++-
 .../arm/_passes/decompose_layernorm_pass.py   | 38 +++++++++++++++----
 .../arm/_passes/decompose_meandim_pass.py     | 12 +++---
 .../decompose_softmax_unstable_pass.py        | 14 +++----
 backends/arm/_passes/decompose_var_pass.py    | 19 ++++++----
 backends/arm/_passes/mm_to_bmm_pass.py        |  8 ++--
 backends/arm/tosa_utils.py                    | 24 ++++++++----
 8 files changed, 124 insertions(+), 42 deletions(-)
 create mode 100644 backends/arm/_passes/arm_pass.py

diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py
new file mode 100644
index 00000000000..085267a174e
--- /dev/null
+++ b/backends/arm/_passes/arm_pass.py
@@ -0,0 +1,33 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-unsafe
+
+import traceback
+from typing import Optional
+
+import torch
+from executorch.exir.pass_base import ExportPass, NodeMetadata
+
+
+class ArmPass(ExportPass):
+    """Base class for Arm passes"""
+
+    def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None):
+        super(ArmPass, self).__init__()
+        self.exported_program = exported_program
+
+    def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
+        if not updated:
+            return super().call_operator(op, args, kwargs, meta)
+
+        # if updated we should update metadata
+        new_meta = {}
+        keys = meta.data.keys()
+        for key in keys:
+            new_meta[key] = meta[key]
+        old_stack_trace = new_meta.get("stack_trace", "")
+        new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
+        return super().call_operator(op, args, kwargs, NodeMetadata(new_meta))
diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py
index a8d06713678..081400f02cc 100644
--- a/backends/arm/_passes/arm_pass_utils.py
+++ b/backends/arm/_passes/arm_pass_utils.py
@@ -7,12 +7,12 @@
 
 # pyre-unsafe
 
+import traceback
 from inspect import isclass
 from typing import Optional
 
 import torch
 import torch.fx
-
 from executorch.exir import ExportedProgram
 from executorch.exir.dialects._ops import ops as exir_ops
 
@@ -96,6 +96,7 @@ def create_node(
     kwargs: Optional[dict] = None,
     quantize: bool = False,
     q_params: Optional[tuple] = None,
+    from_node: Optional[torch.fx.Node] = None,
 ):
     """
     Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
@@ -108,8 +109,18 @@ def create_node(
         args=args,
         kwargs=kwargs or {},
     )
+
+    new_meta = {}
+    if from_node:
+        keys = from_node.meta.keys()
+        for key in keys:
+            new_meta[key] = from_node.meta[key]
+    old_stack_trace = new_meta.get("stack_trace", "")
+    new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
+    node.meta = new_meta
+
     if quantize and q_params:
-        return insert_q_dq_pair(graph, node, q_params)
+        return insert_q_dq_pair(graph, node, q_params, from_node)
     return node
 
 
@@ -117,6 +128,7 @@ def insert_q_dq_pair(
     graph: torch.fx.Graph,
     anchor: torch.fx.Node,
     q_params: tuple,
+    from_node: Optional[torch.fx.Node] = None,
 ):
     """
     Inserts a q dq node pair after the node 'anchor'.
@@ -127,6 +139,7 @@ def insert_q_dq_pair(
             graph=graph,
             op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
             args=(),  # We add the argument last
+            from_node=from_node if from_node else anchor,
         )
         q.meta = anchor.meta
     with graph.inserting_after(q):
@@ -134,6 +147,7 @@ def insert_q_dq_pair(
             graph=graph,
             op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
             args=(q,) + q_params,
+            from_node=from_node if from_node else anchor,
         )
         dq.meta = q.meta
     anchor.replace_all_uses_with(dq)
diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py
index cc4a81caae0..5d132e50c84 100644
--- a/backends/arm/_passes/decompose_layernorm_pass.py
+++ b/backends/arm/_passes/decompose_layernorm_pass.py
@@ -9,9 +9,10 @@
 import operator
 
 import torch
+from executorch.backends.arm._passes.arm_pass import ArmPass
 from executorch.backends.arm._passes.arm_pass_utils import create_node
 from executorch.exir.dialects._ops import ops as exir_ops
-from executorch.exir.pass_base import ExportPass, PassResult
+from executorch.exir.pass_base import PassResult
 
 
 def get_layer_norm_decomposition(op) -> tuple:
@@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple:
     raise RuntimeError(f"Can't get layer_norm composition for op {op}")
 
 
-class DecomposeLayerNormPass(ExportPass):
+class DecomposeLayerNormPass(ArmPass):
     """
     layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
     Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
@@ -111,24 +112,39 @@ def call(self, graph_module: torch.fx.GraphModule):
                     var_op,
                     args=(x, dims),
                     kwargs={"correction": 0, "keepdim": keepdim},
+                    from_node=node,
                 )
                 full = create_node(
                     graph_module.graph,
                     full_op,
                     args=(epsilon_reshaped_shape, epsilon),
                     kwargs={"dtype": dtype},
+                    from_node=node,
+                )
+                add0 = create_node(
+                    graph_module.graph, add_op, args=(var, full), from_node=node
+                )
+                rsqrt = create_node(
+                    graph_module.graph, rsqrt_op, args=(add0,), from_node=node
+                )
+                mul0 = create_node(
+                    graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
                 )
-                add0 = create_node(graph_module.graph, add_op, args=(var, full))
-                rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,))
-                mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt))
                 if weights is not None:
                     weights_reshaped = create_node(
                         graph_module.graph,
                         view_op,
                         args=(weights, weights_reshaped_shape),
+                        from_node=node,
                     )
                     mul1 = create_node(
-                        graph_module.graph, mul_op, args=(mul0, weights_reshaped)
+                        graph_module.graph,
+                        mul_op,
+                        args=(
+                            mul0,
+                            weights_reshaped,
+                        ),
+                        from_node=node,
                     )
                 else:
                     mul1 = mul0
@@ -136,10 +152,16 @@ def call(self, graph_module: torch.fx.GraphModule):
                 if bias is not None:
                     bias_reshaped_shape = weights_reshaped_shape
                     bias_reshaped = create_node(
-                        graph_module.graph, view_op, args=(bias, bias_reshaped_shape)
+                        graph_module.graph,
+                        view_op,
+                        args=(bias, bias_reshaped_shape),
+                        from_node=node,
                     )
                     output = create_node(
-                        graph_module.graph, add_op, args=(mul1, bias_reshaped)
+                        graph_module.graph,
+                        add_op,
+                        args=(mul1, bias_reshaped),
+                        from_node=node,
                     )
 
                 users = [user for user in node.users if node != user]
diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py
index abf5c8f363d..9bcfb72916a 100644
--- a/backends/arm/_passes/decompose_meandim_pass.py
+++ b/backends/arm/_passes/decompose_meandim_pass.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
 # All rights reserved.
 #
 # This source code is licensed under the BSD-style license found in the
@@ -7,9 +7,9 @@
 # pyre-unsafe
 
 import torch
+from executorch.backends.arm._passes.arm_pass import ArmPass
 from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
 from executorch.exir.dialects._ops import ops as exir_ops
-from executorch.exir.pass_base import ExportPass
 
 
 def get_meandim_decomposition(op) -> tuple:
@@ -28,7 +28,7 @@ def get_meandim_decomposition(op) -> tuple:
     raise RuntimeError(f"Can't get meandim decomposition for op {op}")
 
 
-class DecomposeMeanDimPass(ExportPass):
+class DecomposeMeanDimPass(ArmPass):
     """
     This pass decomposes meandim into a sum and mul node.
 
@@ -62,8 +62,8 @@ def call_operator(self, op, args, kwargs, meta):
 
         sum_op, full_op, mul_op = get_meandim_decomposition(op)
 
-        sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta)
+        sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True)
         full = super().call_operator(
-            full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta
+            full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True
         )
-        return super().call_operator(mul_op, (sum, full), {}, meta)
+        return super().call_operator(mul_op, (sum, full), {}, meta, True)
diff --git a/backends/arm/_passes/decompose_softmax_unstable_pass.py b/backends/arm/_passes/decompose_softmax_unstable_pass.py
index 4a2ce712ab7..e4c9113da13 100644
--- a/backends/arm/_passes/decompose_softmax_unstable_pass.py
+++ b/backends/arm/_passes/decompose_softmax_unstable_pass.py
@@ -6,8 +6,8 @@
 # pyre-unsafe
 
 import torch
+from executorch.backends.arm._passes.arm_pass import ArmPass
 from executorch.exir.dialects._ops import ops as exir_ops
-from executorch.exir.pass_base import ExportPass
 
 # For BI case
 torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
@@ -45,7 +45,7 @@ def get_logsoftmax_ops(op) -> tuple:
     raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")
 
 
-class DecomposeSoftmaxUnstablePass(ExportPass):
+class DecomposeSoftmaxUnstablePass(ArmPass):
     """
     This pass decomposes log softmax or softmax into more primitive ops.
 
@@ -66,10 +66,10 @@ def call_operator(self, op, args, kwargs, meta):
         _input = args[0]
         dim = [args[1]]
 
-        op1 = super().call_operator(exp_op, (_input,), {}, meta)
-        op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta)
-        op3 = super().call_operator(reciprocal_op, (op2,), {}, meta)
-        op4 = super().call_operator(mul_op, (op1, op3), {}, meta)
+        op1 = super().call_operator(exp_op, (_input,), {}, meta, True)
+        op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta, True)
+        op3 = super().call_operator(reciprocal_op, (op2,), {}, meta, True)
+        op4 = super().call_operator(mul_op, (op1, op3), {}, meta, True)
         if op in log_softmax:
-            op4 = super().call_operator(log_op, (op4,), {}, meta)
+            op4 = super().call_operator(log_op, (op4,), {}, meta, True)
         return op4
diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py
index 73747d8313d..0c43cd1b9cb 100644
--- a/backends/arm/_passes/decompose_var_pass.py
+++ b/backends/arm/_passes/decompose_var_pass.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
 # All rights reserved.
 #
 # This source code is licensed under the BSD-style license found in the
@@ -8,9 +8,9 @@
 
 
 import torch
+from executorch.backends.arm._passes.arm_pass import ArmPass
 from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
 from executorch.exir.dialects._ops import ops as exir_ops
-from executorch.exir.pass_base import ExportPass
 
 
 def get_var_decomposition(op) -> tuple:
@@ -33,7 +33,7 @@ def get_var_decomposition(op) -> tuple:
     raise RuntimeError(f"Can't get var decomposition for op {op}")
 
 
-class DecomposeVarPass(ExportPass):
+class DecomposeVarPass(ArmPass):
     """
     This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html)
 
@@ -77,14 +77,17 @@ def call_operator(self, op, args, kwargs, meta):
             N *= input_shape[d]
 
         mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
-        mean = super().call_operator(mean_op, (x, dim, True), {}, meta)
-        diff = super().call_operator(diff_op, (x, mean), {}, meta)
-        squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
-        sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
+        mean = super().call_operator(mean_op, (x, dim, True), {}, meta, True)
+        diff = super().call_operator(diff_op, (x, mean), {}, meta, True)
+        squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta, True)
+        sum = super().call_operator(
+            sum_op, (squared_diff, dim, keepdim), {}, meta, True
+        )
         full = super().call_operator(
             full_op,
             ([], 1 / max(0, N - correction)),
             {"dtype": dtype},
             meta,
+            True,
         )
-        return super().call_operator(mul_op, (sum, full), {}, meta)
+        return super().call_operator(mul_op, (sum, full), {}, meta, True)
diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py
index 602d4a007a6..34ac7553212 100644
--- a/backends/arm/_passes/mm_to_bmm_pass.py
+++ b/backends/arm/_passes/mm_to_bmm_pass.py
@@ -47,7 +47,7 @@ def call(self, graph_module: torch.fx.GraphModule):
 
                 with graph.inserting_before(node):
                     unsqueeze_before = create_node(
-                        graph, exir_ops.edge.aten.unsqueeze_copy.default
+                        graph, exir_ops.edge.aten.unsqueeze_copy.default, from_node=node
                     )
                     unsqueeze_before.args = (
                         input_node,  # Input is node's original input
@@ -58,13 +58,14 @@ def call(self, graph_module: torch.fx.GraphModule):
                 # If Quantized we must insert unsqueeze --> q --> dq --> node
                 if input_node.target == dq_op:
                     q_params = input_node.args[1:]
-                    insert_q_dq_pair(graph, unsqueeze_before, q_params)
+                    insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node)
 
             # Replace mm node with bmm
             with graph.inserting_before(node):
                 bmm_node = create_node(
                     graph,
                     exir_ops.edge.aten.bmm.default,
+                    from_node=node,
                 )
                 bmm_node.args = node.args
                 node.replace_all_uses_with(bmm_node)
@@ -75,6 +76,7 @@ def call(self, graph_module: torch.fx.GraphModule):
                 squeeze_after = create_node(
                     graph,
                     exir_ops.edge.aten.squeeze_copy.dims,
+                    from_node=node,
                 )
                 squeeze_after.args = (
                     bmm_node,
@@ -89,7 +91,7 @@ def call(self, graph_module: torch.fx.GraphModule):
             # If quantized, insert mm --> q --> dq --> squeeze
             if all(original_user.target == q_op for original_user in original_users):
                 q_params = original_users[0].args[1:]
-                insert_q_dq_pair(graph, bmm_node, q_params)
+                insert_q_dq_pair(graph, bmm_node, q_params, from_node=node)
 
             modified_graph = True
 
diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py
index 788ebf39696..e265f81dc88 100644
--- a/backends/arm/tosa_utils.py
+++ b/backends/arm/tosa_utils.py
@@ -14,7 +14,7 @@
 from executorch.backends.arm.tosa_mapping import TosaArg
 
 from executorch.exir.dialects._ops import ops as exir_ops
-from executorch.exir.print_program import inspect_node
+from executorch.exir.print_program import add_cursor_to_graph
 from serializer.tosa_serializer import TosaOp
 from torch.fx import Node
 
@@ -32,7 +32,8 @@ def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule):
 
 def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> str:
     output = (
-        f"  {inspect_node(graph=graph_module.graph, node=node)}\n"
+        " Here is the node in the graph:\n"
+        f"  {add_cursor_to_graph(graph=graph_module.graph, finding_node=node)}\n"
         "-- NODE DEBUG INFO --\n"
         f"  Op is {node.op}\n"
         f"  Name is {node.name}\n"
@@ -43,10 +44,17 @@ def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule)
         "  Node.meta = \n"
     )
     for k, v in node.meta.items():
-        output += f"    '{k}' = {v}\n"
-        if isinstance(v, list):
-            for i in v:
-                output += f"      {i}\n"
+        if k == "stack_trace":
+            matches = v.split("\n")
+            output += "      'stack_trace =\n"
+            for m in matches:
+                output += f"      {m}\n"
+        else:
+            output += f"    '{k}' = {v}\n"
+
+            if isinstance(v, list):
+                for i in v:
+                    output += f"      {i}\n"
     return output
 
 
@@ -78,10 +86,10 @@ def dbg_fail(
     tosa_graph: Optional[ts.TosaSerializer] = None,
     path: Optional[str] = None,
 ):
-    logger.warning("Internal error due to poorly handled node:")
+    logger.warning(f" Internal error due to poorly handled node: {node.name}")
     if tosa_graph is not None and path is not None:
         dbg_tosa_dump(tosa_graph, path)
-        logger.warning(f"Debug output captured in '{path}'.")
+        logger.warning(f" Debug output captured in '{path}'.")
     dbg_node(node, graph_module)