From 6f78ea7c05bb0dcdee133790eec902b0a7f01656 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 18 Feb 2025 19:29:36 -0500 Subject: [PATCH] [mlir][python] fix linalg.pack --- mlir/python/mlir/dialects/LinalgOps.td | 1 + mlir/python/mlir/dialects/linalg/__init__.py | 36 +++++++++++++++++++- mlir/test/python/dialects/linalg/ops.py | 31 +++++++++++++++++ 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td index b7658c85a9c44..89fb3f219e858 100644 --- a/mlir/python/mlir/dialects/LinalgOps.td +++ b/mlir/python/mlir/dialects/LinalgOps.td @@ -11,5 +11,6 @@ include "mlir/Dialect/Linalg/IR/LinalgOps.td" include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td" +include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td" #endif diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 5cda4769d593f..ee2d6f0eb493c 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -58,7 +58,10 @@ from .opdsl.ops.core_named_ops import * from ...ir import * -from .._ods_common import get_op_result_or_value as _get_op_result_or_value +from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + _dispatch_mixed_values, +) from ...extras.meta import region_op @@ -193,3 +196,34 @@ def contract( ) fill_builtin_region(op.operation) return op + + +def pack( + source, + dest, + inner_dims_pos, + inner_tiles, + *, + padding_value=None, + outer_dims_perm=None, + loc=None, + ip=None, +) -> ir.Value: + ( + dynamic_inner_tiles, + # packed here means %1:2 packing (results packing) + _inner_tiles, + static_inner_tiles, + ) = _dispatch_mixed_values(inner_tiles) + + return PackOp( + source=source, + dest=dest, + inner_dims_pos=inner_dims_pos, + inner_tiles=dynamic_inner_tiles, + static_inner_tiles=static_inner_tiles, + padding_value=padding_value, + outer_dims_perm=outer_dims_perm, + loc=loc, + ip=ip, + ).result diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 94f8ea4faf4a8..321ec9d658416 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -466,3 +466,34 @@ def matmul_as_contract_op( ) print(module) + + +# CHECK-LABEL: TEST: testPackOp +@run +def testPackOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get((129, 47, 16, 16), f32), + RankedTensorType.get((17, 2, 16, 16, 32, 8), f32), + ) + def tensor_pack(src, dst): + return linalg.pack( + src, + dst, + inner_dims_pos=[1, 0], + inner_tiles=[32, 8], + padding_value=arith.constant(f32, 0.0), + ) + + # CHECK-LABEL: func.func @tensor_pack( + # CHECK-SAME: %[[VAL_0:.*]]: tensor<129x47x16x16xf32>, + # CHECK-SAME: %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> { + # CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 + # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32> + # CHECK: return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32> + # CHECK: } + print(module)