Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic ambiguity check in the tests #9371

Merged
merged 2 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from executorch.exir.pass_base import ExportPass, ProxyValue

from executorch.exir.tests.test_memory_format_ops_pass_utils import (
AmbiguousDimOrderError,
MemoryFormatOpsPassTestUtils,
MemoryFormatTestSet,
PropagateToCopyChannalsLastModule,
Expand Down Expand Up @@ -124,8 +125,34 @@ def test_op_dim_order_propagation(self) -> None:
target_memory_format=torch.channels_last,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
check_unambiguous_dim_order=True,
)

def test_op_dim_order_propagation_ambiguous(self) -> None:
try:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=PropagateToCopyChannalsLastModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(
torch.rand_like(
torch.zeros(
[2, 1, 2, 2]
), # Ambiguous shape should trigger AmbiguousDimOrderError!
dtype=torch.float32,
memory_format=torch.contiguous_format,
),
),
target_memory_format=torch.channels_last,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
check_unambiguous_dim_order=True,
)
AssertionError("Should have raised AmbiguousDimOrderError")
except AmbiguousDimOrderError:
pass # Expected error

# Only test dim order replacement result in lean mode test.
# This test is irrelevant with operator mode.
def test_dim_order_replacement(self) -> None:
Expand Down
65 changes: 64 additions & 1 deletion exir/tests/test_memory_format_ops_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
is_channel_last_dim_order,
is_contiguous_dim_order,
)
from executorch.exir.pass_base import ExportPass

from torch.export import export

from torch.fx.passes.infra.pass_manager import PassManager
from torch.testing import FileCheck
from torch.utils._pytree import tree_flatten

Expand Down Expand Up @@ -99,10 +102,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return t1 * t2


class AmbiguousDimOrderError(RuntimeError):
pass


def assert_unambiguous_dim_order(gm):
class ExampleNOPPass(ExportPass):
"""
Does nothing!
"""

def call_operator(self, op, args, kwargs, meta):
return super().call_operator(
op,
args,
kwargs,
meta,
)

# This is an example of how one can detect ambiguous dim_order anywhere in the graph.
# You can be surgical and only detect it in the nodes you are interested in or something else.
def detect_ambiguity(gm):
"""
Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats.
"""

def get_tensors(node: torch.fx.Node) -> torch.Tensor:
val = node.meta["val"]
if isinstance(val, torch.Tensor):
return [val]
elif isinstance(val, (list, tuple)):
return [tensor for tensor in val if isinstance(tensor, torch.Tensor)]
return []

for node in gm.graph.nodes:
if node.op == "call_function":
for tensor in get_tensors(node):
# Let's make sure dim_order is not ambiguous, raise otherwise.
# This is raising because we can't do anything about it.
# The right course of follow up action is to ask user to try with a different example input.
try:
_ = tensor.dim_order(
ambiguity_check=[
torch.contiguous_format,
torch.channels_last,
]
)
except Exception:
raise AmbiguousDimOrderError

# any pass or passes, just using MemoryFormatOpsPass as an example
dim_order_pass_manager = PassManager(passes=[ExampleNOPPass()])
dim_order_pass_manager.add_checks(detect_ambiguity)
dim_order_pass_manager(gm)


class MemoryFormatOpsPassTestUtils:
@staticmethod
def memory_format_test_runner(
test_class: unittest.TestCase, test_set: MemoryFormatTestSet
test_class: unittest.TestCase,
test_set: MemoryFormatTestSet,
check_unambiguous_dim_order: bool = False,
):
before = export(
test_set.module, test_set.sample_input, strict=True
Expand All @@ -121,6 +181,9 @@ def memory_format_test_runner(
before, compile_config=EdgeCompileConfig(_skip_dim_order=False)
)

if check_unambiguous_dim_order:
assert_unambiguous_dim_order(epm.exported_program().graph_module)

# check memory format ops, if needed
if test_set.op_level_check:
aten_op_str, edge_op_str = MemoryFormatOps2Str[test_set.op]
Expand Down
Loading