From 2448cdf9ff52e8451eceee5e412612fb71398b76 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 31 Dec 2024 15:10:56 +0100 Subject: [PATCH] [mlir] move LinalgToStandard to Linalg as ConvertToFunctionCalls The remnants of the ConvertLinalgToStandard pass were still present in the codebase under this name, years after the Standard dialect was dismantled. Practically, this pass / pattern set was only performing the rewrite of Linalg operaitons to function calls. All this makes the existence of the pass highly confusing. Move the logic under Linalg/Transforms, similarly to other "lowerings" from Linalg, e.g., the one to (affine or SCF) loops. Rename ConvertLinalgToStandard to ConvertLinalgToFunctionCalls. Merge the two relevant test files, ironically, one of them was called library-calls.mlir. Simplify the code a little. --- .../LinalgToStandard/LinalgToStandard.h | 27 +----- mlir/include/mlir/Conversion/Passes.td | 11 --- mlir/include/mlir/Dialect/Linalg/Passes.td | 7 ++ .../Dialect/Linalg/Transforms/Transforms.h | 26 ++++++ mlir/lib/Conversion/CMakeLists.txt | 1 - .../LinalgToStandard/CMakeLists.txt | 23 ----- .../Dialect/Linalg/Transforms/CMakeLists.txt | 2 + .../Linalg/Transforms/FunctionCalls.cpp} | 28 +++---- ...library-calls.mlir => function-calls.mlir} | 84 ++++++++++++++++++- mlir/test/Dialect/Linalg/standard.mlir | 81 ------------------ 10 files changed, 131 insertions(+), 159 deletions(-) delete mode 100644 mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt rename mlir/lib/{Conversion/LinalgToStandard/LinalgToStandard.cpp => Dialect/Linalg/Transforms/FunctionCalls.cpp} (87%) rename mlir/test/Dialect/Linalg/{library-calls.mlir => function-calls.mlir} (61%) delete mode 100644 mlir/test/Dialect/Linalg/standard.mlir diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h index eefa2c4724833..346cf62cdb8e8 100644 --- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h +++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h @@ -22,34 +22,11 @@ class OperationPass; namespace linalg { -//===----------------------------------------------------------------------===// -// Patterns to convert a LinalgOp to func.call @external library implementation. -//===----------------------------------------------------------------------===// -// These patterns are exposed individually because they are expected to be -// typically used individually. - -// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()` -// function. The implementation of the function can be either in the same module -// or in an externally linked library. -// This is a generic entry point for all LinalgOp, except for CopyOp, for which -// more specialized patterns are provided. -class LinalgOpToLibraryCallRewrite - : public OpInterfaceRewritePattern { -public: - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(LinalgOp op, - PatternRewriter &rewriter) const override; -}; - -/// Populate the given list with patterns that convert from Linalg to Standard. -void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns); - -} // namespace linalg - /// Create a pass to convert Linalg operations to the Standard dialect. std::unique_ptr> createConvertLinalgToStandardPass(); +} // namespace linalg + } // namespace mlir #endif // MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 58ee87cf82039..7a3ffa97bd521 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -713,17 +713,6 @@ def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> { ]; } -//===----------------------------------------------------------------------===// -// LinalgToStandard -//===----------------------------------------------------------------------===// - -def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> { - let summary = "Convert the operations from the linalg dialect into the " - "Standard dialect"; - let constructor = "mlir::createConvertLinalgToStandardPass()"; - let dependentDialects = ["func::FuncDialect", "memref::MemRefDialect"]; -} - //===----------------------------------------------------------------------===// // MathToLibm //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index d96ad919b65f0..99c6d1c14674a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -58,6 +58,13 @@ def ConvertLinalgToParallelLoopsPass ]; } +def ConvertLinalgToFunctionCallsPass + : Pass<"convert-linalg-to-function-calls", "ModuleOp"> { + let summary = "Convert the operations from the Linalg dialect into " + "function calls"; + let dependentDialects = ["func::FuncDialect", "LLVM::LLVMDialect"]; +} + def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> { let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let options = [ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 1dc700f22c202..1ae2713651287 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1885,6 +1885,32 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns); /// convert to a `linalg.dot`. void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns); +//===----------------------------------------------------------------------===// +// Patterns to convert a LinalgOp to func.call @external library implementation. +// +// These patterns are exposed individually because they are expected to be +// typically used individually. +//===----------------------------------------------------------------------===// + +// Creates a new call to the type-canonicalized `LinalgOp::getLibraryCallName()` +// function. The implementation of the function can be either in the same module +// or in an externally linked library. +// This is a generic entry point for all LinalgOp, except for CopyOp, for which +// more specialized patterns are provided. +class LinalgOpToLibraryCallRewrite + : public OpInterfaceRewritePattern { +public: + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(LinalgOp op, + PatternRewriter &rewriter) const override; +}; + +/// Populates the given list with patterns that convert from Linalg to library +/// calls using the `func` dialect. +void populateLinalgToFunctionCallsConversionPatterns( + RewritePatternSet &patterns); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 62461c0cea08a..1c7318bb584d4 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -31,7 +31,6 @@ add_subdirectory(GPUToSPIRV) add_subdirectory(GPUToVulkan) add_subdirectory(IndexToLLVM) add_subdirectory(IndexToSPIRV) -add_subdirectory(LinalgToStandard) add_subdirectory(LLVMCommon) add_subdirectory(MathToFuncs) add_subdirectory(MathToLibm) diff --git a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt b/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt deleted file mode 100644 index 7fc4af5403185..0000000000000 --- a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -add_mlir_conversion_library(MLIRLinalgToStandard - LinalgToStandard.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LinalgToStandard - - DEPENDS - MLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRFuncDialect - MLIRIR - MLIRLinalgDialect - MLIRLinalgTransforms - MLIRLLVMDialect - MLIRMemRefDialect - MLIRPass - MLIRSCFDialect - MLIRTransforms - ) diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 3594b08413812..d6bdf1d52dd1d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms EliminateEmptyTensors.cpp EraseUnusedOperandsAndResults.cpp FoldAddIntoDest.cpp + FunctionCalls.cpp FusePadOpWithLinalgProducer.cpp Fusion.cpp Generalization.cpp @@ -68,6 +69,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms MLIRMeshTransforms MLIRLinalgDialect MLIRLinalgUtils + MLIRLLVMDialect MLIRSCFDialect MLIRSCFTransforms MLIRPass diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp similarity index 87% rename from mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp rename to mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp index 4d1f35c767304..a202dac0aa232 100644 --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp @@ -1,4 +1,4 @@ -//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// +//===- LinalgToFunctionCalls.cpp - Linalg to function calls conversion ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,20 +6,19 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" - #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Pass/Pass.h" namespace mlir { -#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD -#include "mlir/Conversion/Passes.h.inc" +#define GEN_PASS_DEF_CONVERTLINALGTOFUNCTIONCALLSPASS +#include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir using namespace mlir; @@ -123,8 +122,7 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( return success(); } -/// Populate the given list with patterns that convert from Linalg to Standard. -void mlir::linalg::populateLinalgToStandardConversionPatterns( +void mlir::linalg::populateLinalgToFunctionCallsConversionPatterns( RewritePatternSet &patterns) { // TODO: ConvOp conversion needs to export a descriptor with relevant // attribute values such as kernel striding and dilation. @@ -132,13 +130,14 @@ void mlir::linalg::populateLinalgToStandardConversionPatterns( } namespace { -struct ConvertLinalgToStandardPass - : public impl::ConvertLinalgToStandardBase { +struct ConvertLinalgToFunctionCallsPass + : public impl::ConvertLinalgToFunctionCallsPassBase< + ConvertLinalgToFunctionCallsPass> { void runOnOperation() override; }; } // namespace -void ConvertLinalgToStandardPass::runOnOperation() { +void ConvertLinalgToFunctionCallsPass::runOnOperation() { auto module = getOperation(); ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); RewritePatternSet patterns(&getContext()); - populateLinalgToStandardConversionPatterns(patterns); + populateLinalgToFunctionCallsConversionPatterns(patterns); if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } - -std::unique_ptr> -mlir::createConvertLinalgToStandardPass() { - return std::make_unique(); -} diff --git a/mlir/test/Dialect/Linalg/library-calls.mlir b/mlir/test/Dialect/Linalg/function-calls.mlir similarity index 61% rename from mlir/test/Dialect/Linalg/library-calls.mlir rename to mlir/test/Dialect/Linalg/function-calls.mlir index 1fa675d8b4b68..103fcb16c5173 100644 --- a/mlir/test/Dialect/Linalg/library-calls.mlir +++ b/mlir/test/Dialect/Linalg/function-calls.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-linalg-to-std -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-linalg-to-function-calls -split-input-file --verify-diagnostics | FileCheck %s func.func private @printMemrefF32(memref<*xf32>) @@ -99,3 +99,85 @@ func.func @test_add(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8 ins(%D, %E: memref<16xf32>, memref<16xf32>) outs(%F: memref<16xf32>) return } + +// ----- + +func.func @dot(%arg0: memref>, + %arg1: memref>, + %arg2: memref) { + linalg.dot ins(%arg0, %arg1: memref>, + memref>) + outs(%arg2: memref) + return +} +// CHECK-LABEL: func @dot( +// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref>, +// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref>, +// CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref) { +// CHECK: %[[o0:.*]] = memref.cast %[[arg0]] : +// CHECK-SAME: memref> to memref> +// CHECK: %[[o1:.*]] = memref.cast %[[arg1]] : +// CHECK-SAME: memref> to memref> +// CHECK: %[[o2:.*]] = memref.cast %[[arg2]] : +// CHECK-SAME: memref to memref> +// CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32( +// CHECK-SAME: %[[o0]], %[[o1]], %[[o2]]) : +// CHECK-SAME: memref>, memref>, memref> + +// ----- + +#matmul_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmul_trait = { + iterator_types = ["parallel", "parallel", "reduction"], + indexing_maps = #matmul_accesses, + library_call = "external_outerproduct_matmul" +} + +!vector_type_A = vector<4xf32> +!vector_type_B = vector<4xf32> +!vector_type_C = vector<4x4xf32> + +!matrix_type_A = memref +!matrix_type_B = memref +!matrix_type_C = memref + +func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) { + linalg.generic #matmul_trait + ins(%A, %B : !matrix_type_A, !matrix_type_B) + outs(%C : !matrix_type_C) { + ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C): + %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B + linalg.yield %d: !vector_type_C + } + return +} +// CHECK-LABEL: func @matmul_vec_impl( +// CHECK: call @external_outerproduct_matmul(%{{.*}}) : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @func(%arg0: tensor, %arg1: tensor) { + // expected-error @below {{failed to legalize}} + %0 = linalg.generic { + indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor + return +} + +// ----- + +func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> { + // expected-error @below {{failed to legalize}} + %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32> + return %0 : tensor<4x8xf32> +} diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir deleted file mode 100644 index f50016f9ea477..0000000000000 --- a/mlir/test/Dialect/Linalg/standard.mlir +++ /dev/null @@ -1,81 +0,0 @@ -// RUN: mlir-opt %s -convert-linalg-to-std --split-input-file -verify-diagnostics | FileCheck %s - -func.func @dot(%arg0: memref>, - %arg1: memref>, - %arg2: memref) { - linalg.dot ins(%arg0, %arg1: memref>, - memref>) - outs(%arg2: memref) - return -} -// CHECK-LABEL: func @dot( -// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref>, -// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref>, -// CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref) { -// CHECK: %[[o0:.*]] = memref.cast %[[arg0]] : -// CHECK-SAME: memref> to memref> -// CHECK: %[[o1:.*]] = memref.cast %[[arg1]] : -// CHECK-SAME: memref> to memref> -// CHECK: %[[o2:.*]] = memref.cast %[[arg2]] : -// CHECK-SAME: memref to memref> -// CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32( -// CHECK-SAME: %[[o0]], %[[o1]], %[[o2]]) : -// CHECK-SAME: memref>, memref>, memref> - -// ----- - -#matmul_accesses = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (m, n)> -] -#matmul_trait = { - iterator_types = ["parallel", "parallel", "reduction"], - indexing_maps = #matmul_accesses, - library_call = "external_outerproduct_matmul" -} - -!vector_type_A = vector<4xf32> -!vector_type_B = vector<4xf32> -!vector_type_C = vector<4x4xf32> - -!matrix_type_A = memref -!matrix_type_B = memref -!matrix_type_C = memref - -func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) { - linalg.generic #matmul_trait - ins(%A, %B : !matrix_type_A, !matrix_type_B) - outs(%C : !matrix_type_C) { - ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C): - %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B - linalg.yield %d: !vector_type_C - } - return -} -// CHECK-LABEL: func @matmul_vec_impl( -// CHECK: call @external_outerproduct_matmul(%{{.*}}) : - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - -func.func @func(%arg0: tensor, %arg1: tensor) { - // expected-error @below {{failed to legalize}} - %0 = linalg.generic { - indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} - ins(%arg0 : tensor) outs(%arg1 : tensor) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor - return -} - -// ----- - -func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> { - // expected-error @below {{failed to legalize}} - %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32> - return %0 : tensor<4x8xf32> -}