diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h index eefa2c4724833..346cf62cdb8e8 100644 --- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h +++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h @@ -22,34 +22,11 @@ class OperationPass; namespace linalg { -//===----------------------------------------------------------------------===// -// Patterns to convert a LinalgOp to func.call @external library implementation. -//===----------------------------------------------------------------------===// -// These patterns are exposed individually because they are expected to be -// typically used individually. - -// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()` -// function. The implementation of the function can be either in the same module -// or in an externally linked library. -// This is a generic entry point for all LinalgOp, except for CopyOp, for which -// more specialized patterns are provided. -class LinalgOpToLibraryCallRewrite - : public OpInterfaceRewritePattern { -public: - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(LinalgOp op, - PatternRewriter &rewriter) const override; -}; - -/// Populate the given list with patterns that convert from Linalg to Standard. -void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns); - -} // namespace linalg - /// Create a pass to convert Linalg operations to the Standard dialect. std::unique_ptr> createConvertLinalgToStandardPass(); +} // namespace linalg + } // namespace mlir #endif // MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 58ee87cf82039..7a3ffa97bd521 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -713,17 +713,6 @@ def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> { ]; } -//===----------------------------------------------------------------------===// -// LinalgToStandard -//===----------------------------------------------------------------------===// - -def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> { - let summary = "Convert the operations from the linalg dialect into the " - "Standard dialect"; - let constructor = "mlir::createConvertLinalgToStandardPass()"; - let dependentDialects = ["func::FuncDialect", "memref::MemRefDialect"]; -} - //===----------------------------------------------------------------------===// // MathToLibm //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index d96ad919b65f0..99c6d1c14674a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -58,6 +58,13 @@ def ConvertLinalgToParallelLoopsPass ]; } +def ConvertLinalgToFunctionCallsPass + : Pass<"convert-linalg-to-function-calls", "ModuleOp"> { + let summary = "Convert the operations from the Linalg dialect into " + "function calls"; + let dependentDialects = ["func::FuncDialect", "LLVM::LLVMDialect"]; +} + def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> { let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let options = [ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 1dc700f22c202..1ae2713651287 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1885,6 +1885,32 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns); /// convert to a `linalg.dot`. void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns); +//===----------------------------------------------------------------------===// +// Patterns to convert a LinalgOp to func.call @external library implementation. +// +// These patterns are exposed individually because they are expected to be +// typically used individually. +//===----------------------------------------------------------------------===// + +// Creates a new call to the type-canonicalized `LinalgOp::getLibraryCallName()` +// function. The implementation of the function can be either in the same module +// or in an externally linked library. +// This is a generic entry point for all LinalgOp, except for CopyOp, for which +// more specialized patterns are provided. +class LinalgOpToLibraryCallRewrite + : public OpInterfaceRewritePattern { +public: + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(LinalgOp op, + PatternRewriter &rewriter) const override; +}; + +/// Populates the given list with patterns that convert from Linalg to library +/// calls using the `func` dialect. +void populateLinalgToFunctionCallsConversionPatterns( + RewritePatternSet &patterns); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 62461c0cea08a..1c7318bb584d4 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -31,7 +31,6 @@ add_subdirectory(GPUToSPIRV) add_subdirectory(GPUToVulkan) add_subdirectory(IndexToLLVM) add_subdirectory(IndexToSPIRV) -add_subdirectory(LinalgToStandard) add_subdirectory(LLVMCommon) add_subdirectory(MathToFuncs) add_subdirectory(MathToLibm) diff --git a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt b/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt deleted file mode 100644 index 7fc4af5403185..0000000000000 --- a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -add_mlir_conversion_library(MLIRLinalgToStandard - LinalgToStandard.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LinalgToStandard - - DEPENDS - MLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRFuncDialect - MLIRIR - MLIRLinalgDialect - MLIRLinalgTransforms - MLIRLLVMDialect - MLIRMemRefDialect - MLIRPass - MLIRSCFDialect - MLIRTransforms - ) diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 3594b08413812..d6bdf1d52dd1d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms EliminateEmptyTensors.cpp EraseUnusedOperandsAndResults.cpp FoldAddIntoDest.cpp + FunctionCalls.cpp FusePadOpWithLinalgProducer.cpp Fusion.cpp Generalization.cpp @@ -68,6 +69,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms MLIRMeshTransforms MLIRLinalgDialect MLIRLinalgUtils + MLIRLLVMDialect MLIRSCFDialect MLIRSCFTransforms MLIRPass diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp similarity index 87% rename from mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp rename to mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp index 4d1f35c767304..a202dac0aa232 100644 --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp @@ -1,4 +1,4 @@ -//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// +//===- LinalgToFunctionCalls.cpp - Linalg to function calls conversion ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,20 +6,19 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" - #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Pass/Pass.h" namespace mlir { -#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD -#include "mlir/Conversion/Passes.h.inc" +#define GEN_PASS_DEF_CONVERTLINALGTOFUNCTIONCALLSPASS +#include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir using namespace mlir; @@ -123,8 +122,7 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( return success(); } -/// Populate the given list with patterns that convert from Linalg to Standard. -void mlir::linalg::populateLinalgToStandardConversionPatterns( +void mlir::linalg::populateLinalgToFunctionCallsConversionPatterns( RewritePatternSet &patterns) { // TODO: ConvOp conversion needs to export a descriptor with relevant // attribute values such as kernel striding and dilation. @@ -132,13 +130,14 @@ void mlir::linalg::populateLinalgToStandardConversionPatterns( } namespace { -struct ConvertLinalgToStandardPass - : public impl::ConvertLinalgToStandardBase { +struct ConvertLinalgToFunctionCallsPass + : public impl::ConvertLinalgToFunctionCallsPassBase< + ConvertLinalgToFunctionCallsPass> { void runOnOperation() override; }; } // namespace -void ConvertLinalgToStandardPass::runOnOperation() { +void ConvertLinalgToFunctionCallsPass::runOnOperation() { auto module = getOperation(); ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); RewritePatternSet patterns(&getContext()); - populateLinalgToStandardConversionPatterns(patterns); + populateLinalgToFunctionCallsConversionPatterns(patterns); if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } - -std::unique_ptr> -mlir::createConvertLinalgToStandardPass() { - return std::make_unique(); -} diff --git a/mlir/test/Dialect/Linalg/library-calls.mlir b/mlir/test/Dialect/Linalg/function-calls.mlir similarity index 61% rename from mlir/test/Dialect/Linalg/library-calls.mlir rename to mlir/test/Dialect/Linalg/function-calls.mlir index 1fa675d8b4b68..103fcb16c5173 100644 --- a/mlir/test/Dialect/Linalg/library-calls.mlir +++ b/mlir/test/Dialect/Linalg/function-calls.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-linalg-to-std -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-linalg-to-function-calls -split-input-file --verify-diagnostics | FileCheck %s func.func private @printMemrefF32(memref<*xf32>) @@ -99,3 +99,85 @@ func.func @test_add(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8 ins(%D, %E: memref<16xf32>, memref<16xf32>) outs(%F: memref<16xf32>) return } + +// ----- + +func.func @dot(%arg0: memref>, + %arg1: memref>, + %arg2: memref) { + linalg.dot ins(%arg0, %arg1: memref>, + memref>) + outs(%arg2: memref) + return +} +// CHECK-LABEL: func @dot( +// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref>, +// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref>, +// CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref) { +// CHECK: %[[o0:.*]] = memref.cast %[[arg0]] : +// CHECK-SAME: memref> to memref> +// CHECK: %[[o1:.*]] = memref.cast %[[arg1]] : +// CHECK-SAME: memref> to memref> +// CHECK: %[[o2:.*]] = memref.cast %[[arg2]] : +// CHECK-SAME: memref to memref> +// CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32( +// CHECK-SAME: %[[o0]], %[[o1]], %[[o2]]) : +// CHECK-SAME: memref>, memref>, memref> + +// ----- + +#matmul_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmul_trait = { + iterator_types = ["parallel", "parallel", "reduction"], + indexing_maps = #matmul_accesses, + library_call = "external_outerproduct_matmul" +} + +!vector_type_A = vector<4xf32> +!vector_type_B = vector<4xf32> +!vector_type_C = vector<4x4xf32> + +!matrix_type_A = memref +!matrix_type_B = memref +!matrix_type_C = memref + +func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) { + linalg.generic #matmul_trait + ins(%A, %B : !matrix_type_A, !matrix_type_B) + outs(%C : !matrix_type_C) { + ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C): + %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B + linalg.yield %d: !vector_type_C + } + return +} +// CHECK-LABEL: func @matmul_vec_impl( +// CHECK: call @external_outerproduct_matmul(%{{.*}}) : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @func(%arg0: tensor, %arg1: tensor) { + // expected-error @below {{failed to legalize}} + %0 = linalg.generic { + indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor + return +} + +// ----- + +func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> { + // expected-error @below {{failed to legalize}} + %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32> + return %0 : tensor<4x8xf32> +} diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir deleted file mode 100644 index f50016f9ea477..0000000000000 --- a/mlir/test/Dialect/Linalg/standard.mlir +++ /dev/null @@ -1,81 +0,0 @@ -// RUN: mlir-opt %s -convert-linalg-to-std --split-input-file -verify-diagnostics | FileCheck %s - -func.func @dot(%arg0: memref>, - %arg1: memref>, - %arg2: memref) { - linalg.dot ins(%arg0, %arg1: memref>, - memref>) - outs(%arg2: memref) - return -} -// CHECK-LABEL: func @dot( -// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref>, -// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref>, -// CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref) { -// CHECK: %[[o0:.*]] = memref.cast %[[arg0]] : -// CHECK-SAME: memref> to memref> -// CHECK: %[[o1:.*]] = memref.cast %[[arg1]] : -// CHECK-SAME: memref> to memref> -// CHECK: %[[o2:.*]] = memref.cast %[[arg2]] : -// CHECK-SAME: memref to memref> -// CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32( -// CHECK-SAME: %[[o0]], %[[o1]], %[[o2]]) : -// CHECK-SAME: memref>, memref>, memref> - -// ----- - -#matmul_accesses = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (m, n)> -] -#matmul_trait = { - iterator_types = ["parallel", "parallel", "reduction"], - indexing_maps = #matmul_accesses, - library_call = "external_outerproduct_matmul" -} - -!vector_type_A = vector<4xf32> -!vector_type_B = vector<4xf32> -!vector_type_C = vector<4x4xf32> - -!matrix_type_A = memref -!matrix_type_B = memref -!matrix_type_C = memref - -func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) { - linalg.generic #matmul_trait - ins(%A, %B : !matrix_type_A, !matrix_type_B) - outs(%C : !matrix_type_C) { - ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C): - %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B - linalg.yield %d: !vector_type_C - } - return -} -// CHECK-LABEL: func @matmul_vec_impl( -// CHECK: call @external_outerproduct_matmul(%{{.*}}) : - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - -func.func @func(%arg0: tensor, %arg1: tensor) { - // expected-error @below {{failed to legalize}} - %0 = linalg.generic { - indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} - ins(%arg0 : tensor) outs(%arg1 : tensor) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor - return -} - -// ----- - -func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> { - // expected-error @below {{failed to legalize}} - %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32> - return %0 : tensor<4x8xf32> -}