-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][python] fix linalg.pack/unpack #127729
[mlir][python] fix linalg.pack/unpack #127729
Conversation
ddd976f
to
a642f31
Compare
✅ With the latest revision this PR passed the Python code formatter. |
a642f31
to
e288bfe
Compare
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesThis PR #123902 broke python bindings for TODO: add wrapper for Full diff: https://github.com/llvm/llvm-project/pull/127729.diff 3 Files Affected:
diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td
index b7658c85a9c44..89fb3f219e858 100644
--- a/mlir/python/mlir/dialects/LinalgOps.td
+++ b/mlir/python/mlir/dialects/LinalgOps.td
@@ -11,5 +11,6 @@
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
#endif
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 5cda4769d593f..2e3424aaddb4f 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -58,7 +58,10 @@
from .opdsl.ops.core_named_ops import *
from ...ir import *
-from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ _dispatch_mixed_values,
+)
from ...extras.meta import region_op
@@ -193,3 +196,35 @@ def contract(
)
fill_builtin_region(op.operation)
return op
+
+
+def pack(
+ source,
+ dest,
+ inner_dims_pos,
+ inner_tiles,
+ *,
+ padding_value=None,
+ outer_dims_perm=None,
+ loc=None,
+ ip=None,
+) -> ir.Value:
+
+ (
+ dynamic_inner_tiles,
+ # packed here means %1:2 packing (results packing)
+ _inner_tiles,
+ static_inner_tiles,
+ ) = _dispatch_mixed_values(inner_tiles)
+
+ return PackOp(
+ source=source,
+ dest=dest,
+ inner_dims_pos=inner_dims_pos,
+ inner_tiles=dynamic_inner_tiles,
+ static_inner_tiles=static_inner_tiles,
+ padding_value=padding_value,
+ outer_dims_perm=outer_dims_perm,
+ loc=loc,
+ ip=ip,
+ ).result
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 94f8ea4faf4a8..321ec9d658416 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -466,3 +466,34 @@ def matmul_as_contract_op(
)
print(module)
+
+
+# CHECK-LABEL: TEST: testPackOp
+@run
+def testPackOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((129, 47, 16, 16), f32),
+ RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
+ )
+ def tensor_pack(src, dst):
+ return linalg.pack(
+ src,
+ dst,
+ inner_dims_pos=[1, 0],
+ inner_tiles=[32, 8],
+ padding_value=arith.constant(f32, 0.0),
+ )
+
+ # CHECK-LABEL: func.func @tensor_pack(
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
+ # CHECK-SAME: %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
+ # CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+ # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
+ # CHECK: return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
+ # CHECK: }
+ print(module)
|
e288bfe
to
6f78ea7
Compare
Thanks for the fix and apologies for the disruption!
Just so that I educate myself - where were the old bindings located? That PR didn't touch any Python code 🤔 Also, looks like this PR is implementing the bindings from scratch? Please bare with me, I'm not too familiar with this logic 😅 |
They were being emitted via
What I have added is a convenience wrapper around the emitted binding that now is being emitted into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good to me.
If you do unpack
in the same way, I think it would good to go!
outer_dims_perm=outer_dims_perm, | ||
loc=loc, | ||
ip=ip, | ||
).result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check: should the other wrappers in this file also return Op(...).result
/Op(...).results
instead of just Op(...)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lower case APIs are so called "value builders" that return a value in general. Here I've actually made a mistake because linalg is special and so I should be returning either result or the op itself in the case when there is no result (when the operands are memrefs).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So actually the batch_matmul
PR that just landed is wrong because it returns ops from lower case APIs https://github.com/llvm/llvm-project/pull/127614/files#diff-d46e2a112be8db49e67fa428a8b422823fa9700f271248043861d4bdb1def61fR173
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would a fixup followup be fine here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine with me - I definitely prefer to combine simple PRs (like these two would be). So I'll do that. I'll keep that change in a separate commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a fix in this one would also be ok, if you're happy to do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a fix in this one would also be ok, if you're happy to do it.
sorry misread/misunderstood your question - yea no need for a revert (I'm probably the only person in the world using linalg
as a frontend so I doubt it's BC). And I'll fix it up here.
6f78ea7
to
9623fa6
Compare
8001c7a
to
395a528
Compare
395a528
to
6e772de
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for the changes, especially the fix-ups for the matmul-like ops!
If you could update the description accordingly, it's good to go IMO.
Thanks @makslevental , I really appreciate your help with this 🙏🏻 |
This PR #123902 broke python bindings for
tensor.pack
/unpack
. This PR fixes that. It also