diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 76e994abdbf..84cd0faa485 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -24,6 +24,7 @@ from executorch.exir.pass_base import ExportPass, ProxyValue from executorch.exir.tests.test_memory_format_ops_pass_utils import ( + AmbiguousDimOrderError, MemoryFormatOpsPassTestUtils, MemoryFormatTestSet, PropagateToCopyChannalsLastModule, @@ -124,8 +125,34 @@ def test_op_dim_order_propagation(self) -> None: target_memory_format=torch.channels_last, _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, ), + check_unambiguous_dim_order=True, ) + def test_op_dim_order_propagation_ambiguous(self) -> None: + try: + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=PropagateToCopyChannalsLastModule().eval(), + op=torch.ops.aten._to_copy.default, + sample_input=( + torch.rand_like( + torch.zeros( + [2, 1, 2, 2] + ), # Ambiguous shape should trigger AmbiguousDimOrderError! + dtype=torch.float32, + memory_format=torch.contiguous_format, + ), + ), + target_memory_format=torch.channels_last, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + check_unambiguous_dim_order=True, + ) + AssertionError("Should have raised AmbiguousDimOrderError") + except AmbiguousDimOrderError: + pass # Expected error + # Only test dim order replacement result in lean mode test. # This test is irrelevant with operator mode. def test_dim_order_replacement(self) -> None: diff --git a/exir/tests/test_memory_format_ops_pass_utils.py b/exir/tests/test_memory_format_ops_pass_utils.py index 8bf810e847e..b54f2f4a90a 100644 --- a/exir/tests/test_memory_format_ops_pass_utils.py +++ b/exir/tests/test_memory_format_ops_pass_utils.py @@ -20,8 +20,11 @@ is_channel_last_dim_order, is_contiguous_dim_order, ) +from executorch.exir.pass_base import ExportPass from torch.export import export + +from torch.fx.passes.infra.pass_manager import PassManager from torch.testing import FileCheck from torch.utils._pytree import tree_flatten @@ -99,10 +102,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return t1 * t2 +class AmbiguousDimOrderError(RuntimeError): + pass + + +def assert_unambiguous_dim_order(gm): + class ExampleNOPPass(ExportPass): + """ + Does nothing! + """ + + def call_operator(self, op, args, kwargs, meta): + return super().call_operator( + op, + args, + kwargs, + meta, + ) + + # This is an example of how one can detect ambiguous dim_order anywhere in the graph. + # You can be surgical and only detect it in the nodes you are interested in or something else. + def detect_ambiguity(gm): + """ + Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats. + """ + + def get_tensors(node: torch.fx.Node) -> torch.Tensor: + val = node.meta["val"] + if isinstance(val, torch.Tensor): + return [val] + elif isinstance(val, (list, tuple)): + return [tensor for tensor in val if isinstance(tensor, torch.Tensor)] + return [] + + for node in gm.graph.nodes: + if node.op == "call_function": + for tensor in get_tensors(node): + # Let's make sure dim_order is not ambiguous, raise otherwise. + # This is raising because we can't do anything about it. + # The right course of follow up action is to ask user to try with a different example input. + try: + _ = tensor.dim_order( + ambiguity_check=[ + torch.contiguous_format, + torch.channels_last, + ] + ) + except Exception: + raise AmbiguousDimOrderError + + # any pass or passes, just using MemoryFormatOpsPass as an example + dim_order_pass_manager = PassManager(passes=[ExampleNOPPass()]) + dim_order_pass_manager.add_checks(detect_ambiguity) + dim_order_pass_manager(gm) + + class MemoryFormatOpsPassTestUtils: @staticmethod def memory_format_test_runner( - test_class: unittest.TestCase, test_set: MemoryFormatTestSet + test_class: unittest.TestCase, + test_set: MemoryFormatTestSet, + check_unambiguous_dim_order: bool = False, ): before = export( test_set.module, test_set.sample_input, strict=True @@ -121,6 +181,9 @@ def memory_format_test_runner( before, compile_config=EdgeCompileConfig(_skip_dim_order=False) ) + if check_unambiguous_dim_order: + assert_unambiguous_dim_order(epm.exported_program().graph_module) + # check memory format ops, if needed if test_set.op_level_check: aten_op_str, edge_op_str = MemoryFormatOps2Str[test_set.op]