From 8001c7a8da5b5bfbee3369aa55a146d316e34c36 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 19 Feb 2025 21:35:38 -0500 Subject: [PATCH] fixup --- mlir/docs/Dialects/Linalg/_index.md | 1 + mlir/python/mlir/dialects/linalg/__init__.py | 20 ++++++++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md index 976f0fd3c7e91..b519e4159f186 100644 --- a/mlir/docs/Dialects/Linalg/_index.md +++ b/mlir/docs/Dialects/Linalg/_index.md @@ -695,3 +695,4 @@ the same IR. ## Operations [include "Dialects/LinalgOps.md"] +[include "Dialects/LinalgRelayoutOps.td"] diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index b99344d34db89..61262dcaaa823 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -153,7 +153,7 @@ def __init__( generic = region_op(GenericOp_, terminator=YieldOp) -def create_op( +def _create_matmul_like_op( op_type, *ins: Union[Operation, OpView, Value], outs: Sequence[Union[Operation, OpView, Value]], @@ -183,7 +183,11 @@ def matmul( indexing_maps: Optional[Sequence[AffineMapAttr]] = None, cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast) + return _get_op_result_or_op_results( + _create_matmul_like_op( + MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) def batch_matmul( @@ -192,8 +196,10 @@ def batch_matmul( indexing_maps: Optional[Sequence[AffineMapAttr]] = None, cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op( - BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + return _get_op_result_or_op_results( + _create_matmul_like_op( + BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) ) @@ -203,8 +209,10 @@ def contract( indexing_maps: Sequence[AffineMapAttr], cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op( - ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + return _get_op_result_or_op_results( + _create_matmul_like_op( + ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) )