Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --convert-mpi-to-llvm pass #14

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_MPITOLLVM_H
#define MLIR_CONVERSION_MPITOLLVM_H

#include "mlir/IR/DialectRegistry.h"

namespace mlir {

class LLVMTypeConverter;
class RewritePatternSet;

#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"

namespace mpi {
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);

void registerConvertMPIToLLVMInterface(DialectRegistry &registry);

} // namespace mpi
} // namespace mlir

#endif // MLIR_CONVERSION_MPITOLLVM_H
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
//===----------------------------------------------------------------------===//

def MPI_Retval : MPI_Type<"Retval", "retval"> {
let summary = "MPI function call return value";
let summary = "MPI function call return value (!mpi.retval)";
let description = [{
This type represents a return value from an MPI function call.
This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef MLIR_INITALLEXTENSIONS_H_
#define MLIR_INITALLEXTENSIONS_H_

#include "Conversion/MPIToLLVM/MPIToLLVM.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
Expand Down Expand Up @@ -62,6 +63,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
registerConvertFuncToLLVMInterface(registry);
index::registerConvertIndexToLLVMInterface(registry);
registerConvertMathToLLVMInterface(registry);
mpi::registerConvertMPIToLLVMInterface(registry);
registerConvertMemRefToLLVMInterface(registry);
registerConvertNVVMToLLVMInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
add_subdirectory(MPIToLLVM)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_conversion_library(MLIRMPIToLLVM
MPIToLLVM.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRMPIDialect
)
230 changes: 230 additions & 0 deletions mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Pass/Pass.h"

#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h>

using namespace mlir;

namespace {

struct InitOpLowering : ConvertOpToLLVMPattern<mpi::InitOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct CommRankOpLowering : ConvertOpToLLVMPattern<mpi::CommRankOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct FinalizeOpLowering : ConvertOpToLLVMPattern<mpi::FinalizeOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

// TODO: this was copied from GPUOpsLowering.cpp:288
// is this okay, or should this be moved to some common file?
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
ConversionPatternRewriter &rewriter,
StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
LLVM::Linkage::External);
}
return ret;
}

// TODO: this is pretty close to getOrDefineFunction, can probably be factored
LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc,
ConversionPatternRewriter &rewriter,
StringRef name,
LLVM::LLVMStructType type) {
LLVM::GlobalOp ret;
if (!(ret = moduleOp.lookupSymbol<LLVM::GlobalOp>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
ret = rewriter.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/false, LLVM::Linkage::External, name,
/*value=*/Attribute(), /*alignment=*/0, 0);
}
return ret;
}

} // namespace

//===----------------------------------------------------------------------===//
// InitOpLowering
//===----------------------------------------------------------------------===//

LogicalResult
InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// get loc
auto loc = op.getLoc();

// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());

// instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
Value llvmnull = nullPtrOp.getRes();

// grab a reference to the global module op:
auto moduleOp = op->getParentOfType<ModuleOp>();

// LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
auto initFuncType =
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp initDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);

// replace init with function call
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
ValueRange{llvmnull, llvmnull});

return success();
}

//===----------------------------------------------------------------------===//
// FinalizeOpLowering
//===----------------------------------------------------------------------===//

LogicalResult
FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// get loc
auto loc = op.getLoc();

// grab a reference to the global module op:
auto moduleOp = op->getParentOfType<ModuleOp>();

// LLVM Function type representing `i32 MPI_Finalize()`
auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
// get or create function declaration:
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter,
"MPI_Finalize", initFuncType);

// replace init with function call
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});

return success();
}

//===----------------------------------------------------------------------===//
// CommRankLowering
//===----------------------------------------------------------------------===//

LogicalResult
CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// get some helper vars
auto loc = op.getLoc();
auto context = rewriter.getContext();
auto i32 = rewriter.getI32Type();

// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(context);

// get external opaque struct pointer type
auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);

// grab a reference to the global module op:
auto moduleOp = op->getParentOfType<ModuleOp>();

// make sure global op definition exists
getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
commStructT);

// get address of @MPI_COMM_WORLD
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
auto commWorld = rewriter.create<LLVM::AddressOfOp>(
loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD"));

// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);

// replace init with function call
auto callOp = rewriter.create<LLVM::CallOp>(
loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()});

// load the rank into a register
auto loadedRank =
rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());

// if retval is checked, replace uses of retval with the results from the call
// op
SmallVector<Value> replacements;
if (op.getRetval()) {
replacements.push_back(callOp.getResult());
}
// replace all uses, then erase op
replacements.push_back(loadedRank.getRes());
rewriter.replaceOp(op, replacements);

return success();
}

//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//

void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<InitOpLowering>(converter);
patterns.add<CommRankOpLowering>(converter);
patterns.add<FinalizeOpLowering>(converter);
}

//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//

namespace {
/// Implement the interface to convert Func to LLVM.
struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
/// Hook for derived dialect interface to provide conversion patterns
/// and mark dialect legal for the conversion target.
void populateConvertToLLVMConversionPatterns(
ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) const final {
mpi::populateMPIToLLVMConversionPatterns(typeConverter, patterns);
}
};
} // namespace

void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
dialect->addInterfaces<FuncToLLVMDialectInterface>();
});
}
40 changes: 40 additions & 0 deletions mlir/test/Conversion/MPIToLLVM/ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: mlir-opt -convert-to-llvm %s | FileCheck %s

module {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32

func.func @mpi_test(%arg0: memref<100xf32>) {
%0 = mpi.init : !mpi.retval
// CHECK: %7 = llvm.mlir.zero : !llvm.ptr
// CHECK-NEXT: %8 = llvm.call @MPI_Init(%7, %7) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %8 : i32 to !mpi.retval


%retval, %rank = mpi.comm_rank : !mpi.retval, i32
// CHECK: %10 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %11 = llvm.alloca %10 x i32 : (i32) -> !llvm.ptr
// CHECK-NEXT: %12 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
// CHECK-NEXT: %13 = llvm.call @MPI_Comm_rank(%12, %11) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK-NEXT: %14 = llvm.load %11 : !llvm.ptr -> i32
// CHECK-NEXT: %15 = builtin.unrealized_conversion_cast %13 : i32 to !mpi.retval

mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32

%1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval

mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32

%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval

%3 = mpi.finalize : !mpi.retval
// CHECK: %18 = llvm.call @MPI_Finalize() : () -> i32

%4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1

%5 = mpi.error_class %0 : !mpi.retval
return
}
}