Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Feb 20, 2025
1 parent d18ca56 commit 8001c7a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
1 change: 1 addition & 0 deletions mlir/docs/Dialects/Linalg/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -695,3 +695,4 @@ the same IR.
## Operations

[include "Dialects/LinalgOps.md"]
[include "Dialects/LinalgRelayoutOps.td"]
20 changes: 14 additions & 6 deletions mlir/python/mlir/dialects/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)


def create_op(
def _create_matmul_like_op(
op_type,
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
Expand Down Expand Up @@ -183,7 +183,11 @@ def matmul(
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
return _get_op_result_or_op_results(
_create_matmul_like_op(
MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
)


def batch_matmul(
Expand All @@ -192,8 +196,10 @@ def batch_matmul(
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
return _get_op_result_or_op_results(
_create_matmul_like_op(
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
)


Expand All @@ -203,8 +209,10 @@ def contract(
indexing_maps: Sequence[AffineMapAttr],
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
return _get_op_result_or_op_results(
_create_matmul_like_op(
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
)


Expand Down

0 comments on commit 8001c7a

Please sign in to comment.