-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API #126377
Conversation
Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op. Required following misc. fixes: 1) make linalg.matmul consistent in whether indexing_maps occurs before or after operands, i.e. per the tests case it comes _before_. TODO: fix linalg.batch_matmul as well 2) tablegen for linalg.contract did not state it accepted an optional cast attr.
@llvm/pr-subscribers-mlir-ods @llvm/pr-subscribers-mlir-linalg Author: Rolf Morel (rolfmorel) ChangesNow that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op. Required following misc. fixes:
Full diff: https://github.com/llvm/llvm-project/pull/126377.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 110ed7d2fc00e2a..6146ff09482fbad 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -752,7 +752,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
- AffineMapArrayAttr:$indexing_maps
+ AffineMapArrayAttr:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyShaped>:$result_tensors);
// NB: The only reason this op has a region - and it get populated at op build
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b50931f15826ce2..d40cec02df6338d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
}
void MatmulOp::print(OpAsmPrinter &p) {
- SmallVector<StringRef, 3> elidedAttrs = {
- "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
- printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
- elidedAttrs);
-
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
MatmulOp::getDefaultIndexingMaps(getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
[&](Attribute attr) { p.printAttribute(attr); });
p << "]";
}
+
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
}
/// Verify the user defined indexing maps.
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 742262a9c496952..07e7c0ccd70025d 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -147,3 +147,51 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)
+
+
+def matmul(
+ inputs: Sequence[Union[Operation, OpView, Value]],
+ *,
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Sequence[AffineMapAttr],
+ cast: Optional[Union[TypeFn, Attribute]]=None
+):
+ inputs = [_get_op_result_or_value(input) for input in inputs]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = MatmulOp(
+ result_tensors=result_types,
+ inputs=inputs,
+ outputs=[init],
+ indexing_maps=indexing_maps,
+ cast=cast
+ )
+ fill_builtin_region(op.operation)
+ return op
+
+
+def contract(
+ inputs: Sequence[Union[Operation, OpView, Value]],
+ *,
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Sequence[AffineMapAttr],
+ cast: Optional[Union[TypeFn, Attribute]]=None
+):
+ inputs = [_get_op_result_or_value(input) for input in inputs]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = ContractOp(
+ result_tensors=result_types,
+ inputs=inputs,
+ outputs=[init],
+ indexing_maps=indexing_maps,
+ cast=cast
+ )
+ fill_builtin_region(op.operation)
+ return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index ac7186c24bed84e..6baea4f917c128c 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -256,3 +256,189 @@ def f(a, b):
module.operation.verify()
print(module)
+
+
+# CHECK-LABEL: TEST: testMatmulOp
+@run
+def testMatmulOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (4, 8)
+ b_shape = (8, 12)
+ b_transposed_shape = (12, 8)
+ c_shape = (4, 12)
+
+ dimM = ir.AffineDimExpr.get(0)
+ dimN = ir.AffineDimExpr.get(1)
+ dimK = ir.AffineDimExpr.get(2)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+ b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+ c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+ b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+ # CHECK: func.func @matmul_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+ # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ # And now with memrefs...
+
+ # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Bmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Btransposedmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ print(module)
+
+
+# CHECK-LABEL: TEST: testContractOp
+@run
+def testContractOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (4, 8)
+ b_shape = (8, 12)
+ b_transposed_shape = (12, 8)
+ c_shape = (4, 12)
+
+ dimM = ir.AffineDimExpr.get(0)
+ dimN = ir.AffineDimExpr.get(1)
+ dimK = ir.AffineDimExpr.get(2)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$B_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+ a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+ b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+ c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+ b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+ # CHECK: func.func @matmul_as_contract_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+ # And now with memrefs...
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=[],
+ inputs=(Amem, Bmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=[],
+ inputs=(Amem, Btransposedmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ print(module)
|
@llvm/pr-subscribers-mlir Author: Rolf Morel (rolfmorel) ChangesNow that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op. Required following misc. fixes:
Full diff: https://github.com/llvm/llvm-project/pull/126377.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 110ed7d2fc00e2a..6146ff09482fbad 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -752,7 +752,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
- AffineMapArrayAttr:$indexing_maps
+ AffineMapArrayAttr:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyShaped>:$result_tensors);
// NB: The only reason this op has a region - and it get populated at op build
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b50931f15826ce2..d40cec02df6338d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
}
void MatmulOp::print(OpAsmPrinter &p) {
- SmallVector<StringRef, 3> elidedAttrs = {
- "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
- printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
- elidedAttrs);
-
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
MatmulOp::getDefaultIndexingMaps(getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
[&](Attribute attr) { p.printAttribute(attr); });
p << "]";
}
+
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
}
/// Verify the user defined indexing maps.
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 742262a9c496952..07e7c0ccd70025d 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -147,3 +147,51 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)
+
+
+def matmul(
+ inputs: Sequence[Union[Operation, OpView, Value]],
+ *,
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Sequence[AffineMapAttr],
+ cast: Optional[Union[TypeFn, Attribute]]=None
+):
+ inputs = [_get_op_result_or_value(input) for input in inputs]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = MatmulOp(
+ result_tensors=result_types,
+ inputs=inputs,
+ outputs=[init],
+ indexing_maps=indexing_maps,
+ cast=cast
+ )
+ fill_builtin_region(op.operation)
+ return op
+
+
+def contract(
+ inputs: Sequence[Union[Operation, OpView, Value]],
+ *,
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Sequence[AffineMapAttr],
+ cast: Optional[Union[TypeFn, Attribute]]=None
+):
+ inputs = [_get_op_result_or_value(input) for input in inputs]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = ContractOp(
+ result_tensors=result_types,
+ inputs=inputs,
+ outputs=[init],
+ indexing_maps=indexing_maps,
+ cast=cast
+ )
+ fill_builtin_region(op.operation)
+ return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index ac7186c24bed84e..6baea4f917c128c 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -256,3 +256,189 @@ def f(a, b):
module.operation.verify()
print(module)
+
+
+# CHECK-LABEL: TEST: testMatmulOp
+@run
+def testMatmulOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (4, 8)
+ b_shape = (8, 12)
+ b_transposed_shape = (12, 8)
+ c_shape = (4, 12)
+
+ dimM = ir.AffineDimExpr.get(0)
+ dimN = ir.AffineDimExpr.get(1)
+ dimK = ir.AffineDimExpr.get(2)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+ b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+ c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+ b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+ # CHECK: func.func @matmul_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+ # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ # And now with memrefs...
+
+ # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Bmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Btransposedmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ print(module)
+
+
+# CHECK-LABEL: TEST: testContractOp
+@run
+def testContractOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (4, 8)
+ b_shape = (8, 12)
+ b_transposed_shape = (12, 8)
+ c_shape = (4, 12)
+
+ dimM = ir.AffineDimExpr.get(0)
+ dimN = ir.AffineDimExpr.get(1)
+ dimK = ir.AffineDimExpr.get(2)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$B_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+ a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+ b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+ c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+ b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+ # CHECK: func.func @matmul_as_contract_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+ # And now with memrefs...
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=[],
+ inputs=(Amem, Bmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=[],
+ inputs=(Amem, Btransposedmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ print(module)
|
✅ With the latest revision this PR passed the Python code formatter. |
Lol did you really whip this up following the discussion yesterday on discourse yesterday or had you been working on this? Either way nice work - I'll take a close look soon. |
Ya so
is the error I alluded to in the discourse discussion. |
Fixed now. Will explain tomorrow why this works / it has to be this way. Same fix will need to happen for |
@@ -606,7 +622,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ | |||
let arguments = (ins | |||
Variadic<AnyType>:$inputs, | |||
Variadic<AnyShaped>:$outputs, | |||
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps, | |||
DefaultValuedMatmulIndexingMapsAttr:$indexing_maps, // DONOTMERGE(rolfmorel): explain why this is necessary |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get what you're going for here but the right way to do this is just add a custom builder that constructs this attribute rather than adding a whole new type. That won't work for python because python can only call the default builder (which is maybe why you went this direction) but there it's the same story - the MatmulOp builder just needs to construct the correct indexing maps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is, as far as I understand now no builder (really no op-specific C++ code at all) is called when an op gets constructed from Python (see from here to here to here to here to finally here ). The only thing one can do is pass in default values for attributes. Could you point me to an example of specifying "the default builder" for an op?
And yes, preferably I would just have DefaultValuedOptionalAttr<AffineMapArrayAttr, "MatmulOp::getDefaultIndexingMaps($_builder.getContext())">:$indexing_maps
in MatmulOp's arguments
(note, to my knowledge, we cannot specify custom "builders" in this list), but that doesn't work. The problem is that DefaultValuedOptionalAttr
assumes that the defaultValue
can be given without access to the context. However, that is not true in this case (that is, attempting the above leads to error: use of undeclared identifier '$_builder'
). The workaround is to just expand what DefaultValuedOptionalAttr
does to get DefaultValuedMatmulIndexingMapsAttr
in which we can modify constBuilderCall
so that we can pass the context in case of the defaultValue
. That the resulting attr is not "anonymous" - and hence we get a whole new type attr - is rather unfortunate though.
If you have suggestions for how to effect a cleaner approach, I am all ears!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, latest commit gets rid of whole new type: now no other code is needed/aware of that we aren't using DefaultValuedOptionalAttr<AffineMapArrayAttr, ...>
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Latest commit now moves the builderCall
parameter to DefaultValued(Optional)Attr.
Probably an even cleaner solution is to just allow $_builder
in Attr's defaultValue
... (that's an exploration for another time though).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And yes, preferably I would just have DefaultValuedOptionalAttr<AffineMapArrayAttr, "MatmulOp::getDefaultIndexingMaps($_builder.getContext())">:$indexing_maps in MatmulOp's arguments (note, to my knowledge, we cannot specify custom "builders" in this list), but that doesn't work.
Rather than let my dreams be dreams, I decided to just get this to work. I doubt we will get to a cleaner solution than this.
Nevertheless, @makslevental, if you could give some pointers to info on "default builders" and how Python is supposed to invoke those, that would be much appreciated!
…t valued attr class Should probably be moved to CommonAttrConstraints.td
260d689
to
538566a
Compare
Reverts changes to DefaultValued(Optional)Attr.
Okay - decided a deep dive was actually a good use of a sunday afternoon. I have now dealt with the "default values can only be C++ values" issue at the root: it was not possible to access the context within To me, the above is an elegant solution to the "no builder method is invoked when ops are constructed from Python" problem. It is also a general quality of life improvement for others who come across this strange limitation (i.e. Attr's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick turnaround @rolfmorel !
I ran on a quick test internally and this mostly does what I expect.
One small discrepancy is how this changes the input args from opdsl, I now have the following:
if self.matmul_impl == linalg.matmul:
self.matmul_impl([A, B], outs=[C])
else:
self.matmul_impl(A, B, outs=[C])
Generally, I think I prefer you variant but as opdsl is gradually replaced, the size of this switch will grow.
How about accepting both variants for now? (not a blocker for me)
Thanks for pointing that out! I have now changed the argument format to that of OpDSL-derived ops (I didn't check carefully enough what that format was and went with what seemed more consistent with the MLIR-style). @rengolin and I both don't care much for which syntax we go with. I personally feel there isn't much value in allowing the alternative style now for just a couple ops (versus allowing it/switching over all linalg ops at once). If others do have a preference for switching/allowing the alternative style now, I am happy to comply. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for now we keep the old syntax and if someone wants to change, this can be a new PR/RFC.
Thanks @nicolasvasilache and @rengolin - I have now gone ahead and merged the PR! Syntax changes and doing the same thing to other ops can be done in later PRs. |
SGTM, thanks for fixing! |
llvm#126377) Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op. Required following misc. fixes: 1) make linalg.matmul's parsing and printing consistent w.r.t. whether indexing_maps occurs before or after operands, i.e. per the tests cases it comes _before_. 2) tablegen for linalg.contract did not state it accepted an optional cast attr. 3) In ODS's C++-generating code, expand partial support for `$_builder` access in `Attr::defaultValue` to full support. This enables access to the current `MlirContext` when constructing the default value (as is required when the default value consists of affine maps).
Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op.
Required following misc. fixes:
$_builder
access inAttr::defaultValue
to full support. This enables access to the currentMlirContext
when constructing the default value (as is required when the default value consists of affine maps).