Skip to content

Commit

Permalink
More renames + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Feb 18, 2025
1 parent 630ad41 commit 88f1c8a
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ static FailureOr<BrgemmInfo> checkAccess(linalg::LinalgOp linalgOp, unsigned m,
strideB = (*stridesOnB)[*batchPosCodomainB];
}

auto loops = linalgOp.computeStaticLoopSizes();
auto loops = linalgOp.getStaticLoopRanges();
int64_t batchVal = (batchPos) ? loops[batchPos.value()] : 0;

bool isVnni = vnni::utils::isInVnniLayout(linalgOp);
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/GPU/GpuConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct GpuConversion : public tpp::impl::GpuConversionBase<GpuConversion>,
void constructPipeline() override {
// Map loops into GPU kernels.
pm.addNestedPass<func::FuncOp>(createGpuMapParallelLoopsPass());
pm.addNestedPass<func::FuncOp>(createParallelLoopToGpuPass());
pm.addNestedPass<func::FuncOp>(createConvertParallelLoopToGpuPass());
pm.addPass(createCleanup());

// First lower linalg using custom patterns then fall back to
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,

// DPAS only works with F32 accumulators.
auto dpasResType =
VectorType::get(dpasTypeC.getShape(), FloatType::getF32(ctx));
VectorType::get(dpasTypeC.getShape(), Float32Type::get(ctx));

// Extend the accumulation values if needed.
auto convOutPrecision = !typeC.getElementType().isF32();
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ void Bufferize::runOnOperation() {
bufferization::buildBufferDeallocationPipeline(passManager, options);
}

passManager.addPass(createBufferizationToMemRefPass());
passManager.addPass(createConvertBufferizationToMemRefPass());
if (failed(runPipeline(passManager, moduleOp)))
return signalPassFailure();
}
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/IntelAMXTileConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ struct IntelAMXTileConfig : OpRewritePattern<InvokeOpTy> {
auto alloca = rewriter.create<memref::AllocaOp>(
op.getLoc(), MemRefType::get({64}, rewriter.getI8Type()));

ValueRange tileConfigInputs{alloca};
SmallVector<Value> tileConfigInputs{alloca};
rewriter.create<mlir::xsmm::IntelAMXTileConfigOp>(
op.getLoc(), tileConfigSetup, tileConfigInputs);

Expand All @@ -107,7 +107,7 @@ struct IntelAMXTileConfig : OpRewritePattern<InvokeOpTy> {
xsmm::utils::getDataType(rewriter, op.getOperand(1).getType()),
invokeOperands);

ValueRange tileResetInputs{alloca};
SmallVector<Value> tileResetInputs{alloca};
rewriter.create<mlir::xsmm::IntelAMXTileConfigOp>(
op.getLoc(), tileConfigReset, tileResetInputs);

Expand Down
10 changes: 5 additions & 5 deletions lib/TPP/Transforms/LowerPacksAndUnpacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) {
forLoops);
if (!fusedProducer)
continue;
rewriter.replaceOp(consumerPackOp, tilingResult->replacements);
rewriter.replaceOp(consumerPackOp, tilingResult->mergeResult.replacements);
}

// Tile packs.
Expand All @@ -124,7 +124,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) {
rewriter, cast<TilingInterface>(packOp.getOperation()), tileSizes);
if (failed(tilingResult))
continue;
rewriter.replaceOp(packOp, tilingResult->replacements);
rewriter.replaceOp(packOp, tilingResult->mergeResult.replacements);
}

// Tile unpacks.
Expand All @@ -136,7 +136,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) {
rewriter, cast<TilingInterface>(unPackOp.getOperation()), tileSizes);
if (failed(tilingResult))
continue;
rewriter.replaceOp(unPackOp, tilingResult->replacements);
rewriter.replaceOp(unPackOp, tilingResult->mergeResult.replacements);
}
}

Expand Down Expand Up @@ -215,7 +215,7 @@ class LowerPacksAndUnPacks
unpackTilingOptions);
if (failed(tilingResult))
return signalPassFailure();
rewriter.replaceOp(unPackOp, tilingResult->replacements);
rewriter.replaceOp(unPackOp, tilingResult->mergeResult.replacements);
});
getOperation()->walk([&](linalg::PackOp packOp) {
SmallVector<int64_t> tiles(packOp.getSourceType().getRank(), 1);
Expand All @@ -226,7 +226,7 @@ class LowerPacksAndUnPacks
packTilingOptions);
if (failed(tilingResult))
return signalPassFailure();
rewriter.replaceOp(packOp, tilingResult->replacements);
rewriter.replaceOp(packOp, tilingResult->mergeResult.replacements);
});
RewritePatternSet patterns(&getContext());
patterns.add<linalg::DecomposeOuterUnitDimsUnPackOpPattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace {
void lowerPackAndFoldTranspose(linalg::PackOp packOp,
linalg::GenericOp genericOp, uint operandIdx,
PatternRewriter &rewriter) {
auto packInversionPerm = tensor::getPackInverseDestPerm(packOp);
auto packInversionPerm = linalg::getPackInverseDestPerm(packOp);

auto res = linalg::lowerPack(rewriter, packOp);

Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/RewriteBatchMatmulToMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ struct RewriteBatchMatmulToMatmul
tilingOpts);
if (failed(tilingResult))
return signalPassFailure();
rewriter.replaceOp(batchMatmulOp, tilingResult->replacements);
rewriter.replaceOp(batchMatmulOp, tilingResult->mergeResult.replacements);
});

// Step2:
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/SplitReductionDim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct SplitContractionReduction
return rewriter.notifyMatchFailure(linalgOp,
"failed to tile contraction");

rewriter.replaceOp(linalgOp, tilingResult->replacements);
rewriter.replaceOp(linalgOp, tilingResult->mergeResult.replacements);

return success();
}
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ struct PackVNNI : public tpp::impl::PackVNNIBase<PackVNNI> {
RewritePatternSet patterns(ctx);
linalg::populateLinalgDeGeneralizationPatterns(patterns);
patterns.add<VNNIOnMatmul, VNNIOnBRGemm>(ctx);
tensor::populateSimplifyPackAndUnpackPatterns(patterns);
linalg::populateSimplifyPackAndUnpackPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
Expand Down Expand Up @@ -831,7 +831,7 @@ struct SimplifyAndCanonicalizePack

void mlir::tpp::populateSimplifyPacking(RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
tensor::populateSimplifyPackAndUnpackPatterns(patterns);
linalg::populateSimplifyPackAndUnpackPatterns(patterns);
tensor::populateFoldTensorEmptyPatterns(patterns);
linalg::PackOp::getCanonicalizationPatterns(patterns, ctx);
linalg::UnPackOp::getCanonicalizationPatterns(patterns, ctx);
Expand Down
18 changes: 10 additions & 8 deletions lib/TPP/Transforms/TransformUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,16 @@ Value getSliceOperand(OpBuilder &builder, linalg::LinalgOp linalgOp,
assert(rank == strides.size() && "expect rank == strides");

Location loc = linalgOp.getLoc();
Type reducedType =
(linalgOp.hasPureTensorSemantics())
? tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
desiredResultRank, cast<RankedTensorType>(operandType), offsets,
sizes, strides)
: memref::SubViewOp::inferRankReducedResultType(
getExpectedResultMemRefShape(sizes, desiredResultRank),
cast<MemRefType>(operandType), offsets, sizes, strides);
Type reducedType;
if (linalgOp.hasPureTensorSemantics()) {
reducedType = tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
desiredResultRank, cast<RankedTensorType>(operandType), offsets, sizes,
strides);
} else {
reducedType = memref::SubViewOp::inferRankReducedResultType(
getExpectedResultMemRefShape(sizes, desiredResultRank),
cast<MemRefType>(operandType), offsets, sizes, strides);
}

Operation *extractOperation =
(linalgOp.hasPureTensorSemantics())
Expand Down

0 comments on commit 88f1c8a

Please sign in to comment.