From ddd976fc803f07b12ee813fc79f91c1d4518378b Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 18 Feb 2025 19:29:36 -0500 Subject: [PATCH] [mlir][python] fix linalg.pack --- mlir/python/mlir/dialects/LinalgOps.td | 1 + mlir/python/mlir/dialects/linalg/__init__.py | 37 +++++++++++++++++++- mlir/test/python/dialects/linalg/ops.py | 21 +++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td index b7658c85a9c44..89fb3f219e858 100644 --- a/mlir/python/mlir/dialects/LinalgOps.td +++ b/mlir/python/mlir/dialects/LinalgOps.td @@ -11,5 +11,6 @@ include "mlir/Dialect/Linalg/IR/LinalgOps.td" include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td" +include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td" #endif diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 5cda4769d593f..2e3424aaddb4f 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -58,7 +58,10 @@ from .opdsl.ops.core_named_ops import * from ...ir import * -from .._ods_common import get_op_result_or_value as _get_op_result_or_value +from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + _dispatch_mixed_values, +) from ...extras.meta import region_op @@ -193,3 +196,35 @@ def contract( ) fill_builtin_region(op.operation) return op + + +def pack( + source, + dest, + inner_dims_pos, + inner_tiles, + *, + padding_value=None, + outer_dims_perm=None, + loc=None, + ip=None, +) -> ir.Value: + + ( + dynamic_inner_tiles, + # packed here means %1:2 packing (results packing) + _inner_tiles, + static_inner_tiles, + ) = _dispatch_mixed_values(inner_tiles) + + return PackOp( + source=source, + dest=dest, + inner_dims_pos=inner_dims_pos, + inner_tiles=dynamic_inner_tiles, + static_inner_tiles=static_inner_tiles, + padding_value=padding_value, + outer_dims_perm=outer_dims_perm, + loc=loc, + ip=ip, + ).result diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 94f8ea4faf4a8..cda73e0c6ce9a 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -466,3 +466,24 @@ def matmul_as_contract_op( ) print(module) + + +# CHECK-LABEL: TEST: testPackOp +@run +def testPackOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + + @func + def tensor_pack( + src: RankedTensorType.get((129, 47, 16, 16), f32), + dst: RankedTensorType((17, 2, 16, 16, 32, 8), f32), + ): + return linalg.pack( + src, + dst, + inner_dims_pos=[1, 0], + inner_tiles=[32, 8], + padding_value=arith.constant(0.0), + )