diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md index 976f0fd3c7e91..b519e4159f186 100644 --- a/mlir/docs/Dialects/Linalg/_index.md +++ b/mlir/docs/Dialects/Linalg/_index.md @@ -695,3 +695,4 @@ the same IR. ## Operations [include "Dialects/LinalgOps.md"] +[include "Dialects/LinalgRelayoutOps.td"] diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index b99344d34db89..63586a5bb8bbb 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -153,7 +153,7 @@ def __init__( generic = region_op(GenericOp_, terminator=YieldOp) -def create_op( +def _create_matmul_like_op( op_type, *ins: Union[Operation, OpView, Value], outs: Sequence[Union[Operation, OpView, Value]], @@ -183,7 +183,11 @@ def matmul( indexing_maps: Optional[Sequence[AffineMapAttr]] = None, cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast) + return _get_op_result_or_op_results( + _create_matmul_like_op( + MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) def batch_matmul( @@ -192,8 +196,10 @@ def batch_matmul( indexing_maps: Optional[Sequence[AffineMapAttr]] = None, cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op( - BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + return _get_op_result_or_op_results( + _create_matmul_like_op( + BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) ) @@ -203,8 +209,10 @@ def contract( indexing_maps: Sequence[AffineMapAttr], cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op( - ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + return _get_op_result_or_op_results( + _create_matmul_like_op( + ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) ) @@ -239,3 +247,34 @@ def pack( ip=ip, ) ) + + +def unpack( + source, + dest, + inner_dims_pos, + inner_tiles, + *, + outer_dims_perm=None, + loc=None, + ip=None, +) -> ir.Value: + ( + dynamic_inner_tiles, + # packed here means %1:2 packing (results packing) + _inner_tiles, + static_inner_tiles, + ) = _dispatch_mixed_values(inner_tiles) + + return _get_op_result_or_op_results( + UnPackOp( + source=source, + dest=dest, + inner_dims_pos=inner_dims_pos, + inner_tiles=dynamic_inner_tiles, + static_inner_tiles=static_inner_tiles, + outer_dims_perm=outer_dims_perm, + loc=loc, + ip=ip, + ) + ) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index d199558750e1e..e32a911b24b11 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -568,32 +568,41 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem): print(module) -# CHECK-LABEL: TEST: testPackOp +# CHECK-LABEL: TEST: testPackUnPackOp @run -def testPackOp(): +def testPackUnPackOp(): with Context(), Location.unknown(): module = Module.create() f32 = F32Type.get() with InsertionPoint(module.body): @func.FuncOp.from_py_func( - RankedTensorType.get((129, 47, 16, 16), f32), - RankedTensorType.get((17, 2, 16, 16, 32, 8), f32), + RankedTensorType.get((128, 128), f32), + RankedTensorType.get((16, 16, 8, 8), f32), ) def tensor_pack(src, dst): - return linalg.pack( + packed = linalg.pack( src, dst, inner_dims_pos=[1, 0], - inner_tiles=[32, 8], + inner_tiles=[8, 8], padding_value=arith.constant(f32, 0.0), ) + unpacked = linalg.unpack( + packed, + src, + inner_dims_pos=[0, 1], + inner_tiles=[8, 8], + ) + + return unpacked + # CHECK-LABEL: func.func @tensor_pack( - # CHECK-SAME: %[[VAL_0:.*]]: tensor<129x47x16x16xf32>, - # CHECK-SAME: %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> { + # CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> { # CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 - # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32> - # CHECK: return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32> + # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32> + # CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32> + # CHECK: return %[[VAL_4]] : tensor<128x128xf32> # CHECK: } print(module)