Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API #126377

Merged
merged 11 commits into from
Feb 10, 2025

Conversation

rolfmorel
Copy link
Contributor

@rolfmorel rolfmorel commented Feb 8, 2025

Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op.

Required following misc. fixes:

  1. make linalg.matmul's parsing and printing consistent w.r.t. whether indexing_maps occurs before or after operands, i.e. per the tests cases it comes before.
  2. tablegen for linalg.contract did not state it accepted an optional cast attr.
  3. In ODS's C++-generating code, expand partial support for $_builder access in Attr::defaultValue to full support. This enables access to the current MlirContext when constructing the default value (as is required when the default value consists of affine maps).

Now that linalg.matmul is in tablegen, "hand write" the Python wrapper
that OpDSL used to derive. Similarly, add a Python wrapper for the new
linalg.contract op.

Required following misc. fixes:
1) make linalg.matmul consistent in whether indexing_maps occurs before
   or after operands, i.e. per the tests case it comes _before_.
   TODO: fix linalg.batch_matmul as well
2) tablegen for linalg.contract did not state it accepted an optional
   cast attr.
@llvmbot
Copy link
Member

llvmbot commented Feb 8, 2025

@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-linalg

Author: Rolf Morel (rolfmorel)

Changes

Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op.

Required following misc. fixes:

  1. make linalg.matmul consistent in whether indexing_maps occurs before
    or after operands, i.e. per the tests case it comes before.
    TODO: fix linalg.batch_matmul as well
  2. tablegen for linalg.contract did not state it accepted an optional
    cast attr.

Full diff: https://github.com/llvm/llvm-project/pull/126377.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+5-5)
  • (modified) mlir/python/mlir/dialects/linalg/init.py (+48)
  • (modified) mlir/test/python/dialects/linalg/ops.py (+186)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 110ed7d2fc00e2a..6146ff09482fbad 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -752,7 +752,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
   let arguments = (ins
     Variadic<AnyType>:$inputs,
     Variadic<AnyShaped>:$outputs,
-    AffineMapArrayAttr:$indexing_maps
+    AffineMapArrayAttr:$indexing_maps,
+    DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
   );
   let results = (outs Variadic<AnyShaped>:$result_tensors);
   // NB: The only reason this op has a region - and it get populated at op build
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b50931f15826ce2..d40cec02df6338d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void MatmulOp::print(OpAsmPrinter &p) {
-  SmallVector<StringRef, 3> elidedAttrs = {
-      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
-  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
-                         elidedAttrs);
-
   SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
       MatmulOp::getDefaultIndexingMaps(getContext()),
       [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
                           [&](Attribute attr) { p.printAttribute(attr); });
     p << "]";
   }
+
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                         elidedAttrs);
 }
 
 /// Verify the user defined indexing maps.
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 742262a9c496952..07e7c0ccd70025d 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -147,3 +147,51 @@ def __init__(
 
 
 generic = region_op(GenericOp_, terminator=YieldOp)
+
+
+def matmul(
+    inputs: Sequence[Union[Operation, OpView, Value]],
+    *,
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Sequence[AffineMapAttr],
+    cast: Optional[Union[TypeFn, Attribute]]=None
+):
+    inputs = [_get_op_result_or_value(input) for input in inputs]
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+    op = MatmulOp(
+        result_tensors=result_types,
+        inputs=inputs,
+        outputs=[init],
+        indexing_maps=indexing_maps,
+        cast=cast
+    )
+    fill_builtin_region(op.operation)
+    return op
+
+
+def contract(
+    inputs: Sequence[Union[Operation, OpView, Value]],
+    *,
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Sequence[AffineMapAttr],
+    cast: Optional[Union[TypeFn, Attribute]]=None
+):
+    inputs = [_get_op_result_or_value(input) for input in inputs]
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+    op = ContractOp(
+        result_tensors=result_types,
+        inputs=inputs,
+        outputs=[init],
+        indexing_maps=indexing_maps,
+        cast=cast
+    )
+    fill_builtin_region(op.operation)
+    return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index ac7186c24bed84e..6baea4f917c128c 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -256,3 +256,189 @@ def f(a, b):
 
     module.operation.verify()
     print(module)
+
+
+# CHECK-LABEL: TEST: testMatmulOp
+@run
+def testMatmulOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (4, 8)
+            b_shape = (8, 12)
+            b_transposed_shape = (12, 8)
+            c_shape = (4, 12)
+
+            dimM = ir.AffineDimExpr.get(0)
+            dimN = ir.AffineDimExpr.get(1)
+            dimK = ir.AffineDimExpr.get(2)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+            a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+            b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+            c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+            b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+            # CHECK: func.func @matmul_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+                # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+                # And now with memrefs...
+
+                # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Bmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Btransposedmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+        print(module)
+
+
+# CHECK-LABEL: TEST: testContractOp
+@run
+def testContractOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (4, 8)
+            b_shape = (8, 12)
+            b_transposed_shape = (12, 8)
+            c_shape = (4, 12)
+
+            dimM = ir.AffineDimExpr.get(0)
+            dimN = ir.AffineDimExpr.get(1)
+            dimK = ir.AffineDimExpr.get(2)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+            # CHECK: #[[$B_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+            a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+            b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+            c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+            b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+            # CHECK: func.func @matmul_as_contract_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+                # And now with memrefs...
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=[],
+                    inputs=(Amem, Bmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=[],
+                    inputs=(Amem, Btransposedmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+        print(module)

@llvmbot
Copy link
Member

llvmbot commented Feb 8, 2025

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op.

Required following misc. fixes:

  1. make linalg.matmul consistent in whether indexing_maps occurs before
    or after operands, i.e. per the tests case it comes before.
    TODO: fix linalg.batch_matmul as well
  2. tablegen for linalg.contract did not state it accepted an optional
    cast attr.

Full diff: https://github.com/llvm/llvm-project/pull/126377.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+5-5)
  • (modified) mlir/python/mlir/dialects/linalg/init.py (+48)
  • (modified) mlir/test/python/dialects/linalg/ops.py (+186)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 110ed7d2fc00e2a..6146ff09482fbad 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -752,7 +752,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
   let arguments = (ins
     Variadic<AnyType>:$inputs,
     Variadic<AnyShaped>:$outputs,
-    AffineMapArrayAttr:$indexing_maps
+    AffineMapArrayAttr:$indexing_maps,
+    DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
   );
   let results = (outs Variadic<AnyShaped>:$result_tensors);
   // NB: The only reason this op has a region - and it get populated at op build
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b50931f15826ce2..d40cec02df6338d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void MatmulOp::print(OpAsmPrinter &p) {
-  SmallVector<StringRef, 3> elidedAttrs = {
-      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
-  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
-                         elidedAttrs);
-
   SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
       MatmulOp::getDefaultIndexingMaps(getContext()),
       [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
                           [&](Attribute attr) { p.printAttribute(attr); });
     p << "]";
   }
+
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                         elidedAttrs);
 }
 
 /// Verify the user defined indexing maps.
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 742262a9c496952..07e7c0ccd70025d 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -147,3 +147,51 @@ def __init__(
 
 
 generic = region_op(GenericOp_, terminator=YieldOp)
+
+
+def matmul(
+    inputs: Sequence[Union[Operation, OpView, Value]],
+    *,
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Sequence[AffineMapAttr],
+    cast: Optional[Union[TypeFn, Attribute]]=None
+):
+    inputs = [_get_op_result_or_value(input) for input in inputs]
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+    op = MatmulOp(
+        result_tensors=result_types,
+        inputs=inputs,
+        outputs=[init],
+        indexing_maps=indexing_maps,
+        cast=cast
+    )
+    fill_builtin_region(op.operation)
+    return op
+
+
+def contract(
+    inputs: Sequence[Union[Operation, OpView, Value]],
+    *,
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Sequence[AffineMapAttr],
+    cast: Optional[Union[TypeFn, Attribute]]=None
+):
+    inputs = [_get_op_result_or_value(input) for input in inputs]
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+    op = ContractOp(
+        result_tensors=result_types,
+        inputs=inputs,
+        outputs=[init],
+        indexing_maps=indexing_maps,
+        cast=cast
+    )
+    fill_builtin_region(op.operation)
+    return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index ac7186c24bed84e..6baea4f917c128c 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -256,3 +256,189 @@ def f(a, b):
 
     module.operation.verify()
     print(module)
+
+
+# CHECK-LABEL: TEST: testMatmulOp
+@run
+def testMatmulOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (4, 8)
+            b_shape = (8, 12)
+            b_transposed_shape = (12, 8)
+            c_shape = (4, 12)
+
+            dimM = ir.AffineDimExpr.get(0)
+            dimN = ir.AffineDimExpr.get(1)
+            dimK = ir.AffineDimExpr.get(2)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+            a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+            b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+            c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+            b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+            # CHECK: func.func @matmul_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+                # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+                # And now with memrefs...
+
+                # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Bmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Btransposedmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+        print(module)
+
+
+# CHECK-LABEL: TEST: testContractOp
+@run
+def testContractOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (4, 8)
+            b_shape = (8, 12)
+            b_transposed_shape = (12, 8)
+            c_shape = (4, 12)
+
+            dimM = ir.AffineDimExpr.get(0)
+            dimN = ir.AffineDimExpr.get(1)
+            dimK = ir.AffineDimExpr.get(2)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+            # CHECK: #[[$B_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+            a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+            b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+            c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+            b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+            # CHECK: func.func @matmul_as_contract_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+                # And now with memrefs...
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=[],
+                    inputs=(Amem, Bmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=[],
+                    inputs=(Amem, Btransposedmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+        print(module)

@rolfmorel rolfmorel removed the request for review from stellaraccident February 8, 2025 16:51
Copy link

github-actions bot commented Feb 8, 2025

✅ With the latest revision this PR passed the Python code formatter.

@makslevental
Copy link
Contributor

Lol did you really whip this up following the discussion yesterday on discourse yesterday or had you been working on this? Either way nice work - I'll take a close look soon.

@makslevental
Copy link
Contributor

Ya so

 Assertion failed: idx < size(),

is the error I alluded to in the discourse discussion.

@rolfmorel
Copy link
Contributor Author

Fixed now. Will explain tomorrow why this works / it has to be this way.

Same fix will need to happen for linalg.batch_matmul which I will do before this PR is ready to merge.

@@ -606,7 +622,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
DefaultValuedMatmulIndexingMapsAttr:$indexing_maps, // DONOTMERGE(rolfmorel): explain why this is necessary
Copy link
Contributor

@makslevental makslevental Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get what you're going for here but the right way to do this is just add a custom builder that constructs this attribute rather than adding a whole new type. That won't work for python because python can only call the default builder (which is maybe why you went this direction) but there it's the same story - the MatmulOp builder just needs to construct the correct indexing maps.

Copy link
Contributor Author

@rolfmorel rolfmorel Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is, as far as I understand now no builder (really no op-specific C++ code at all) is called when an op gets constructed from Python (see from here to here to here to here to finally here ). The only thing one can do is pass in default values for attributes. Could you point me to an example of specifying "the default builder" for an op?

And yes, preferably I would just have DefaultValuedOptionalAttr<AffineMapArrayAttr, "MatmulOp::getDefaultIndexingMaps($_builder.getContext())">:$indexing_maps in MatmulOp's arguments (note, to my knowledge, we cannot specify custom "builders" in this list), but that doesn't work. The problem is that DefaultValuedOptionalAttr assumes that the defaultValue can be given without access to the context. However, that is not true in this case (that is, attempting the above leads to error: use of undeclared identifier '$_builder'). The workaround is to just expand what DefaultValuedOptionalAttr does to get DefaultValuedMatmulIndexingMapsAttr in which we can modify constBuilderCall so that we can pass the context in case of the defaultValue. That the resulting attr is not "anonymous" - and hence we get a whole new type attr - is rather unfortunate though.

If you have suggestions for how to effect a cleaner approach, I am all ears!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, latest commit gets rid of whole new type: now no other code is needed/aware of that we aren't using DefaultValuedOptionalAttr<AffineMapArrayAttr, ...>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latest commit now moves the builderCall parameter to DefaultValued(Optional)Attr.

Probably an even cleaner solution is to just allow $_builder in Attr's defaultValue... (that's an exploration for another time though).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes, preferably I would just have DefaultValuedOptionalAttr<AffineMapArrayAttr, "MatmulOp::getDefaultIndexingMaps($_builder.getContext())">:$indexing_maps in MatmulOp's arguments (note, to my knowledge, we cannot specify custom "builders" in this list), but that doesn't work.

Rather than let my dreams be dreams, I decided to just get this to work. I doubt we will get to a cleaner solution than this.

Nevertheless, @makslevental, if you could give some pointers to info on "default builders" and how Python is supposed to invoke those, that would be much appreciated!

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:ods labels Feb 9, 2025
Reverts changes to DefaultValued(Optional)Attr.
@rolfmorel
Copy link
Contributor Author

rolfmorel commented Feb 9, 2025

Okay - decided a deep dive was actually a good use of a sunday afternoon.

I have now dealt with the "default values can only be C++ values" issue at the root: it was not possible to access the context within Attr::defaultValue. Previously I tried to work around this. Now I have just allowed access to $_builder from defaultValue. Note partial support for this already existed, see here. I have just extended that support to all cases where defaultValue gets dumped into C++. Happy to extract this commit to a separate PR if that's considered necessary.

To me, the above is an elegant solution to the "no builder method is invoked when ops are constructed from Python" problem. It is also a general quality of life improvement for others who come across this strange limitation (i.e. Attr's constBuilderCall can access the builder/context and defaultValue cannot). Let me know if other approaches are still considered preferable.

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick turnaround @rolfmorel !
I ran on a quick test internally and this mostly does what I expect.

One small discrepancy is how this changes the input args from opdsl, I now have the following:

if self.matmul_impl == linalg.matmul:
  self.matmul_impl([A, B], outs=[C])
else:
  self.matmul_impl(A, B, outs=[C])

Generally, I think I prefer you variant but as opdsl is gradually replaced, the size of this switch will grow.

How about accepting both variants for now? (not a blocker for me)

@rolfmorel
Copy link
Contributor Author

Hi @nicolasvasilache,

Thanks for pointing that out! I have now changed the argument format to that of OpDSL-derived ops (I didn't check carefully enough what that format was and went with what seemed more consistent with the MLIR-style).

@rengolin and I both don't care much for which syntax we go with. I personally feel there isn't much value in allowing the alternative style now for just a couple ops (versus allowing it/switching over all linalg ops at once). If others do have a preference for switching/allowing the alternative style now, I am happy to comply.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now we keep the old syntax and if someone wants to change, this can be a new PR/RFC.

@rolfmorel rolfmorel merged commit f796bc6 into llvm:main Feb 10, 2025
8 checks passed
@rolfmorel
Copy link
Contributor Author

Thanks @nicolasvasilache and @rengolin - I have now gone ahead and merged the PR!

Syntax changes and doing the same thing to other ops can be done in later PRs.

@nicolasvasilache
Copy link
Contributor

Hi @nicolasvasilache,

Thanks for pointing that out! I have now changed the argument format to that of OpDSL-derived ops (I didn't check carefully enough what that format was and went with what seemed more consistent with the MLIR-style).

@rengolin and I both don't care much for which syntax we go with. I personally feel there isn't much value in allowing the alternative style now for just a couple ops (versus allowing it/switching over all linalg ops at once). If others do have a preference for switching/allowing the alternative style now, I am happy to comply.

SGTM, thanks for fixing!

Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
llvm#126377)

Now that linalg.matmul is in tablegen, "hand write" the Python wrapper
that OpDSL used to derive. Similarly, add a Python wrapper for the new
linalg.contract op.

Required following misc. fixes:
1) make linalg.matmul's parsing and printing consistent w.r.t. whether
indexing_maps occurs before or after operands, i.e. per the tests cases
it comes _before_.
2) tablegen for linalg.contract did not state it accepted an optional
cast attr.
3) In ODS's C++-generating code, expand partial support for `$_builder`
access in `Attr::defaultValue` to full support. This enables access to
the current `MlirContext` when constructing the default value (as is
required when the default value consists of affine maps).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir:ods mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants