From 395a528a6473f9ffd2dee91f30b7efda7e3588ad Mon Sep 17 00:00:00 2001 From: max Date: Wed, 19 Feb 2025 21:35:38 -0500 Subject: [PATCH] fixup --- mlir/docs/Dialects/Linalg/_index.md | 1 + mlir/python/mlir/dialects/linalg/__init__.py | 51 +++++++++++++++++--- mlir/test/python/dialects/linalg/ops.py | 16 ++++-- 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md index 976f0fd3c7e91..b519e4159f186 100644 --- a/mlir/docs/Dialects/Linalg/_index.md +++ b/mlir/docs/Dialects/Linalg/_index.md @@ -695,3 +695,4 @@ the same IR. ## Operations [include "Dialects/LinalgOps.md"] +[include "Dialects/LinalgRelayoutOps.td"] diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index b99344d34db89..63586a5bb8bbb 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -153,7 +153,7 @@ def __init__( generic = region_op(GenericOp_, terminator=YieldOp) -def create_op( +def _create_matmul_like_op( op_type, *ins: Union[Operation, OpView, Value], outs: Sequence[Union[Operation, OpView, Value]], @@ -183,7 +183,11 @@ def matmul( indexing_maps: Optional[Sequence[AffineMapAttr]] = None, cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast) + return _get_op_result_or_op_results( + _create_matmul_like_op( + MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) def batch_matmul( @@ -192,8 +196,10 @@ def batch_matmul( indexing_maps: Optional[Sequence[AffineMapAttr]] = None, cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op( - BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + return _get_op_result_or_op_results( + _create_matmul_like_op( + BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) ) @@ -203,8 +209,10 @@ def contract( indexing_maps: Sequence[AffineMapAttr], cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op( - ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + return _get_op_result_or_op_results( + _create_matmul_like_op( + ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) ) @@ -239,3 +247,34 @@ def pack( ip=ip, ) ) + + +def unpack( + source, + dest, + inner_dims_pos, + inner_tiles, + *, + outer_dims_perm=None, + loc=None, + ip=None, +) -> ir.Value: + ( + dynamic_inner_tiles, + # packed here means %1:2 packing (results packing) + _inner_tiles, + static_inner_tiles, + ) = _dispatch_mixed_values(inner_tiles) + + return _get_op_result_or_op_results( + UnPackOp( + source=source, + dest=dest, + inner_dims_pos=inner_dims_pos, + inner_tiles=dynamic_inner_tiles, + static_inner_tiles=static_inner_tiles, + outer_dims_perm=outer_dims_perm, + loc=loc, + ip=ip, + ) + ) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index d199558750e1e..55dc60cd6c902 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -568,9 +568,9 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem): print(module) -# CHECK-LABEL: TEST: testPackOp +# CHECK-LABEL: TEST: testPackUnPackOp @run -def testPackOp(): +def testPackUnPackOp(): with Context(), Location.unknown(): module = Module.create() f32 = F32Type.get() @@ -581,7 +581,7 @@ def testPackOp(): RankedTensorType.get((17, 2, 16, 16, 32, 8), f32), ) def tensor_pack(src, dst): - return linalg.pack( + packed = linalg.pack( src, dst, inner_dims_pos=[1, 0], @@ -589,6 +589,16 @@ def tensor_pack(src, dst): padding_value=arith.constant(f32, 0.0), ) + unpacked = linalg.pack( + dst, + src, + inner_dims_pos=[0, 1], + inner_tiles=[16, 16], + ) + + return unpacked + + # CHECK-LABEL: func.func @tensor_pack( # CHECK-SAME: %[[VAL_0:.*]]: tensor<129x47x16x16xf32>, # CHECK-SAME: %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {