diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 244db23925ab3..5986626a72729 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -178,6 +178,14 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> { ]; } +// TODO: +def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> { + let description = [{ + TODO + }]; + let cppNamespace = "::mlir::linalg"; +} + def LinalgFillOpInterface : OpInterface<"FillOpInterface"> { let description = [{ A fill operation is defined in general terms: diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index 845a34e90bc09..fe0e826f6b771 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -18,6 +18,7 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/IR/OpAsmInterface.td" //===----------------------------------------------------------------------===// @@ -27,7 +28,7 @@ include "mlir/IR/OpAsmInterface.td" class Linalg_RelayoutOp traits = []> : Op, - DestinationStyleOpInterface, + DestinationStyleOpInterface, LinalgRelayoutOpInterface, ConditionallySpeculatable, NoMemoryEffect, DeclareOpInterfaceMethods, TypesMatchWith<"result type matches type of dest", diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 011f8601fa95f..34a3ebfe82f22 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3935,9 +3935,10 @@ struct FoldTensorCastProducerOp LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override { - // Reject PackOp/UnpackOp - there are dedicated patterns for that instead. + // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns + // for that instead. if (!foldTensorCastPrecondition(op) || - isa(*op)) + isa(*op)) return failure(); SmallVector newResultTypes(op->getResultTypes());