Skip to content

Commit

Permalink
Automerge: [mlir][python] fix linalg.pack/unpack (#127729)
Browse files Browse the repository at this point in the history
This PR llvm/llvm-project#123902 broke python
bindings for `tensor.pack`/`unpack`. This PR fixes that. It also

1. adds convenience wrappers for pack/unpack
2. cleans up matmul-like ops in the linalg bindings
3. fixes linalg docs missing pack/unpack
  • Loading branch information
makslevental authored and github-actions[bot] committed Feb 20, 2025
2 parents 2803c09 + a72616d commit 2faf4f4
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 7 deletions.
1 change: 1 addition & 0 deletions mlir/docs/Dialects/Linalg/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -695,3 +695,4 @@ the same IR.
## Operations

[include "Dialects/LinalgOps.md"]
[include "Dialects/LinalgRelayoutOps.td"]
1 change: 1 addition & 0 deletions mlir/python/mlir/dialects/LinalgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@

include "mlir/Dialect/Linalg/IR/LinalgOps.td"
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"

#endif
90 changes: 83 additions & 7 deletions mlir/python/mlir/dialects/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@
from .opdsl.ops.core_named_ops import *

from ...ir import *
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
from .._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_result_or_op_results as _get_op_result_or_op_results,
_dispatch_mixed_values,
)
from ...extras.meta import region_op


Expand Down Expand Up @@ -149,7 +153,7 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)


def create_op(
def _create_matmul_like_op(
op_type,
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
Expand Down Expand Up @@ -179,7 +183,11 @@ def matmul(
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
return _get_op_result_or_op_results(
_create_matmul_like_op(
MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
)


def batch_matmul(
Expand All @@ -188,8 +196,10 @@ def batch_matmul(
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
return _get_op_result_or_op_results(
_create_matmul_like_op(
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
)


Expand All @@ -199,6 +209,72 @@ def contract(
indexing_maps: Sequence[AffineMapAttr],
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
return _get_op_result_or_op_results(
_create_matmul_like_op(
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
)


def pack(
source,
dest,
inner_dims_pos,
inner_tiles,
*,
padding_value=None,
outer_dims_perm=None,
loc=None,
ip=None,
) -> ir.Value:
(
dynamic_inner_tiles,
# packed here means %1:2 packing (results packing)
_inner_tiles,
static_inner_tiles,
) = _dispatch_mixed_values(inner_tiles)

return _get_op_result_or_op_results(
PackOp(
source=source,
dest=dest,
inner_dims_pos=inner_dims_pos,
inner_tiles=dynamic_inner_tiles,
static_inner_tiles=static_inner_tiles,
padding_value=padding_value,
outer_dims_perm=outer_dims_perm,
loc=loc,
ip=ip,
)
)


def unpack(
source,
dest,
inner_dims_pos,
inner_tiles,
*,
outer_dims_perm=None,
loc=None,
ip=None,
) -> ir.Value:
(
dynamic_inner_tiles,
# packed here means %1:2 packing (results packing)
_inner_tiles,
static_inner_tiles,
) = _dispatch_mixed_values(inner_tiles)

return _get_op_result_or_op_results(
UnPackOp(
source=source,
dest=dest,
inner_dims_pos=inner_dims_pos,
inner_tiles=dynamic_inner_tiles,
static_inner_tiles=static_inner_tiles,
outer_dims_perm=outer_dims_perm,
loc=loc,
ip=ip,
)
)
40 changes: 40 additions & 0 deletions mlir/test/python/dialects/linalg/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,3 +566,43 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
)

print(module)


# CHECK-LABEL: TEST: testPackUnPackOp
@run
def testPackUnPackOp():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):

@func.FuncOp.from_py_func(
RankedTensorType.get((128, 128), f32),
RankedTensorType.get((16, 16, 8, 8), f32),
)
def tensor_pack(src, dst):
packed = linalg.pack(
src,
dst,
inner_dims_pos=[1, 0],
inner_tiles=[8, 8],
padding_value=arith.constant(f32, 0.0),
)

unpacked = linalg.unpack(
packed,
src,
inner_dims_pos=[0, 1],
inner_tiles=[8, 8],
)

return unpacked

# CHECK-LABEL: func.func @tensor_pack(
# CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
# CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
# CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
# CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
print(module)

0 comments on commit 2faf4f4

Please sign in to comment.