Skip to content

Commit

Permalink
[mlir][python] fix linalg.pack
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Feb 19, 2025
1 parent 3430bc3 commit ddd976f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
1 change: 1 addition & 0 deletions mlir/python/mlir/dialects/LinalgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@

include "mlir/Dialect/Linalg/IR/LinalgOps.td"
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"

#endif
37 changes: 36 additions & 1 deletion mlir/python/mlir/dialects/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
from .opdsl.ops.core_named_ops import *

from ...ir import *
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
from .._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
_dispatch_mixed_values,
)
from ...extras.meta import region_op


Expand Down Expand Up @@ -193,3 +196,35 @@ def contract(
)
fill_builtin_region(op.operation)
return op


def pack(
source,
dest,
inner_dims_pos,
inner_tiles,
*,
padding_value=None,
outer_dims_perm=None,
loc=None,
ip=None,
) -> ir.Value:

(
dynamic_inner_tiles,
# packed here means %1:2 packing (results packing)
_inner_tiles,
static_inner_tiles,
) = _dispatch_mixed_values(inner_tiles)

return PackOp(
source=source,
dest=dest,
inner_dims_pos=inner_dims_pos,
inner_tiles=dynamic_inner_tiles,
static_inner_tiles=static_inner_tiles,
padding_value=padding_value,
outer_dims_perm=outer_dims_perm,
loc=loc,
ip=ip,
).result
21 changes: 21 additions & 0 deletions mlir/test/python/dialects/linalg/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,24 @@ def matmul_as_contract_op(
)

print(module)


# CHECK-LABEL: TEST: testPackOp
@run
def testPackOp():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()

@func
def tensor_pack(
src: RankedTensorType.get((129, 47, 16, 16), f32),
dst: RankedTensorType((17, 2, 16, 16, 32, 8), f32),
):
return linalg.pack(
src,
dst,
inner_dims_pos=[1, 0],
inner_tiles=[32, 8],
padding_value=arith.constant(0.0),
)

0 comments on commit ddd976f

Please sign in to comment.