Skip to content

Commit

Permalink
[MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (
Browse files Browse the repository at this point in the history
llvm#126377)

Now that linalg.matmul is in tablegen, "hand write" the Python wrapper
that OpDSL used to derive. Similarly, add a Python wrapper for the new
linalg.contract op.

Required following misc. fixes:
1) make linalg.matmul's parsing and printing consistent w.r.t. whether
indexing_maps occurs before or after operands, i.e. per the tests cases
it comes _before_.
2) tablegen for linalg.contract did not state it accepted an optional
cast attr.
3) In ODS's C++-generating code, expand partial support for `$_builder`
access in `Attr::defaultValue` to full support. This enables access to
the current `MlirContext` when constructing the default value (as is
required when the default value consists of affine maps).
  • Loading branch information
rolfmorel authored and Icohedron committed Feb 11, 2025
1 parent eb62d36 commit bdabea2
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 32 deletions.
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,10 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
DefaultValuedOptionalAttr<
AffineMapArrayAttr,
"MatmulOp::getDefaultIndexingMaps($_builder.getContext())"
>:$indexing_maps,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
Expand Down Expand Up @@ -752,7 +755,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
AffineMapArrayAttr:$indexing_maps
AffineMapArrayAttr:$indexing_maps,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyShaped>:$result_tensors);
// NB: The only reason this op has a region - and it get populated at op build
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/CommonAttrConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Attr<Pred condition, string summary = ""> :

// Default value for attribute.
// Requires a constBuilderCall defined.
//
// Format: `$_builder` will be expanded to the relevant builder, e.g. to allow
// access to the current context.
string defaultValue = ?;

// The value type of this attribute. This corresponds to the mlir::Type that
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
}

void MatmulOp::print(OpAsmPrinter &p) {
SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);

SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
MatmulOp::getDefaultIndexingMaps(getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
Expand All @@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
[&](Attribute attr) { p.printAttribute(attr); });
p << "]";
}

SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);
}

/// Verify the user defined indexing maps.
Expand Down
46 changes: 46 additions & 0 deletions mlir/python/mlir/dialects/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,49 @@ def __init__(


generic = region_op(GenericOp_, terminator=YieldOp)


def matmul(
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
ins = [_get_op_result_or_value(input) for input in ins]
if len(outs) > 1:
raise ValueError(f"{outs=} must have length 1.")
init = _get_op_result_or_value(outs[0])
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []

op = MatmulOp(
result_tensors=result_types,
inputs=ins,
outputs=[init],
indexing_maps=indexing_maps,
cast=cast,
)
fill_builtin_region(op.operation)
return op


def contract(
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Sequence[AffineMapAttr],
cast: Optional[Union[TypeFn, Attribute]] = None,
):
ins = [_get_op_result_or_value(input) for input in ins]
if len(outs) > 1:
raise ValueError(f"{outs=} must have length 1.")
init = _get_op_result_or_value(outs[0])
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []

op = ContractOp(
result_tensors=result_types,
inputs=ins,
outputs=[init],
indexing_maps=indexing_maps,
cast=cast,
)
fill_builtin_region(op.operation)
return op
16 changes: 10 additions & 6 deletions mlir/test/Dialect/Linalg/named-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }

Expand All @@ -1294,7 +1294,7 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }

Expand All @@ -1315,6 +1315,7 @@ func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: m
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_a
// CHECK: linalg.matmul
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)

Expand All @@ -1335,6 +1336,7 @@ func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %ar
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_a_dim1
// CHECK: linalg.matmul
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)

Expand All @@ -1355,6 +1357,7 @@ func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: m
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_b
// CHECK: linalg.matmul
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)

Expand All @@ -1376,7 +1379,7 @@ func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: m
// CHECK-LABEL: func.func @matmul_bcast_a_b(
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }

Expand All @@ -1397,6 +1400,7 @@ func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %ar
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_b_dim1
// CHECK: linalg.matmul
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)

Expand All @@ -1420,7 +1424,7 @@ func.func @dynamic_matmul_bcast_a(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>,
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<?x?xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>)
// CHECK: return
// CHECK: }

Expand All @@ -1444,7 +1448,7 @@ func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf3
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }

Expand All @@ -1468,7 +1472,7 @@ func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf3
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }

Expand Down
Loading

0 comments on commit bdabea2

Please sign in to comment.