Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Feb 20, 2025
1 parent d18ca56 commit 6e772de
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 16 deletions.
1 change: 1 addition & 0 deletions mlir/docs/Dialects/Linalg/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -695,3 +695,4 @@ the same IR.
## Operations

[include "Dialects/LinalgOps.md"]
[include "Dialects/LinalgRelayoutOps.td"]
51 changes: 45 additions & 6 deletions mlir/python/mlir/dialects/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)


def create_op(
def _create_matmul_like_op(
op_type,
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
Expand Down Expand Up @@ -183,7 +183,11 @@ def matmul(
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
return _get_op_result_or_op_results(
_create_matmul_like_op(
MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
)


def batch_matmul(
Expand All @@ -192,8 +196,10 @@ def batch_matmul(
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
return _get_op_result_or_op_results(
_create_matmul_like_op(
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
)


Expand All @@ -203,8 +209,10 @@ def contract(
indexing_maps: Sequence[AffineMapAttr],
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
return _get_op_result_or_op_results(
_create_matmul_like_op(
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
)


Expand Down Expand Up @@ -239,3 +247,34 @@ def pack(
ip=ip,
)
)


def unpack(
source,
dest,
inner_dims_pos,
inner_tiles,
*,
outer_dims_perm=None,
loc=None,
ip=None,
) -> ir.Value:
(
dynamic_inner_tiles,
# packed here means %1:2 packing (results packing)
_inner_tiles,
static_inner_tiles,
) = _dispatch_mixed_values(inner_tiles)

return _get_op_result_or_op_results(
UnPackOp(
source=source,
dest=dest,
inner_dims_pos=inner_dims_pos,
inner_tiles=dynamic_inner_tiles,
static_inner_tiles=static_inner_tiles,
outer_dims_perm=outer_dims_perm,
loc=loc,
ip=ip,
)
)
29 changes: 19 additions & 10 deletions mlir/test/python/dialects/linalg/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,32 +568,41 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
print(module)


# CHECK-LABEL: TEST: testPackOp
# CHECK-LABEL: TEST: testPackUnPackOp
@run
def testPackOp():
def testPackUnPackOp():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):

@func.FuncOp.from_py_func(
RankedTensorType.get((129, 47, 16, 16), f32),
RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
RankedTensorType.get((128, 128), f32),
RankedTensorType.get((16, 16, 8, 8), f32),
)
def tensor_pack(src, dst):
return linalg.pack(
packed = linalg.pack(
src,
dst,
inner_dims_pos=[1, 0],
inner_tiles=[32, 8],
inner_tiles=[8, 8],
padding_value=arith.constant(f32, 0.0),
)

unpacked = linalg.unpack(
packed,
src,
inner_dims_pos=[0, 1],
inner_tiles=[8, 8],
)

return unpacked

# CHECK-LABEL: func.func @tensor_pack(
# CHECK-SAME: %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
# CHECK-SAME: %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
# CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
# CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
# CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
# CHECK: return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
# CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
# CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
print(module)

0 comments on commit 6e772de

Please sign in to comment.