Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm backend: Make passes preserve and update node metadata #9362

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions backends/arm/_passes/arm_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import traceback
from typing import Optional

import torch
from executorch.exir.pass_base import ExportPass, NodeMetadata


class ArmPass(ExportPass):
"""Base class for Arm passes"""

def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None):
super(ArmPass, self).__init__()
self.exported_program = exported_program

def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
if not updated:
return super().call_operator(op, args, kwargs, meta)

# if updated we should update metadata
new_meta = {}
keys = meta.data.keys()
for key in keys:
new_meta[key] = meta[key]
old_stack_trace = new_meta.get("stack_trace", "")
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
return super().call_operator(op, args, kwargs, NodeMetadata(new_meta))
18 changes: 16 additions & 2 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

# pyre-unsafe

import traceback
from inspect import isclass
from typing import Optional

import torch
import torch.fx

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops

Expand Down Expand Up @@ -96,6 +96,7 @@ def create_node(
kwargs: Optional[dict] = None,
quantize: bool = False,
q_params: Optional[tuple] = None,
from_node: Optional[torch.fx.Node] = None,
):
"""
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
Expand All @@ -108,15 +109,26 @@ def create_node(
args=args,
kwargs=kwargs or {},
)

new_meta = {}
if from_node:
keys = from_node.meta.keys()
for key in keys:
new_meta[key] = from_node.meta[key]
old_stack_trace = new_meta.get("stack_trace", "")
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
node.meta = new_meta

if quantize and q_params:
return insert_q_dq_pair(graph, node, q_params)
return insert_q_dq_pair(graph, node, q_params, from_node)
return node


def insert_q_dq_pair(
graph: torch.fx.Graph,
anchor: torch.fx.Node,
q_params: tuple,
from_node: Optional[torch.fx.Node] = None,
):
"""
Inserts a q dq node pair after the node 'anchor'.
Expand All @@ -127,13 +139,15 @@ def insert_q_dq_pair(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(), # We add the argument last
from_node=from_node if from_node else anchor,
)
q.meta = anchor.meta
with graph.inserting_after(q):
dq = create_node(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q,) + q_params,
from_node=from_node if from_node else anchor,
)
dq.meta = q.meta
anchor.replace_all_uses_with(dq)
Expand Down
38 changes: 30 additions & 8 deletions backends/arm/_passes/decompose_layernorm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import operator

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_base import PassResult


def get_layer_norm_decomposition(op) -> tuple:
Expand Down Expand Up @@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple:
raise RuntimeError(f"Can't get layer_norm composition for op {op}")


class DecomposeLayerNormPass(ExportPass):
class DecomposeLayerNormPass(ArmPass):
"""
layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
Expand Down Expand Up @@ -111,35 +112,56 @@ def call(self, graph_module: torch.fx.GraphModule):
var_op,
args=(x, dims),
kwargs={"correction": 0, "keepdim": keepdim},
from_node=node,
)
full = create_node(
graph_module.graph,
full_op,
args=(epsilon_reshaped_shape, epsilon),
kwargs={"dtype": dtype},
from_node=node,
)
add0 = create_node(
graph_module.graph, add_op, args=(var, full), from_node=node
)
rsqrt = create_node(
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
)
mul0 = create_node(
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
)
add0 = create_node(graph_module.graph, add_op, args=(var, full))
rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,))
mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt))
if weights is not None:
weights_reshaped = create_node(
graph_module.graph,
view_op,
args=(weights, weights_reshaped_shape),
from_node=node,
)
mul1 = create_node(
graph_module.graph, mul_op, args=(mul0, weights_reshaped)
graph_module.graph,
mul_op,
args=(
mul0,
weights_reshaped,
),
from_node=node,
)
else:
mul1 = mul0
output = mul1
if bias is not None:
bias_reshaped_shape = weights_reshaped_shape
bias_reshaped = create_node(
graph_module.graph, view_op, args=(bias, bias_reshaped_shape)
graph_module.graph,
view_op,
args=(bias, bias_reshaped_shape),
from_node=node,
)
output = create_node(
graph_module.graph, add_op, args=(mul1, bias_reshaped)
graph_module.graph,
add_op,
args=(mul1, bias_reshaped),
from_node=node,
)

users = [user for user in node.users if node != user]
Expand Down
12 changes: 6 additions & 6 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -7,9 +7,9 @@
# pyre-unsafe

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


def get_meandim_decomposition(op) -> tuple:
Expand All @@ -28,7 +28,7 @@ def get_meandim_decomposition(op) -> tuple:
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


class DecomposeMeanDimPass(ExportPass):
class DecomposeMeanDimPass(ArmPass):
"""
This pass decomposes meandim into a sum and mul node.

Expand Down Expand Up @@ -62,8 +62,8 @@ def call_operator(self, op, args, kwargs, meta):

sum_op, full_op, mul_op = get_meandim_decomposition(op)

sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta)
sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True)
full = super().call_operator(
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True
)
return super().call_operator(mul_op, (sum, full), {}, meta)
return super().call_operator(mul_op, (sum, full), {}, meta, True)
14 changes: 7 additions & 7 deletions backends/arm/_passes/decompose_softmax_unstable_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# pyre-unsafe

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

# For BI case
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
Expand Down Expand Up @@ -45,7 +45,7 @@ def get_logsoftmax_ops(op) -> tuple:
raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")


class DecomposeSoftmaxUnstablePass(ExportPass):
class DecomposeSoftmaxUnstablePass(ArmPass):
"""
This pass decomposes log softmax or softmax into more primitive ops.

Expand All @@ -66,10 +66,10 @@ def call_operator(self, op, args, kwargs, meta):
_input = args[0]
dim = [args[1]]

op1 = super().call_operator(exp_op, (_input,), {}, meta)
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta)
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta)
op4 = super().call_operator(mul_op, (op1, op3), {}, meta)
op1 = super().call_operator(exp_op, (_input,), {}, meta, True)
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta, True)
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta, True)
op4 = super().call_operator(mul_op, (op1, op3), {}, meta, True)
if op in log_softmax:
op4 = super().call_operator(log_op, (op4,), {}, meta)
op4 = super().call_operator(log_op, (op4,), {}, meta, True)
return op4
19 changes: 11 additions & 8 deletions backends/arm/_passes/decompose_var_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -8,9 +8,9 @@


import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


def get_var_decomposition(op) -> tuple:
Expand All @@ -33,7 +33,7 @@ def get_var_decomposition(op) -> tuple:
raise RuntimeError(f"Can't get var decomposition for op {op}")


class DecomposeVarPass(ExportPass):
class DecomposeVarPass(ArmPass):
"""
This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html)

Expand Down Expand Up @@ -77,14 +77,17 @@ def call_operator(self, op, args, kwargs, meta):
N *= input_shape[d]

mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
mean = super().call_operator(mean_op, (x, dim, True), {}, meta)
diff = super().call_operator(diff_op, (x, mean), {}, meta)
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
mean = super().call_operator(mean_op, (x, dim, True), {}, meta, True)
diff = super().call_operator(diff_op, (x, mean), {}, meta, True)
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta, True)
sum = super().call_operator(
sum_op, (squared_diff, dim, keepdim), {}, meta, True
)
full = super().call_operator(
full_op,
([], 1 / max(0, N - correction)),
{"dtype": dtype},
meta,
True,
)
return super().call_operator(mul_op, (sum, full), {}, meta)
return super().call_operator(mul_op, (sum, full), {}, meta, True)
8 changes: 5 additions & 3 deletions backends/arm/_passes/mm_to_bmm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def call(self, graph_module: torch.fx.GraphModule):

with graph.inserting_before(node):
unsqueeze_before = create_node(
graph, exir_ops.edge.aten.unsqueeze_copy.default
graph, exir_ops.edge.aten.unsqueeze_copy.default, from_node=node
)
unsqueeze_before.args = (
input_node, # Input is node's original input
Expand All @@ -58,13 +58,14 @@ def call(self, graph_module: torch.fx.GraphModule):
# If Quantized we must insert unsqueeze --> q --> dq --> node
if input_node.target == dq_op:
q_params = input_node.args[1:]
insert_q_dq_pair(graph, unsqueeze_before, q_params)
insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node)

# Replace mm node with bmm
with graph.inserting_before(node):
bmm_node = create_node(
graph,
exir_ops.edge.aten.bmm.default,
from_node=node,
)
bmm_node.args = node.args
node.replace_all_uses_with(bmm_node)
Expand All @@ -75,6 +76,7 @@ def call(self, graph_module: torch.fx.GraphModule):
squeeze_after = create_node(
graph,
exir_ops.edge.aten.squeeze_copy.dims,
from_node=node,
)
squeeze_after.args = (
bmm_node,
Expand All @@ -89,7 +91,7 @@ def call(self, graph_module: torch.fx.GraphModule):
# If quantized, insert mm --> q --> dq --> squeeze
if all(original_user.target == q_op for original_user in original_users):
q_params = original_users[0].args[1:]
insert_q_dq_pair(graph, bmm_node, q_params)
insert_q_dq_pair(graph, bmm_node, q_params, from_node=node)

modified_graph = True

Expand Down
Loading
Loading