Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][python] fix linalg.pack/unpack #127729

Merged
merged 2 commits into from
Feb 20, 2025

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Feb 19, 2025

This PR #123902 broke python bindings for tensor.pack/unpack. This PR fixes that. It also

  1. adds convenience wrappers for pack/unpack
  2. cleans up matmul-like ops in the linalg bindings
  3. fixes linalg docs missing pack/unpack

@makslevental makslevental force-pushed the makslevental/fix-tensor-pack branch from ddd976f to a642f31 Compare February 19, 2025 00:32
Copy link

github-actions bot commented Feb 19, 2025

✅ With the latest revision this PR passed the Python code formatter.

@makslevental makslevental force-pushed the makslevental/fix-tensor-pack branch from a642f31 to e288bfe Compare February 19, 2025 00:39
@makslevental makslevental marked this pull request as ready for review February 19, 2025 00:40
@llvmbot llvmbot added mlir:linalg mlir:python MLIR Python bindings mlir labels Feb 19, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

This PR #123902 broke python bindings for tensor.pack/unpack. This PR fixes.

TODO: add wrapper for unpack.


Full diff: https://github.com/llvm/llvm-project/pull/127729.diff

3 Files Affected:

  • (modified) mlir/python/mlir/dialects/LinalgOps.td (+1)
  • (modified) mlir/python/mlir/dialects/linalg/init.py (+36-1)
  • (modified) mlir/test/python/dialects/linalg/ops.py (+31)
diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td
index b7658c85a9c44..89fb3f219e858 100644
--- a/mlir/python/mlir/dialects/LinalgOps.td
+++ b/mlir/python/mlir/dialects/LinalgOps.td
@@ -11,5 +11,6 @@
 
 include "mlir/Dialect/Linalg/IR/LinalgOps.td"
 include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
 
 #endif
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 5cda4769d593f..2e3424aaddb4f 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -58,7 +58,10 @@
 from .opdsl.ops.core_named_ops import *
 
 from ...ir import *
-from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from .._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    _dispatch_mixed_values,
+)
 from ...extras.meta import region_op
 
 
@@ -193,3 +196,35 @@ def contract(
     )
     fill_builtin_region(op.operation)
     return op
+
+
+def pack(
+    source,
+    dest,
+    inner_dims_pos,
+    inner_tiles,
+    *,
+    padding_value=None,
+    outer_dims_perm=None,
+    loc=None,
+    ip=None,
+) -> ir.Value:
+
+    (
+        dynamic_inner_tiles,
+        # packed here means %1:2 packing (results packing)
+        _inner_tiles,
+        static_inner_tiles,
+    ) = _dispatch_mixed_values(inner_tiles)
+
+    return PackOp(
+        source=source,
+        dest=dest,
+        inner_dims_pos=inner_dims_pos,
+        inner_tiles=dynamic_inner_tiles,
+        static_inner_tiles=static_inner_tiles,
+        padding_value=padding_value,
+        outer_dims_perm=outer_dims_perm,
+        loc=loc,
+        ip=ip,
+    ).result
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 94f8ea4faf4a8..321ec9d658416 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -466,3 +466,34 @@ def matmul_as_contract_op(
                 )
 
         print(module)
+
+
+# CHECK-LABEL: TEST: testPackOp
+@run
+def testPackOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((129, 47, 16, 16), f32),
+                RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
+            )
+            def tensor_pack(src, dst):
+                return linalg.pack(
+                    src,
+                    dst,
+                    inner_dims_pos=[1, 0],
+                    inner_tiles=[32, 8],
+                    padding_value=arith.constant(f32, 0.0),
+                )
+
+        # CHECK-LABEL:   func.func @tensor_pack(
+        # CHECK-SAME:                           %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
+        # CHECK-SAME:                           %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
+        # CHECK:           %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+        # CHECK:           %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
+        # CHECK:           return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
+        # CHECK:         }
+        print(module)

@makslevental makslevental force-pushed the makslevental/fix-tensor-pack branch from e288bfe to 6f78ea7 Compare February 19, 2025 00:41
@makslevental makslevental changed the title [mlir][python] fix linalg.pack [mlir][python] fix linalg.pack/unpack Feb 19, 2025
@rengolin rengolin requested a review from rolfmorel February 19, 2025 02:22
@banach-space
Copy link
Contributor

Thanks for the fix and apologies for the disruption!

This PR #123902 broke python bindings for tensor.pack/unpack.

Just so that I educate myself - where were the old bindings located? That PR didn't touch any Python code 🤔 Also, looks like this PR is implementing the bindings from scratch?

Please bare with me, I'm not too familiar with this logic 😅

@makslevental
Copy link
Contributor Author

makslevental commented Feb 19, 2025

where were the old bindings located

They were being emitted via dialects/TensorOps.td into a file called dialects/_tensor_ops_gen.py.

Also, looks like this PR is implementing the bindings from scratch?

What I have added is a convenience wrapper around the emitted binding that now is being emitted into dialects/_linalg_ops_gen.py.

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good to me.

If you do unpack in the same way, I think it would good to go!

outer_dims_perm=outer_dims_perm,
loc=loc,
ip=ip,
).result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check: should the other wrappers in this file also return Op(...).result/Op(...).results instead of just Op(...)?

Copy link
Contributor Author

@makslevental makslevental Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lower case APIs are so called "value builders" that return a value in general. Here I've actually made a mistake because linalg is special and so I should be returning either result or the op itself in the case when there is no result (when the operands are memrefs).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So actually the batch_matmul
PR that just landed is wrong because it returns ops from lower case APIs https://github.com/llvm/llvm-project/pull/127614/files#diff-d46e2a112be8db49e67fa428a8b422823fa9700f271248043861d4bdb1def61fR173

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a fixup followup be fine here?

Copy link
Contributor Author

@makslevental makslevental Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with me - I definitely prefer to combine simple PRs (like these two would be). So I'll do that. I'll keep that change in a separate commit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a fix in this one would also be ok, if you're happy to do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a fix in this one would also be ok, if you're happy to do it.

sorry misread/misunderstood your question - yea no need for a revert (I'm probably the only person in the world using linalg as a frontend so I doubt it's BC). And I'll fix it up here.

@makslevental makslevental force-pushed the makslevental/fix-tensor-pack branch from 6f78ea7 to 9623fa6 Compare February 19, 2025 15:55
@makslevental makslevental force-pushed the makslevental/fix-tensor-pack branch 3 times, most recently from 8001c7a to 395a528 Compare February 20, 2025 02:44
@makslevental makslevental force-pushed the makslevental/fix-tensor-pack branch from 395a528 to 6e772de Compare February 20, 2025 03:08
Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for the changes, especially the fix-ups for the matmul-like ops!

If you could update the description accordingly, it's good to go IMO.

@makslevental makslevental merged commit a72616d into llvm:main Feb 20, 2025
8 checks passed
@makslevental makslevental deleted the makslevental/fix-tensor-pack branch February 20, 2025 16:02
@banach-space
Copy link
Contributor

Thanks @makslevental , I really appreciate your help with this 🙏🏻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants