Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] move LinalgToStandard to Linalg as ConvertToFunctionCalls #121392

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 2 additions & 25 deletions mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,11 @@ class OperationPass;

namespace linalg {

//===----------------------------------------------------------------------===//
// Patterns to convert a LinalgOp to func.call @external library implementation.
//===----------------------------------------------------------------------===//
// These patterns are exposed individually because they are expected to be
// typically used individually.

// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
// function. The implementation of the function can be either in the same module
// or in an externally linked library.
// This is a generic entry point for all LinalgOp, except for CopyOp, for which
// more specialized patterns are provided.
class LinalgOpToLibraryCallRewrite
: public OpInterfaceRewritePattern<LinalgOp> {
public:
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;

LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override;
};

/// Populate the given list with patterns that convert from Linalg to Standard.
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);

} // namespace linalg

/// Create a pass to convert Linalg operations to the Standard dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();

} // namespace linalg

} // namespace mlir

#endif // MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
11 changes: 0 additions & 11 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -713,17 +713,6 @@ def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
];
}

//===----------------------------------------------------------------------===//
// LinalgToStandard
//===----------------------------------------------------------------------===//

def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
let summary = "Convert the operations from the linalg dialect into the "
"Standard dialect";
let constructor = "mlir::createConvertLinalgToStandardPass()";
let dependentDialects = ["func::FuncDialect", "memref::MemRefDialect"];
}

//===----------------------------------------------------------------------===//
// MathToLibm
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def ConvertLinalgToParallelLoopsPass
];
}

def ConvertLinalgToFunctionCallsPass
: Pass<"convert-linalg-to-function-calls", "ModuleOp"> {
let summary = "Convert the operations from the Linalg dialect into "
"function calls";
let dependentDialects = ["func::FuncDialect", "LLVM::LLVMDialect"];
}

def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
let options = [
Expand Down
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1885,6 +1885,32 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);

//===----------------------------------------------------------------------===//
// Patterns to convert a LinalgOp to func.call @external library implementation.
//
// These patterns are exposed individually because they are expected to be
// typically used individually.
//===----------------------------------------------------------------------===//

// Creates a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
// function. The implementation of the function can be either in the same module
// or in an externally linked library.
// This is a generic entry point for all LinalgOp, except for CopyOp, for which
// more specialized patterns are provided.
class LinalgOpToLibraryCallRewrite
: public OpInterfaceRewritePattern<LinalgOp> {
public:
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;

LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override;
};

/// Populates the given list with patterns that convert from Linalg to library
/// calls using the `func` dialect.
void populateLinalgToFunctionCallsConversionPatterns(
RewritePatternSet &patterns);

} // namespace linalg
} // namespace mlir

Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ add_subdirectory(GPUToSPIRV)
add_subdirectory(GPUToVulkan)
add_subdirectory(IndexToLLVM)
add_subdirectory(IndexToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(LLVMCommon)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
Expand Down
23 changes: 0 additions & 23 deletions mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt

This file was deleted.

2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
EliminateEmptyTensors.cpp
EraseUnusedOperandsAndResults.cpp
FoldAddIntoDest.cpp
FunctionCalls.cpp
FusePadOpWithLinalgProducer.cpp
Fusion.cpp
Generalization.cpp
Expand Down Expand Up @@ -68,6 +69,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRMeshTransforms
MLIRLinalgDialect
MLIRLinalgUtils
MLIRLLVMDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRPass
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
//===- LinalgToFunctionCalls.cpp - Linalg to function calls conversion ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
#include "mlir/Conversion/Passes.h.inc"
#define GEN_PASS_DEF_CONVERTLINALGTOFUNCTIONCALLSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

using namespace mlir;
Expand Down Expand Up @@ -123,35 +122,30 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
return success();
}

/// Populate the given list with patterns that convert from Linalg to Standard.
void mlir::linalg::populateLinalgToStandardConversionPatterns(
void mlir::linalg::populateLinalgToFunctionCallsConversionPatterns(
RewritePatternSet &patterns) {
// TODO: ConvOp conversion needs to export a descriptor with relevant
// attribute values such as kernel striding and dilation.
patterns.add<LinalgOpToLibraryCallRewrite>(patterns.getContext());
}

namespace {
struct ConvertLinalgToStandardPass
: public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
struct ConvertLinalgToFunctionCallsPass
: public impl::ConvertLinalgToFunctionCallsPassBase<
ConvertLinalgToFunctionCallsPass> {
void runOnOperation() override;
};
} // namespace

void ConvertLinalgToStandardPass::runOnOperation() {
void ConvertLinalgToFunctionCallsPass::runOnOperation() {
auto module = getOperation();
ConversionTarget target(getContext());
target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
func::FuncDialect, memref::MemRefDialect,
scf::SCFDialect>();
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
RewritePatternSet patterns(&getContext());
populateLinalgToStandardConversionPatterns(patterns);
populateLinalgToFunctionCallsConversionPatterns(patterns);
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}

std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertLinalgToStandardPass() {
return std::make_unique<ConvertLinalgToStandardPass>();
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-linalg-to-std -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-linalg-to-function-calls -split-input-file --verify-diagnostics | FileCheck %s

func.func private @printMemrefF32(memref<*xf32>)

Expand Down Expand Up @@ -99,3 +99,85 @@ func.func @test_add(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8
ins(%D, %E: memref<16xf32>, memref<16xf32>) outs(%F: memref<16xf32>)
return
}

// -----

func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
%arg1: memref<?xf32, strided<[1], offset: ?>>,
%arg2: memref<f32>) {
linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
memref<?xf32, strided<[1], offset: ?>>)
outs(%arg2: memref<f32>)
return
}
// CHECK-LABEL: func @dot(
// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
// CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
// CHECK: %[[o0:.*]] = memref.cast %[[arg0]] :
// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[o1:.*]] = memref.cast %[[arg1]] :
// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[o2:.*]] = memref.cast %[[arg2]] :
// CHECK-SAME: memref<f32> to memref<f32, strided<[], offset: ?>>
// CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32(
// CHECK-SAME: %[[o0]], %[[o1]], %[[o2]]) :
// CHECK-SAME: memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>

// -----

#matmul_accesses = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>
]
#matmul_trait = {
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "external_outerproduct_matmul"
}

!vector_type_A = vector<4xf32>
!vector_type_B = vector<4xf32>
!vector_type_C = vector<4x4xf32>

!matrix_type_A = memref<?x?x!vector_type_A>
!matrix_type_B = memref<?x?x!vector_type_B>
!matrix_type_C = memref<?x?x!vector_type_C>

func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
linalg.generic #matmul_trait
ins(%A, %B : !matrix_type_A, !matrix_type_B)
outs(%C : !matrix_type_C) {
^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
%d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
linalg.yield %d: !vector_type_C
}
return
}
// CHECK-LABEL: func @matmul_vec_impl(
// CHECK: call @external_outerproduct_matmul(%{{.*}}) :

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>

func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) {
// expected-error @below {{failed to legalize}}
%0 = linalg.generic {
indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?xf32>
return
}

// -----

func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
// expected-error @below {{failed to legalize}}
%0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
81 changes: 0 additions & 81 deletions mlir/test/Dialect/Linalg/standard.mlir

This file was deleted.

Loading