Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API #126377

Merged
merged 11 commits into from
Feb 10, 2025
21 changes: 19 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,22 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//


// DONOTMERGE(rolfmorel): explain why the below is necessary
def DefaultValuedMatmulIndexingMapsAttr :
Attr<AffineMapArrayAttr.predicate, AffineMapArrayAttr.summary> {
let storageType = AffineMapArrayAttr.storageType;
let returnType = AffineMapArrayAttr.returnType;
let convertFromStorage = AffineMapArrayAttr.convertFromStorage;
let constBuilderCall = "$_builder.getAffineMapArrayAttr($0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0)";
let defaultValue = "SmallVector<AffineMap>()";
let valueType = AffineMapArrayAttr.valueType;
let isOptional = 1;

let baseAttr = AffineMapArrayAttr;
}


def MatmulOp : LinalgStructuredBase_Op<"matmul", [
AttrSizedOperandSegments,
LinalgContractionOpInterface]> {
Expand Down Expand Up @@ -606,7 +622,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
DefaultValuedMatmulIndexingMapsAttr:$indexing_maps, // DONOTMERGE(rolfmorel): explain why this is necessary
Copy link
Contributor

@makslevental makslevental Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get what you're going for here but the right way to do this is just add a custom builder that constructs this attribute rather than adding a whole new type. That won't work for python because python can only call the default builder (which is maybe why you went this direction) but there it's the same story - the MatmulOp builder just needs to construct the correct indexing maps.

Copy link
Contributor Author

@rolfmorel rolfmorel Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is, as far as I understand now no builder (really no op-specific C++ code at all) is called when an op gets constructed from Python (see from here to here to here to here to finally here ). The only thing one can do is pass in default values for attributes. Could you point me to an example of specifying "the default builder" for an op?

And yes, preferably I would just have DefaultValuedOptionalAttr<AffineMapArrayAttr, "MatmulOp::getDefaultIndexingMaps($_builder.getContext())">:$indexing_maps in MatmulOp's arguments (note, to my knowledge, we cannot specify custom "builders" in this list), but that doesn't work. The problem is that DefaultValuedOptionalAttr assumes that the defaultValue can be given without access to the context. However, that is not true in this case (that is, attempting the above leads to error: use of undeclared identifier '$_builder'). The workaround is to just expand what DefaultValuedOptionalAttr does to get DefaultValuedMatmulIndexingMapsAttr in which we can modify constBuilderCall so that we can pass the context in case of the defaultValue. That the resulting attr is not "anonymous" - and hence we get a whole new type attr - is rather unfortunate though.

If you have suggestions for how to effect a cleaner approach, I am all ears!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, latest commit gets rid of whole new type: now no other code is needed/aware of that we aren't using DefaultValuedOptionalAttr<AffineMapArrayAttr, ...>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latest commit now moves the builderCall parameter to DefaultValued(Optional)Attr.

Probably an even cleaner solution is to just allow $_builder in Attr's defaultValue... (that's an exploration for another time though).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes, preferably I would just have DefaultValuedOptionalAttr<AffineMapArrayAttr, "MatmulOp::getDefaultIndexingMaps($_builder.getContext())">:$indexing_maps in MatmulOp's arguments (note, to my knowledge, we cannot specify custom "builders" in this list), but that doesn't work.

Rather than let my dreams be dreams, I decided to just get this to work. I doubt we will get to a cleaner solution than this.

Nevertheless, @makslevental, if you could give some pointers to info on "default builders" and how Python is supposed to invoke those, that would be much appreciated!

DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
Expand Down Expand Up @@ -752,7 +768,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
AffineMapArrayAttr:$indexing_maps
AffineMapArrayAttr:$indexing_maps,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyShaped>:$result_tensors);
// NB: The only reason this op has a region - and it get populated at op build
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
}

void MatmulOp::print(OpAsmPrinter &p) {
SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);

SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
MatmulOp::getDefaultIndexingMaps(getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
Expand All @@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
[&](Attribute attr) { p.printAttribute(attr); });
p << "]";
}

SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);
}

/// Verify the user defined indexing maps.
Expand Down
53 changes: 53 additions & 0 deletions mlir/python/mlir/dialects/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,56 @@ def __init__(


generic = region_op(GenericOp_, terminator=YieldOp)


@register_attribute_builder("DefaultValuedMatmulIndexingMapsAttr")
def _DefaultValuedMatmulIndexingMapsAttr(x, context):
return ArrayAttr.get([AffineMapAttr.get(v) for v in x])


def matmul(
inputs: Sequence[Union[Operation, OpView, Value]],
*,
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
inputs = [_get_op_result_or_value(input) for input in inputs]
if len(outs) > 1:
raise ValueError(f"{outs=} must have length 1.")
init = _get_op_result_or_value(outs[0])
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []

op = MatmulOp(
result_tensors=result_types,
inputs=inputs,
outputs=[init],
indexing_maps=indexing_maps,
cast=cast,
)
fill_builtin_region(op.operation)
return op


def contract(
inputs: Sequence[Union[Operation, OpView, Value]],
*,
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Sequence[AffineMapAttr],
cast: Optional[Union[TypeFn, Attribute]] = None,
):
inputs = [_get_op_result_or_value(input) for input in inputs]
if len(outs) > 1:
raise ValueError(f"{outs=} must have length 1.")
init = _get_op_result_or_value(outs[0])
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []

op = ContractOp(
result_tensors=result_types,
inputs=inputs,
outputs=[init],
indexing_maps=indexing_maps,
cast=cast,
)
fill_builtin_region(op.operation)
return op
16 changes: 10 additions & 6 deletions mlir/test/Dialect/Linalg/named-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }

Expand All @@ -1294,7 +1294,7 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }

Expand All @@ -1315,6 +1315,7 @@ func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: m
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_a
// CHECK: linalg.matmul
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)

Expand All @@ -1335,6 +1336,7 @@ func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %ar
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_a_dim1
// CHECK: linalg.matmul
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)

Expand All @@ -1355,6 +1357,7 @@ func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: m
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_b
// CHECK: linalg.matmul
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)

Expand All @@ -1376,7 +1379,7 @@ func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: m
// CHECK-LABEL: func.func @matmul_bcast_a_b(
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }

Expand All @@ -1397,6 +1400,7 @@ func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %ar
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_b_dim1
// CHECK: linalg.matmul
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)

Expand All @@ -1420,7 +1424,7 @@ func.func @dynamic_matmul_bcast_a(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>,
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<?x?xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>)
// CHECK: return
// CHECK: }

Expand All @@ -1444,7 +1448,7 @@ func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf3
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }

Expand All @@ -1468,7 +1472,7 @@ func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf3
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }

Expand Down
Loading