Skip to content

Commit 46db175

Browse files
authored
Merge branch 'pytorch:main' into etdump-file-data-sink-test-coverage
2 parents 47f1c53 + a05c4da commit 46db175

40 files changed

+719
-234
lines changed

.lintrunner.toml

+26
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,29 @@ init_command = [
343343
'--dry-run={{DRYRUN}}',
344344
'--requirement=requirements-lintrunner.txt',
345345
]
346+
347+
[[linter]]
348+
code = 'LICENSELINT'
349+
include_patterns = [
350+
'**/*',
351+
]
352+
exclude_patterns = [
353+
'**/fb/**',
354+
'.lintrunner.toml',
355+
]
356+
command = [
357+
'python',
358+
'-m',
359+
'lintrunner_adapters',
360+
'run',
361+
'grep_linter',
362+
'--pattern=Confidential and proprietary',
363+
'--linter-name=LICENSELINT',
364+
'--error-name=Wrong license',
365+
"""--error-description=\
366+
Code contributed to ExecuTorch open source repo should have \
367+
BSD-license header \
368+
""",
369+
'--',
370+
'@{{PATHSFILE}}',
371+
]

backends/arm/_passes/arm_pass_manager.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
199199
)
200200

201201
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
202-
self.add_pass(ScalarsToAttributePass())
203202
self.add_pass(ReplaceScalarWithTensorArgPass())
203+
self.add_pass(ScalarsToAttributePass())
204204
self.add_pass(DecomposeLayerNormPass())
205205
self.add_pass(DecomposeVarPass())
206206
self.add_pass(DecomposeMeanDimPass())

backends/arm/_passes/arm_pass_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-unsafe
99

1010
from inspect import isclass
11-
from typing import Optional
11+
from typing import Optional, Sequence
1212

1313
import torch
1414
import torch.fx
@@ -149,7 +149,7 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
149149
If the node contains many fake tensors, return the first one.
150150
"""
151151
if isinstance(
152-
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
152+
node.meta["val"], (Sequence, torch.fx.immutable_collections.immutable_list)
153153
):
154154
fake_tensor = node.meta["val"][0]
155155
else:

backends/arm/ethosu_backend.py

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
# debug functionality
2525
logger = logging.getLogger(__name__)
26-
logger.setLevel(logging.WARNING)
2726

2827

2928
@final

backends/arm/operator_support/tosa_supported_operators.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
FuseQuantizedActivationPass,
1919
)
2020
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
21+
from executorch.exir import ExportedProgram
2122
from executorch.exir.dialects._ops import ops as exir_ops
23+
from torch.export.graph_signature import InputKind
2224
from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase
2325
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
2426

@@ -84,9 +86,10 @@ def get_registered_tosa_support_checks(
8486

8587
def tosa_support_factory(
8688
tosa_spec: TosaSpecification,
89+
exported_program: ExportedProgram,
8790
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
8891
) -> OperatorSupportBase:
89-
negative_checks: list[OperatorSupportBase] = []
92+
negative_checks: list[OperatorSupportBase] = [CheckInt64Inputs(exported_program)]
9093
if not tosa_spec.support_float():
9194
negative_checks.append(NeedsDecompositionCheck())
9295
negative_checks.append(CheckProperQuantization())
@@ -247,6 +250,10 @@ def is_node_supported(
247250
exir_ops.edge.aten._log_softmax.default,
248251
exir_ops.edge.aten.var.correction,
249252
exir_ops.edge.aten.var.dim,
253+
exir_ops.edge.aten.add.Scalar,
254+
exir_ops.edge.aten.sub.Scalar,
255+
exir_ops.edge.aten.mul.Scalar,
256+
exir_ops.edge.aten.div.Scalar,
250257
]
251258
return not needs_decomp
252259

@@ -312,6 +319,8 @@ def is_node_supported(
312319
exir_ops.edge.aten.bmm.default,
313320
exir_ops.edge.aten.convolution.default,
314321
exir_ops.edge.aten.exp.default,
322+
exir_ops.edge.aten.full.default,
323+
exir_ops.edge.aten.full_like.default,
315324
exir_ops.edge.aten.hardtanh.default,
316325
exir_ops.edge.aten.linear.default,
317326
exir_ops.edge.aten.log.default,
@@ -371,3 +380,29 @@ def is_node_supported(
371380
if not output_quantized:
372381
return False
373382
return True
383+
384+
385+
class CheckInt64Inputs(OperatorSupportBase):
386+
387+
def __init__(self, exported_program: ExportedProgram):
388+
self.input_names = [
389+
spec.arg.name
390+
for spec in exported_program.graph_signature.input_specs
391+
if spec.kind == InputKind.USER_INPUT
392+
]
393+
super().__init__()
394+
395+
def is_node_supported(
396+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
397+
) -> bool:
398+
399+
for input_node in node.all_input_nodes:
400+
# We can cast constant placeholders AOT, not call_functions.
401+
if (
402+
input_node.name in self.input_names
403+
or not input_node.op == "placeholder"
404+
):
405+
tensor = get_first_fake_tensor(input_node)
406+
if tensor.dtype == torch.int64:
407+
return False
408+
return True

backends/arm/process_node.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
from executorch.backends.arm.operators.node_visitor import NodeVisitor
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_specification import TosaSpecification
17-
from executorch.backends.arm.tosa_utils import (
18-
get_node_debug_info,
19-
getNodeArgs,
20-
tosa_shape,
21-
)
17+
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
2218
from torch.export.exported_program import ExportedProgram
2319

2420

@@ -36,7 +32,7 @@ def process_call_function(
3632
output = TosaArg(node)
3733
except ValueError as e:
3834
raise ValueError(
39-
f"Failed processing call_function:\n{get_node_debug_info(node)}"
35+
f"Failed processing call_function: {node.name}. "
4036
"Is the original torch function supported?"
4137
) from e
4238
tosa_graph.currRegion.currBasicBlock.addTensor(
@@ -74,7 +70,7 @@ def process_inputs(
7470
tosa_arg = TosaArg(node)
7571
except ValueError as e:
7672
raise ValueError(
77-
f"Failed processing input placeholder:\n{get_node_debug_info(node)}"
73+
f"Failed processing input placeholder: {node.name}. "
7874
"Is the original torch function supported?"
7975
) from e
8076
input_shape = tosa_arg.shape
@@ -100,7 +96,7 @@ def process_inputs_to_parameters(
10096
tosa_arg = TosaArg(node)
10197
except ValueError as e:
10298
raise ValueError(
103-
f"Failed processing parameter placeholder:\n{get_node_debug_info(node)}"
99+
f"Failed processing parameter placeholder: {node.name}. "
104100
"Is the original torch function supported?"
105101
) from e
106102
parameter_name = edge_program.graph_signature.inputs_to_parameters[tosa_arg.name]
@@ -129,7 +125,7 @@ def process_inputs_to_buffers(
129125
tosa_arg = TosaArg(node)
130126
except ValueError as e:
131127
raise ValueError(
132-
f"Failed processing buffer placeholder:\n{get_node_debug_info(node)}"
128+
f"Failed processing buffer placeholder: {node.name}. "
133129
"Is the original torch function supported?"
134130
) from e
135131
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
@@ -157,7 +153,7 @@ def process_inputs_to_lifted_tensor_constants(
157153
tosa_arg = TosaArg(node)
158154
except ValueError as e:
159155
raise ValueError(
160-
f"Failed processing lifted tensor constant placeholder:\n{get_node_debug_info(node)}"
156+
f"Failed processing lifted tensor constant placeholder: {node.name}. "
161157
"Is the original torch function supported?"
162158
) from e
163159
tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[

backends/arm/scripts/build_executorch_runner.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ source ${setup_path_script}
7878
pte_file=$(realpath ${pte_file})
7979
ethosu_tools_dir=$(realpath ${ethosu_tools_dir})
8080
ethos_u_root_dir="$ethosu_tools_dir/ethos-u"
81+
mkdir -p "${ethos_u_root_dir}"
8182
ethosu_tools_dir=$(realpath ${ethos_u_root_dir})
8283

8384
et_build_dir=${et_build_root}/cmake-out
@@ -106,6 +107,7 @@ then
106107
fi
107108
fi
108109

110+
mkdir -p "${output_folder}"
109111
output_folder=$(realpath ${output_folder})
110112

111113
if [[ ${target} == *"ethos-u55"* ]]; then
@@ -128,7 +130,6 @@ if [ "$build_with_etdump" = true ] ; then
128130
fi
129131

130132
echo "Building with BundleIO/etdump/extra flags: ${build_bundleio_flags} ${build_with_etdump_flags} ${extra_build_flags}"
131-
mkdir -p "${output_folder}"
132133

133134
cmake \
134135
-DCMAKE_BUILD_TYPE=${build_type} \

backends/arm/test/misc/test_debug_feats.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import shutil
1010
import tempfile
1111
import unittest
12+
from importlib.metadata import version
1213

1314
import torch
1415
from executorch.backends.arm.test import common
@@ -192,15 +193,16 @@ def test_collate_tosa_BI_tests(self):
192193
.to_edge_transform_and_lower()
193194
.to_executorch()
194195
)
196+
et_version = version("executorch")
195197
# test that the output directory is created and contains the expected files
196198
assert os.path.exists(
197199
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests"
198200
)
199201
assert os.path.exists(
200-
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag6.tosa"
202+
f"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag6_TOSA-0.80+BI_{et_version}.tosa"
201203
)
202204
assert os.path.exists(
203-
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag6.json"
205+
f"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag6_TOSA-0.80+BI_{et_version}.json"
204206
)
205207

206208
os.environ.pop("TOSA_TESTCASES_BASE_PATH")

backends/arm/test/models/test_conformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_conformer_tosa_BI(self):
9393
)
9494
)
9595

96-
@unittest.expectedFailure # TODO(MLETORCH-635)
96+
@conftest.expectedFailureOnFVP # TODO(MLETORCH-635)
9797
def test_conformer_u55_BI(self):
9898
tester = (
9999
ArmTester(
@@ -115,7 +115,7 @@ def test_conformer_u55_BI(self):
115115
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
116116
)
117117

118-
@unittest.expectedFailure # TODO(MLETORCH-635)
118+
@conftest.expectedFailureOnFVP # TODO(MLETORCH-635)
119119
def test_conformer_u85_BI(self):
120120
tester = (
121121
ArmTester(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Tests 10 popular torch.nn.functional not tested in other ways or training related
8+
- normalize
9+
- grid_sample
10+
- one_hot
11+
- softplus
12+
- cosine_similarity
13+
- unfold
14+
- elu
15+
- fold
16+
- affine_grid
17+
- max_pool1d
18+
- threshold
19+
"""
20+
from typing import Callable
21+
22+
import torch
23+
from executorch.backends.arm.test.common import parametrize
24+
from executorch.backends.arm.test.tester.test_pipeline import (
25+
TosaPipelineBI,
26+
TosaPipelineMI,
27+
)
28+
29+
30+
def module_factory(function: Callable) -> torch.nn.Module:
31+
class ModuleWrapper(torch.nn.Module):
32+
def forward(self, *args):
33+
return function(*args)
34+
35+
return ModuleWrapper()
36+
37+
38+
example_input = torch.rand(1, 6, 16, 16)
39+
40+
module_tests = {
41+
"normalize": (module_factory(torch.nn.functional.normalize), (example_input,)),
42+
"grid_sample": (
43+
module_factory(torch.nn.functional.grid_sample),
44+
(torch.rand(1, 1, 4, 4), torch.rand(1, 5, 5, 2)),
45+
),
46+
"one_hot": (
47+
module_factory(torch.nn.functional.one_hot),
48+
(torch.randint(0, 5, (2, 2, 5, 5)), 5),
49+
),
50+
"softplus": (module_factory(torch.nn.functional.softplus), (example_input,)),
51+
"cosine_similarity": (
52+
module_factory(torch.nn.functional.cosine_similarity),
53+
(example_input, example_input),
54+
),
55+
"unfold": (
56+
module_factory(torch.nn.functional.unfold),
57+
(torch.randn(1, 3, 10, 12), (4, 5)),
58+
),
59+
"elu": (module_factory(torch.nn.functional.elu), (example_input,)),
60+
"fold": (
61+
module_factory(torch.nn.functional.fold),
62+
(torch.randn(1, 12, 12), (4, 5), (2, 2)),
63+
),
64+
"affine_grid": (
65+
module_factory(torch.nn.functional.affine_grid),
66+
(torch.rand(1, 2, 3), (1, 2, 10, 10)),
67+
),
68+
"max_pool1d": (
69+
module_factory(torch.nn.functional.max_pool1d),
70+
(torch.randn(20, 16, 50), 4),
71+
),
72+
"threshold": (
73+
module_factory(torch.nn.functional.threshold),
74+
(example_input, 0.5, 0.1),
75+
),
76+
}
77+
78+
input_t = tuple[torch.Tensor]
79+
80+
81+
@parametrize(
82+
"test_data", module_tests, xfails={"max_pool1d": "ValueError: Invalid TOSA graph"}
83+
)
84+
def test_nn_functional_MI(test_data):
85+
module, inputs = test_data
86+
pipeline = TosaPipelineMI[input_t](
87+
module, inputs, "", use_to_edge_transform_and_lower=True
88+
)
89+
pipeline.pop_stage("check.aten")
90+
pipeline.pop_stage("check_count.exir")
91+
try:
92+
pipeline.run()
93+
except RuntimeError as e:
94+
if (
95+
"Ran model with TosaReferenceModelDispatch but never ran TOSABackend delegate."
96+
not in str(e)
97+
):
98+
raise e
99+
100+
101+
@parametrize("test_data", module_tests)
102+
def test_nn_functional_BI(test_data):
103+
module, inputs = test_data
104+
pipeline = TosaPipelineBI[input_t](
105+
module, inputs, "", use_to_edge_transform_and_lower=True
106+
)
107+
pipeline.pop_stage("check.aten")
108+
pipeline.pop_stage("check_count.exir")
109+
pipeline.pop_stage("check.quant_nodes")
110+
pipeline.pop_stage("check_not.quant_nodes")
111+
try:
112+
pipeline.run()
113+
except RuntimeError as e:
114+
if (
115+
"Ran model with TosaReferenceModelDispatch but never ran TOSABackend delegate."
116+
not in str(e)
117+
):
118+
raise e

0 commit comments

Comments
 (0)