From a72616de18c0814ad37b5748d6bdc60b825dd889 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 20 Feb 2025 10:02:36 -0600 Subject: [PATCH] [mlir][python] fix linalg.pack/unpack (#127729) This PR https://github.com/llvm/llvm-project/pull/123902 broke python bindings for `tensor.pack`/`unpack`. This PR fixes that. It also 1. adds convenience wrappers for pack/unpack 2. cleans up matmul-like ops in the linalg bindings 3. fixes linalg docs missing pack/unpack --- mlir/docs/Dialects/Linalg/_index.md | 1 + mlir/python/mlir/dialects/LinalgOps.td | 1 + mlir/python/mlir/dialects/linalg/__init__.py | 90 ++++++++++++++++++-- mlir/test/python/dialects/linalg/ops.py | 40 +++++++++ 4 files changed, 125 insertions(+), 7 deletions(-) diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md index 976f0fd3c7e9..b519e4159f18 100644 --- a/mlir/docs/Dialects/Linalg/_index.md +++ b/mlir/docs/Dialects/Linalg/_index.md @@ -695,3 +695,4 @@ the same IR. ## Operations [include "Dialects/LinalgOps.md"] +[include "Dialects/LinalgRelayoutOps.td"] diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td index b7658c85a9c4..89fb3f219e85 100644 --- a/mlir/python/mlir/dialects/LinalgOps.td +++ b/mlir/python/mlir/dialects/LinalgOps.td @@ -11,5 +11,6 @@ include "mlir/Dialect/Linalg/IR/LinalgOps.td" include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td" +include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td" #endif diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index c5fbb833ee39..63586a5bb8bb 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -58,7 +58,11 @@ from .opdsl.ops.core_named_ops import * from ...ir import * -from .._ods_common import get_op_result_or_value as _get_op_result_or_value +from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_result_or_op_results as _get_op_result_or_op_results, + _dispatch_mixed_values, +) from ...extras.meta import region_op @@ -149,7 +153,7 @@ def __init__( generic = region_op(GenericOp_, terminator=YieldOp) -def create_op( +def _create_matmul_like_op( op_type, *ins: Union[Operation, OpView, Value], outs: Sequence[Union[Operation, OpView, Value]], @@ -179,7 +183,11 @@ def matmul( indexing_maps: Optional[Sequence[AffineMapAttr]] = None, cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast) + return _get_op_result_or_op_results( + _create_matmul_like_op( + MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) def batch_matmul( @@ -188,8 +196,10 @@ def batch_matmul( indexing_maps: Optional[Sequence[AffineMapAttr]] = None, cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op( - BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + return _get_op_result_or_op_results( + _create_matmul_like_op( + BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) ) @@ -199,6 +209,72 @@ def contract( indexing_maps: Sequence[AffineMapAttr], cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op( - ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + return _get_op_result_or_op_results( + _create_matmul_like_op( + ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) + + +def pack( + source, + dest, + inner_dims_pos, + inner_tiles, + *, + padding_value=None, + outer_dims_perm=None, + loc=None, + ip=None, +) -> ir.Value: + ( + dynamic_inner_tiles, + # packed here means %1:2 packing (results packing) + _inner_tiles, + static_inner_tiles, + ) = _dispatch_mixed_values(inner_tiles) + + return _get_op_result_or_op_results( + PackOp( + source=source, + dest=dest, + inner_dims_pos=inner_dims_pos, + inner_tiles=dynamic_inner_tiles, + static_inner_tiles=static_inner_tiles, + padding_value=padding_value, + outer_dims_perm=outer_dims_perm, + loc=loc, + ip=ip, + ) + ) + + +def unpack( + source, + dest, + inner_dims_pos, + inner_tiles, + *, + outer_dims_perm=None, + loc=None, + ip=None, +) -> ir.Value: + ( + dynamic_inner_tiles, + # packed here means %1:2 packing (results packing) + _inner_tiles, + static_inner_tiles, + ) = _dispatch_mixed_values(inner_tiles) + + return _get_op_result_or_op_results( + UnPackOp( + source=source, + dest=dest, + inner_dims_pos=inner_dims_pos, + inner_tiles=dynamic_inner_tiles, + static_inner_tiles=static_inner_tiles, + outer_dims_perm=outer_dims_perm, + loc=loc, + ip=ip, + ) ) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 307a88709ad5..e32a911b24b1 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -566,3 +566,43 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem): ) print(module) + + +# CHECK-LABEL: TEST: testPackUnPackOp +@run +def testPackUnPackOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get((128, 128), f32), + RankedTensorType.get((16, 16, 8, 8), f32), + ) + def tensor_pack(src, dst): + packed = linalg.pack( + src, + dst, + inner_dims_pos=[1, 0], + inner_tiles=[8, 8], + padding_value=arith.constant(f32, 0.0), + ) + + unpacked = linalg.unpack( + packed, + src, + inner_dims_pos=[0, 1], + inner_tiles=[8, 8], + ) + + return unpacked + + # CHECK-LABEL: func.func @tensor_pack( + # CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> { + # CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 + # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32> + # CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32> + # CHECK: return %[[VAL_4]] : tensor<128x128xf32> + # CHECK: } + print(module)