From bdabea2d5b39f329e6ddd5e0044bb002586f0944 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 10 Feb 2025 13:05:13 +0100 Subject: [PATCH] [MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (#126377) Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op. Required following misc. fixes: 1) make linalg.matmul's parsing and printing consistent w.r.t. whether indexing_maps occurs before or after operands, i.e. per the tests cases it comes _before_. 2) tablegen for linalg.contract did not state it accepted an optional cast attr. 3) In ODS's C++-generating code, expand partial support for `$_builder` access in `Attr::defaultValue` to full support. This enables access to the current `MlirContext` when constructing the default value (as is required when the default value consists of affine maps). --- .../Dialect/Linalg/IR/LinalgStructuredOps.td | 8 +- mlir/include/mlir/IR/CommonAttrConstraints.td | 3 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 10 +- mlir/python/mlir/dialects/linalg/__init__.py | 46 ++++ mlir/test/Dialect/Linalg/named-ops.mlir | 16 +- mlir/test/python/dialects/linalg/ops.py | 210 ++++++++++++++++++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 31 ++- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 22 +- mlir/tools/mlir-tblgen/RewriterGen.cpp | 2 +- 9 files changed, 316 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 110ed7d2fc00e..29cb8035b583b 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -606,7 +606,10 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ let arguments = (ins Variadic:$inputs, Variadic:$outputs, - DefaultValuedOptionalAttr:$indexing_maps, + DefaultValuedOptionalAttr< + AffineMapArrayAttr, + "MatmulOp::getDefaultIndexingMaps($_builder.getContext())" + >:$indexing_maps, DefaultValuedOptionalAttr:$cast ); let results = (outs Variadic:$result_tensors); @@ -752,7 +755,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [ let arguments = (ins Variadic:$inputs, Variadic:$outputs, - AffineMapArrayAttr:$indexing_maps + AffineMapArrayAttr:$indexing_maps, + DefaultValuedOptionalAttr:$cast ); let results = (outs Variadic:$result_tensors); // NB: The only reason this op has a region - and it get populated at op build diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td index 599f5ecba5803..2beb1e8110afe 100644 --- a/mlir/include/mlir/IR/CommonAttrConstraints.td +++ b/mlir/include/mlir/IR/CommonAttrConstraints.td @@ -50,6 +50,9 @@ class Attr : // Default value for attribute. // Requires a constBuilderCall defined. + // + // Format: `$_builder` will be expanded to the relevant builder, e.g. to allow + // access to the current context. string defaultValue = ?; // The value type of this attribute. This corresponds to the mlir::Type that diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index b50931f15826c..d40cec02df633 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) { } void MatmulOp::print(OpAsmPrinter &p) { - SmallVector elidedAttrs = { - "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; - printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), - elidedAttrs); - SmallVector indexingMaps = llvm::map_to_vector( MatmulOp::getDefaultIndexingMaps(getContext()), [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); @@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) { [&](Attribute attr) { p.printAttribute(attr); }); p << "]"; } + + SmallVector elidedAttrs = { + "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; + printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), + elidedAttrs); } /// Verify the user defined indexing maps. diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 742262a9c4969..5cda4769d593f 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -147,3 +147,49 @@ def __init__( generic = region_op(GenericOp_, terminator=YieldOp) + + +def matmul( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + ins = [_get_op_result_or_value(input) for input in ins] + if len(outs) > 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] + + op = MatmulOp( + result_tensors=result_types, + inputs=ins, + outputs=[init], + indexing_maps=indexing_maps, + cast=cast, + ) + fill_builtin_region(op.operation) + return op + + +def contract( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Sequence[AffineMapAttr], + cast: Optional[Union[TypeFn, Attribute]] = None, +): + ins = [_get_op_result_or_value(input) for input in ins] + if len(outs) > 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] + + op = ContractOp( + result_tensors=result_types, + inputs=ins, + outputs=[init], + indexing_maps=indexing_maps, + cast=cast, + ) + fill_builtin_region(op.operation) + return op diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index ed8683522c74a..68ea97be911a6 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1269,7 +1269,7 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5 // CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) // CHECK: return // CHECK: } @@ -1294,7 +1294,7 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7 // CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) // CHECK: return // CHECK: } @@ -1315,6 +1315,7 @@ func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: m // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @matmul_bcast_a // CHECK: linalg.matmul +// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>) // CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) @@ -1335,6 +1336,7 @@ func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %ar // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @matmul_bcast_a_dim1 // CHECK: linalg.matmul +// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>) // CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) @@ -1355,6 +1357,7 @@ func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: m // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @matmul_bcast_b // CHECK: linalg.matmul +// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>) // CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) @@ -1376,7 +1379,7 @@ func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: m // CHECK-LABEL: func.func @matmul_bcast_a_b( // CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]] +// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) // CHECK: return // CHECK: } @@ -1397,6 +1400,7 @@ func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %ar // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @matmul_bcast_b_dim1 // CHECK: linalg.matmul +// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>) // CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) @@ -1420,7 +1424,7 @@ func.func @dynamic_matmul_bcast_a(%arg0: memref, %arg1: memref, // CHECK-SAME: %[[VAL_0:.*]]: memref, // CHECK-SAME: %[[VAL_1:.*]]: memref, // CHECK-SAME: %[[VAL_2:.*]]: memref) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref, memref) outs(%[[VAL_2]] : memref) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref, memref) outs(%[[VAL_2]] : memref) // CHECK: return // CHECK: } @@ -1444,7 +1448,7 @@ func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf3 // CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) // CHECK: return // CHECK: } @@ -1468,7 +1472,7 @@ func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf3 // CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) // CHECK: return // CHECK: } diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index ac7186c24bed8..94f8ea4faf4a8 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -256,3 +256,213 @@ def f(a, b): module.operation.verify() print(module) + + +# CHECK-LABEL: TEST: testMatmulOp +@run +def testMatmulOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + a_shape = (4, 8) + b_shape = (8, 12) + b_transposed_shape = (12, 8) + c_shape = (4, 12) + + dimM = ir.AffineDimExpr.get(0) + dimN = ir.AffineDimExpr.get(1) + dimK = ir.AffineDimExpr.get(2) + + # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> + # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> + # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + a_map = ir.AffineMap.get(3, 0, [dimM, dimK]) + b_map = ir.AffineMap.get(3, 0, [dimK, dimN]) + c_map = ir.AffineMap.get(3, 0, [dimM, dimN]) + b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK]) + + # CHECK: func.func @matmul_op( + @func.FuncOp.from_py_func( + # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>, + RankedTensorType.get(a_shape, f32), + # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>, + MemRefType.get(a_shape, f32), + # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>, + RankedTensorType.get(b_shape, f32), + # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>, + MemRefType.get(b_shape, f32), + # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>, + RankedTensorType.get(b_transposed_shape, f32), + # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>, + MemRefType.get(b_transposed_shape, f32), + # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>, + RankedTensorType.get(c_shape, f32), + # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>) + MemRefType.get(c_shape, f32), + ) + def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem): + # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>) + res = linalg.MatmulOp( + result_tensors=(C.type,), + inputs=(A, B), + outputs=(C,), + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>) + res = linalg.matmul(A, B, outs=(C,)) + + # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>) + res = linalg.MatmulOp( + result_tensors=(C.type,), + inputs=(A, Btransposed), + outputs=(C,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>) + res = linalg.matmul( + A, + Btransposed, + outs=(C,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + + # And now with memrefs... + + # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + res = linalg.MatmulOp( + result_tensors=[], + inputs=(Amem, Bmem), + outputs=(Cmem,), + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + linalg.matmul(Amem, Bmem, outs=(Cmem,)) + + # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + res = linalg.MatmulOp( + result_tensors=[], + inputs=(Amem, Btransposedmem), + outputs=(Cmem,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + linalg.matmul( + Amem, + Btransposedmem, + outs=(Cmem,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + + print(module) + + +# CHECK-LABEL: TEST: testContractOp +@run +def testContractOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + a_shape = (4, 8) + b_shape = (8, 12) + b_transposed_shape = (12, 8) + c_shape = (4, 12) + + dimM = ir.AffineDimExpr.get(0) + dimN = ir.AffineDimExpr.get(1) + dimK = ir.AffineDimExpr.get(2) + + # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> + # CHECK: #[[$B_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> + # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> + a_map = ir.AffineMap.get(3, 0, [dimM, dimK]) + b_map = ir.AffineMap.get(3, 0, [dimK, dimN]) + c_map = ir.AffineMap.get(3, 0, [dimM, dimN]) + b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK]) + + # CHECK: func.func @matmul_as_contract_op( + @func.FuncOp.from_py_func( + # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>, + RankedTensorType.get(a_shape, f32), + # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>, + MemRefType.get(a_shape, f32), + # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>, + RankedTensorType.get(b_shape, f32), + # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>, + MemRefType.get(b_shape, f32), + # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>, + RankedTensorType.get(b_transposed_shape, f32), + # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>, + MemRefType.get(b_transposed_shape, f32), + # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>, + RankedTensorType.get(c_shape, f32), + # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>) + MemRefType.get(c_shape, f32), + ) + def matmul_as_contract_op( + A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem + ): + # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>) + op4 = linalg.ContractOp( + result_tensors=(C.type,), + inputs=(A, B), + outputs=(C,), + indexing_maps=[a_map, b_map, c_map], + ) + linalg.fill_builtin_region(op4.operation) + # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>) + op5 = linalg.contract( + A, B, outs=(C,), indexing_maps=[a_map, b_map, c_map] + ) + + # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>) + op4 = linalg.ContractOp( + result_tensors=(C.type,), + inputs=(A, Btransposed), + outputs=(C,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + linalg.fill_builtin_region(op4.operation) + # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>) + op5 = linalg.contract( + A, + Btransposed, + outs=(C,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + # And now with memrefs... + + # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + op4 = linalg.ContractOp( + result_tensors=[], + inputs=(Amem, Bmem), + outputs=(Cmem,), + indexing_maps=[a_map, b_map, c_map], + ) + linalg.fill_builtin_region(op4.operation) + # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + linalg.contract( + Amem, Bmem, outs=(Cmem,), indexing_maps=[a_map, b_map, c_map] + ) + + # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + op4 = linalg.ContractOp( + result_tensors=[], + inputs=(Amem, Btransposedmem), + outputs=(Cmem,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + linalg.fill_builtin_region(op4.operation) + # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + linalg.contract( + Amem, + Btransposedmem, + outs=(Cmem,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + + print(module) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index a970cbc5caceb..629e863dac5e3 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1334,8 +1334,9 @@ static void emitAttrGetterWithReturnType(FmtContext &fctx, PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() + " must have a constBuilder"); } - std::string defaultValue = std::string( - tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); + std::string defaultValue = + std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx, + tgfmt(attr.getDefaultValue(), &fctx))); body << " if (!attr)\n return " << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf(defaultValue)) @@ -1467,6 +1468,7 @@ void OpEmitter::genPropertiesSupport() { os << " if (!attr) attr = dict.get(\"result_segment_sizes\");"; } + fctx.withBuilder(odsBuilder); setPropMethod << "{\n" << formatv(propFromAttrFmt, tgfmt(prop.getConvertFromAttributeCall(), @@ -1479,7 +1481,7 @@ void OpEmitter::genPropertiesSupport() { prop.getStorageTypeValueOverride()); } else if (prop.hasDefaultValue()) { setPropMethod << formatv(attrGetDefaultFmt, name, - prop.getDefaultValue()); + tgfmt(prop.getDefaultValue(), &fctx)); } else { setPropMethod << formatv(attrGetNoDefaultFmt, name); } @@ -2919,6 +2921,9 @@ getBuilderSignature(const Builder &builder) { arguments.emplace_back("::mlir::OpBuilder &", odsBuilder); arguments.emplace_back("::mlir::OperationState &", builderOpState); + FmtContext fctx; + fctx.withBuilder(odsBuilder); + for (unsigned i = 0, e = params.size(); i < e; ++i) { // If no name is provided, generate one. std::optional paramName = params[i].getName(); @@ -2931,7 +2936,7 @@ getBuilderSignature(const Builder &builder) { defaultValue = *defaultParamValue; arguments.emplace_back(params[i].getCppType(), std::move(name), - defaultValue); + tgfmt(defaultValue, &fctx)); } return arguments; @@ -3189,6 +3194,9 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, } } + FmtContext fctx; + fctx.withBuilder(odsBuilder); + for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) { Argument arg = op.getArg(i); if (const auto *operand = @@ -3210,7 +3218,7 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, StringRef type = prop.getInterfaceType(); std::string defaultValue; if (prop.hasDefaultValue() && i >= defaultValuedAttrLikeStartIndex) { - defaultValue = prop.getDefaultValue(); + defaultValue = tgfmt(prop.getDefaultValue(), &fctx); } bool isOptional = prop.hasDefaultValue(); paramList.emplace_back(type, propArg->name, StringRef(defaultValue), @@ -3242,7 +3250,7 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, if (i >= defaultValuedAttrStartIndex) { if (attrParamKind == AttrParamKind::UnwrappedValue && canUseUnwrappedRawValue(attr)) - defaultValue += attr.getDefaultValue(); + defaultValue += tgfmt(attr.getDefaultValue(), &fctx); else defaultValue += "nullptr"; } @@ -4172,6 +4180,9 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( staticVerifierEmitter(staticVerifierEmitter), emitHelper(op, /*emitForOp=*/false) { + FmtContext fctx; + fctx.withBuilder(odsBuilder); + genericAdaptorBase.declare(Visibility::Public); bool useProperties = emitHelper.hasProperties(); if (useProperties) { @@ -4212,7 +4223,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( if (prop.hasStorageTypeValueOverride()) os << " = " << prop.getStorageTypeValueOverride(); else if (prop.hasDefaultValue()) - os << " = " << prop.getDefaultValue(); + os << " = " << tgfmt(prop.getDefaultValue(), &fctx); comparatorOs << " rhs." << name << " == this->" << name << " &&\n"; // Emit accessors using the interface type. @@ -4454,7 +4465,6 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands")) m->body() << " return odsOperands;"; - FmtContext fctx; fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); // Generate named accessor with Attribute return type. @@ -4481,8 +4491,9 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( // Use the default value if attribute is not set. // TODO: this is inefficient, we are recreating the attribute for every // call. This should be set instead. - std::string defaultValue = std::string( - tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); + std::string defaultValue = + std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx, + tgfmt(attr.getDefaultValue(), &fctx))); body << "if (!attr)\n attr = " << defaultValue << ";\n"; } body << "return attr;\n"; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index f03a3bfd398ed..fe724e86d6707 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1999,7 +1999,7 @@ static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())"); body << getter << "Attr() != " << tgfmt(attr.getConstBuilderTemplate(), &fctx, - attr.getDefaultValue()); + tgfmt(attr.getDefaultValue(), &fctx)); } if (optionalAndDefault) body << ")"; @@ -2007,8 +2007,10 @@ static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, PropertyVariable &propElement) { - body << op.getGetterName(propElement.getVar()->name) - << "() != " << propElement.getVar()->prop.getDefaultValue(); + FmtContext fctx; + fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())"); + body << op.getGetterName(propElement.getVar()->name) << "() != " + << tgfmt(propElement.getVar()->prop.getDefaultValue(), &fctx); } /// Elide the variadic segment size attributes if necessary. @@ -2045,8 +2047,9 @@ static void genPropDictPrinter(OperationFormat &fmt, Operator &op, const StringRef &name = namedAttr.name; FmtContext fctx; fctx.withBuilder("odsBuilder"); - std::string defaultValue = std::string( - tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); + std::string defaultValue = + std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx, + tgfmt(attr.getDefaultValue(), &fctx))); body << " {\n"; body << " ::mlir::Builder odsBuilder(getContext());\n"; body << " ::mlir::Attribute attr = " << op.getGetterName(name) @@ -2059,8 +2062,10 @@ static void genPropDictPrinter(OperationFormat &fmt, Operator &op, // Similarly, elide default-valued properties. for (const NamedProperty &prop : op.getProperties()) { if (prop.prop.hasDefaultValue()) { + FmtContext fctx; + fctx.withBuilder("odsBuilder"); body << " if (" << op.getGetterName(prop.name) - << "() == " << prop.prop.getDefaultValue() << ") {"; + << "() == " << tgfmt(prop.prop.getDefaultValue(), &fctx) << ") {"; body << " elidedProps.push_back(\"" << prop.name << "\");\n"; body << " }\n"; } @@ -2094,8 +2099,9 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op, const StringRef &name = namedAttr.name; FmtContext fctx; fctx.withBuilder("odsBuilder"); - std::string defaultValue = std::string( - tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); + std::string defaultValue = + std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx, + tgfmt(attr.getDefaultValue(), &fctx))); body << " {\n"; body << " ::mlir::Builder odsBuilder(getContext());\n"; body << " ::mlir::Attribute attr = " << op.getGetterName(name) diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index a041c4d327779..f6eb5bdfe568e 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -879,7 +879,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, if (attr.hasDefaultValue()) { os << "if (!tblgen_attr) tblgen_attr = " << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, - attr.getDefaultValue())) + tgfmt(attr.getDefaultValue(), &fmtCtx))) << ";\n"; } else if (attr.isOptional()) { // For a missing attribute that is optional according to definition, we