Skip to content

Commit

Permalink
Bump LLVM (#1015)
Browse files Browse the repository at this point in the history
Fixes after bump:
  - Update pack/unpack op from Tensor to Linalg
  - Update calls to upstream APIs
  - Improve pipelines - lower UB dialect, better pass ordering
  - Update references: mlir-cpu-runner -> mlir-runner
  - Other minor syntax updates
  • Loading branch information
adam-smnk authored Feb 18, 2025
1 parent 37f936c commit cb1e22f
Show file tree
Hide file tree
Showing 83 changed files with 455 additions and 454 deletions.
2 changes: 1 addition & 1 deletion benchmarks/mlir/fp32-pack-gemm-operand-a-512x1024.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// BENCH_TOTAL_FLOPS: 2097152

func.func @entry(%arg0: tensor<512x1024xf32>, %arg1: tensor<16x32x32x32xf32>) -> tensor<16x32x32x32xf32> {
%pack = tensor.pack %arg0
%pack = linalg.pack %arg0
inner_dims_pos = [0, 1]
inner_tiles = [32, 32]
into %arg1 : tensor<512x1024xf32> -> tensor<16x32x32x32xf32>
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/mlir/fp32-pack-gemm-operand-b-512x1024.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// BENCH_TOTAL_FLOPS: 2097152

func.func @entry(%arg0: tensor<1024x512xf32>, %arg1: tensor<16x32x32x32xf32>) -> tensor<16x32x32x32xf32> {
%0 = tensor.pack %arg0
%0 = linalg.pack %arg0
outer_dims_perm = [1, 0]
inner_dims_pos = [0, 1]
inner_tiles = [32, 32]
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/mlir/fp32-unpack-gemm-operand-a-512x512.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// BENCH_TOTAL_FLOPS: 1048576

func.func @entry(%arg0: tensor<16x16x32x32xf32>, %arg1: tensor<512x512xf32>) -> tensor<512x512xf32> {
%unpack = tensor.unpack %arg0
%unpack = linalg.unpack %arg0
inner_dims_pos = [0, 1]
inner_tiles = [32, 32]
into %arg1 : tensor<16x16x32x32xf32> -> tensor<512x512xf32>
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3654f1baa66f524c89e40ab24e18e594e56363e9
2b71df5a74cb5bd67f3f34277749dc920fd35105
4 changes: 2 additions & 2 deletions docs/TPPDialect.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,12 @@ Should be fused with the user(s).
GEMM ops have transposed versions, we should use this op to annotate operands.

## Tensor pack
The tensor operation `tensor.pack` does a "block transpose" (n,m <-> m,n) copies.
The tensor operation `linalg.pack` does a "block transpose" (n,m <-> m,n) copies.
We lower this to a series of `tpp.copy` into temporary tiles if needed.
But the idea is that all constant tensors would have been packed by the compiler already and all input packs would be combined at the beginning.

## Tensor Unpack
The tensor operation `tensor.unpack` does a "block transpose" (n,m <-> m,n) copies.
The tensor operation `linalg.unpack` does a "block transpose" (n,m <-> m,n) copies.

## VNNI Pack
Packs into VNNI shape.
Expand Down
2 changes: 1 addition & 1 deletion include/TPP/IR/StructuredOpMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ struct HasStaticStrides {
SmallVector<int64_t> strides;
if (auto memRefType = dyn_cast_or_null<MemRefType>(operandType)) {
int64_t offset;
if (failed(getStridesAndOffset(memRefType, strides, offset)))
if (failed(memRefType.getStridesAndOffset(strides, offset)))
return false;
if (llvm::any_of(strides, [](int64_t stride) {
return stride == ShapedType::kDynamic;
Expand Down
14 changes: 7 additions & 7 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -262,26 +262,26 @@ def CombineXsmmOpPass : Pass<"combine-xsmm-op-optimization", "func::FuncOp"> {
}

def PropagatePackUnPack : Pass<"propagate-pack-and-unpack", "func::FuncOp"> {
let summary = "Propagate tensor.pack and tensor.unpack";
let summary = "Propagate linalg.pack and linalg.unpack";
let description = [{
Attempt to push tensor.pack and tensor.unpack at the boundaries. Currently,
Attempt to push linalg.pack and linalg.unpack at the boundaries. Currently,
it propagates through linalg element-wise operations. Only one operand in the
generic must come from a tensor.pack/tensor.unpack.
generic must come from a linalg.pack/linalg.unpack.
}];
}

def SimplifyAndCanonicalizePack : Pass<"simplify-pack", "func::FuncOp"> {
let summary = "Simplify and canonicalize tensor.pack";
let summary = "Simplify and canonicalize linalg.pack";
let description = [{
Apply `tensor.pack` and `tensor.unpack` canonicalization and simplification
Apply `linalg.pack` and `linalg.unpack` canonicalization and simplification
patterns.
}];
}

def ConstantFoldPack : Pass<"constant-fold-pack", "ModuleOp"> {
let summary = "Constant fold tensor.pack";
let summary = "Constant fold linalg.pack";
let description = [{
Reduce pack overhead by folding tensor.pack into constant tensors.
Reduce pack overhead by folding linalg.pack into constant tensors.
}];
let dependentDialects = ["linalg::LinalgDialect",
"tensor::TensorDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ struct ConvertCheckToLoops
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateCheckToLoopsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ struct ConvertLinalgToFunc
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<ConvertMatmulOp>(ctx);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ static FailureOr<BrgemmInfo> checkAccess(linalg::LinalgOp linalgOp, unsigned m,
strideB = (*stridesOnB)[*batchPosCodomainB];
}

auto loops = linalgOp.computeStaticLoopSizes();
auto loops = linalgOp.getStaticLoopRanges();
int64_t batchVal = (batchPos) ? loops[batchPos.value()] : 0;

bool isVnni = vnni::utils::isInVnniLayout(linalgOp);
Expand Down Expand Up @@ -847,7 +847,7 @@ void ConvertLinalgToXsmm::runOnOperation() {
SmallVector<StringRef> skipPatterns(skipOperations.begin(),
skipOperations.end());
tpp::populateLinalgToXsmmPatterns(patterns, skipPatterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}

Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Conversion/ConvertPerfToFunc/ConvertPerfToFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ struct ConvertPerfToFunc
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populatePerfToFuncPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct ConvertPerfToLoops
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populatePerfToLoopsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ struct ConvertVectorToXsmm

void runOnOperation() final {
PatternRewriter rewriter(&getContext());
if (failed(applyPatternsAndFoldGreedily(getOperation(), patterns))) {
if (failed(applyPatternsGreedily(getOperation(), patterns))) {
signalPassFailure();
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ struct ConvertXsmmToFunc
ConvertGemmDispatchOp, ConvertBrgemmDispatchOp,
ConvertFusedBrgemmOp, ConvertIntelAMXTileConfigDispatchOp>(
patterns.getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
13 changes: 8 additions & 5 deletions lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,25 +192,28 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
options.amx = vnni::utils::hasAMX();
pm.addPass(createConvertVectorToLLVMPass(options));
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addPass(createConvertSCFToCFPass());
pm.addPass(createSCFToControlFlowPass());
if (defParallel)
pm.addPass(createConvertOpenMPToLLVMPass());
pm.addPass(createConvertMathToLLVMPass());

pm.addNestedPass<func::FuncOp>(createGpuAsyncRegionPass());
pm.addPass(createGpuToLLVMConversionPass());
GpuModuleToBinaryPassOptions gpuModuleToBinaryPassOptions;
gpuModuleToBinaryPassOptions.compilationTarget = "fatbin";
pm.addPass(createGpuModuleToBinaryPass(gpuModuleToBinaryPassOptions));
pm.addPass(createConvertMathToLLVMPass());
pm.addPass(createAsyncToAsyncRuntimePass());
pm.addPass(createAsyncRuntimeRefCountingPass());
pm.addPass(createConvertAsyncToLLVMPass());
pm.addPass(createConvertIndexToLLVMPass());

pm.addPass(createConvertFuncToLLVMPass());

pm.addNestedPass<func::FuncOp>(createArithToLLVMConversionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addPass(createArithToLLVMConversionPass());
pm.addPass(createConvertControlFlowToLLVMPass());
pm.addPass(createUBToLLVMConversionPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addPass(createReconcileUnrealizedCastsPass());

// Anything useful has been lowered by now.
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct DefaultTppPasses
if (linalgToLoops) {
// Lower linalg directly to loops.
// Skip all TPP transformations.
// Generalize tensor.pack and tensor.unpack.
// Generalize linalg.pack and linalg.unpack.
pm.addPass(createLowerPacksAndUnPacks());
pm.addNestedPass<func::FuncOp>(createDecomposeAggregatedOps());
pm.addPass(createBufferize());
Expand All @@ -120,7 +120,7 @@ struct DefaultTppPasses
TppMappingOptions tppMappingOptions{lowerPackUnpackWithoutTranspose};
pm.addPass(createTppMapping(tppMappingOptions));

// Generalize tensor.pack and tensor.unpack.
// Generalize linalg.pack and linalg.unpack.
pm.addPass(createLowerPacksAndUnPacks());
pm.addPass(createCleanup());

Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ getVectorUnaryInfo(MemRefType shapedType, MemRefType inputType,
SmallVector<int64_t> strides;
int64_t offset;

if (failed(getStridesAndOffset(memrefType, strides, offset))) {
if (failed(memrefType.getStridesAndOffset(strides, offset))) {
return failure();
}
if (strides.empty()) {
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/GPU/GpuConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct GpuConversion : public tpp::impl::GpuConversionBase<GpuConversion>,
void constructPipeline() override {
// Map loops into GPU kernels.
pm.addNestedPass<func::FuncOp>(createGpuMapParallelLoopsPass());
pm.addNestedPass<func::FuncOp>(createParallelLoopToGpuPass());
pm.addNestedPass<func::FuncOp>(createConvertParallelLoopToGpuPass());
pm.addPass(createCleanup());

// First lower linalg using custom patterns then fall back to
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/GPU/GpuDataTransfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class GpuDataTransfer : public tpp::impl::GpuDataTransferBase<GpuDataTransfer> {
RewritePatternSet patterns(ctx);
// TODO: Add cleanup patterns to minimize data copies.
patterns.add<TransferDataToGpu>(ctx);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/GPU/GpuInlineConstants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct GpuInlineConstants
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateGpuInlineConstantsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/GPU/GpuToCuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct GpuToCuda : public tpp::impl::GpuToCudaBase<GpuToCuda>,
pm.addNestedPass<gpu::GPUModuleOp>(arith::createArithExpandOpsPass());
pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());
pm.addNestedPass<gpu::GPUModuleOp>(createConvertVectorToSCFPass());
pm.addNestedPass<gpu::GPUModuleOp>(createConvertSCFToCFPass());
pm.addNestedPass<gpu::GPUModuleOp>(createSCFToControlFlowPass());

pm.addNestedPass<gpu::GPUModuleOp>(createConvertNVGPUToNVVMPass());
pm.addNestedPass<gpu::GPUModuleOp>(createConvertGpuOpsToNVVMOps());
Expand All @@ -77,6 +77,7 @@ struct GpuToCuda : public tpp::impl::GpuToCudaBase<GpuToCuda>,
pm.addNestedPass<gpu::GPUModuleOp>(createConvertFuncToLLVMPass());
pm.addNestedPass<gpu::GPUModuleOp>(createArithToLLVMConversionPass());
pm.addNestedPass<gpu::GPUModuleOp>(createConvertIndexToLLVMPass());
pm.addNestedPass<gpu::GPUModuleOp>(createUBToLLVMConversionPass());

GpuNVVMAttachTargetOptions nvvmTargetOptions;
nvvmTargetOptions.triple = gpuTriple;
Expand All @@ -85,7 +86,6 @@ struct GpuToCuda : public tpp::impl::GpuToCudaBase<GpuToCuda>,
pm.addPass(createGpuNVVMAttachTarget(nvvmTargetOptions));

// Create CUDA kernels.
pm.addNestedPass<gpu::GPUModuleOp>(createStripDebugInfoPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
pm.addNestedPass<gpu::GPUModuleOp>(createReconcileUnrealizedCastsPass());
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/GPU/GpuVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct GpuVectorize : public tpp::impl::GpuVectorizeBase<GpuVectorize> {
vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
6 changes: 3 additions & 3 deletions lib/TPP/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,

// DPAS only works with F32 accumulators.
auto dpasResType =
VectorType::get(dpasTypeC.getShape(), FloatType::getF32(ctx));
VectorType::get(dpasTypeC.getShape(), Float32Type::get(ctx));

// Extend the accumulation values if needed.
auto convOutPrecision = !typeC.getElementType().isF32();
Expand Down Expand Up @@ -1397,12 +1397,12 @@ struct LinalgToXeGPU : public tpp::impl::LinalgToXeGPUBase<LinalgToXeGPU> {
// Run GEMM pattern first to allow fusion with its consumers.
RewritePatternSet gemmPatterns(&getContext());
populateLinalgGemmToXeGPUPatterns(gemmPatterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(gemmPatterns));
(void)applyPatternsGreedily(getOperation(), std::move(gemmPatterns));

// Convert other remaining ops.
RewritePatternSet patterns(&getContext());
populateLinalgEltwiseToXeGPUPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Runner/MLIRBench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ LogicalResult MLIRBench::findKernel(StringRef name) {

} else {
// If there is no entry function, and multiple functions, bail
return module.emitError("No valid entry point, use mlir-cpu-runner");
return module.emitError("No valid entry point, use mlir-runner");
}

// Ignore functions that return more than one result
auto funcType = kernel.getFunctionType();
if (funcType.getNumResults() > 1)
return module.emitError("Multiple return values, use mlir-cpu-runner");
return module.emitError("Multiple return values, use mlir-runner");

return success();
}
Expand Down
3 changes: 1 addition & 2 deletions lib/TPP/Transforms/BrgemmLinalgTiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,7 @@ struct BrgemmLinalgTiling : public tpp::impl::BrgemmLinalgTilingBase<BrgemmLinal
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
};
} // namespace tpp
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ void Bufferize::runOnOperation() {
bufferization::buildBufferDeallocationPipeline(passManager, options);
}

passManager.addPass(createBufferizationToMemRefPass());
passManager.addPass(createConvertBufferizationToMemRefPass());
if (failed(runPipeline(passManager, moduleOp)))
return signalPassFailure();
}
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/CombineXsmmPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ struct CombineXsmmOpPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateCombinePatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
} // namespace
14 changes: 7 additions & 7 deletions lib/TPP/Transforms/ConstantFoldPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ namespace tpp {

namespace {

// Helper pattern - lower tensor.pack operations that pack constants.
struct LowerConstantPacking : public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
// Helper pattern - lower linalg.pack operations that pack constants.
struct LowerConstantPacking : public OpRewritePattern<linalg::PackOp> {
using OpRewritePattern<linalg::PackOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::PackOp packOp,
LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override {
auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
if (!constOp)
Expand All @@ -52,7 +52,7 @@ struct LowerConstantPacking : public OpRewritePattern<tensor::PackOp> {
return rewriter.notifyMatchFailure(
packOp, "expects destination with static shape");

// If it is a splat constant, skip and let tensor.pack folder to handle this
// If it is a splat constant, skip and let linalg.pack folder to handle this
// case.
if (denseAttr.isSplat())
return rewriter.notifyMatchFailure(
Expand All @@ -77,13 +77,13 @@ struct ConstantFoldPack
// Apply canonicalization to fold trivial cases and linalg constant folders
// to cleanup lowered packs.
linalg::FillOp::getCanonicalizationPatterns(patterns, ctx);
tensor::PackOp::getCanonicalizationPatterns(patterns, ctx);
linalg::PackOp::getCanonicalizationPatterns(patterns, ctx);
tensor::populateRewriteAsConstantPatterns(
patterns, [](OpOperand *) -> bool { return true; });
linalg::populateConstantFoldLinalgOperations(
patterns, [](OpOperand *) -> bool { return true; });

(void)applyPatternsAndFoldGreedily(module, std::move(patterns));
(void)applyPatternsGreedily(module, std::move(patterns));
}
};

Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/ConvInitSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ struct ConvInitSimplify
void runOnOperation() override {
RewritePatternSet patterns(getOperation().getContext());
patterns.add<EliminateZeroInitAndAddBiasToInit>(patterns.getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Loading

0 comments on commit cb1e22f

Please sign in to comment.