Skip to content

Commit

Permalink
Changes necessary for llvm/llvm-project#123902
Browse files Browse the repository at this point in the history
  • Loading branch information
qedawkins committed Feb 14, 2025
1 parent 8fab35c commit 40fd49f
Show file tree
Hide file tree
Showing 71 changed files with 463 additions and 458 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
}
}
// CHECK-LABEL: util.func public @lhs_encoding
// CHECK: tensor.pack
// CHECK: tensor.unpack
// CHECK: linalg.pack
// CHECK: linalg.unpack
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ matchDAGForUKernel(RewriterBase &rewriter, linalg::Mmt4DOp op,
}

static FailureOr<IREE::Codegen::UKernelOpInterface>
matchDAGForUKernel(RewriterBase &rewriter, tensor::PackOp op,
matchDAGForUKernel(RewriterBase &rewriter, linalg::PackOp op,
bool /*skipIntermediateRoundings*/) {
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
const char ukernelName[] = "pack";
Expand Down Expand Up @@ -386,7 +386,7 @@ matchDAGForUKernel(RewriterBase &rewriter, tensor::PackOp op,
}

static FailureOr<IREE::Codegen::UKernelOpInterface>
matchDAGForUKernel(RewriterBase &rewriter, tensor::UnPackOp op,
matchDAGForUKernel(RewriterBase &rewriter, linalg::UnPackOp op,
bool /*skipIntermediateRoundings*/) {
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
const char ukernelName[] = "unpack";
Expand Down Expand Up @@ -616,8 +616,8 @@ void CPULowerToUKernelsPass::runOnOperation() {
// these ops.
auto allTargets = [](auto target) { return true; };
patterns.insert<LowerToUKernelPattern<linalg::Mmt4DOp>,
LowerToUKernelPattern<tensor::PackOp>,
LowerToUKernelPattern<tensor::UnPackOp>>(
LowerToUKernelPattern<linalg::PackOp>,
LowerToUKernelPattern<linalg::UnPackOp>>(
context, allTargets, skipIntermediateRoundings);
// These patterns are inherently specific to the VMVX backend.
patterns.insert<LowerToUKernelPattern<IREE::Codegen::QueryTileSizesOp>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ static void tileBatchDimsForBatchMmt4dOp(RewriterBase &rewriter,

static void tileNonPackedDimsFor3DPackOps(RewriterBase &rewriter,
FunctionOpInterface funcOp) {
funcOp.walk([&](tensor::PackOp packOp) {
funcOp.walk([&](linalg::PackOp packOp) {
if (packOp.getSourceRank() != 3 || packOp.getDestRank() != 5) {
return;
}
Expand Down Expand Up @@ -81,7 +81,7 @@ static void tileNonPackedDimsFor3DPackOps(RewriterBase &rewriter,

static void tileNonPackedDimsFor5DPUnpackOps(RewriterBase &rewriter,
FunctionOpInterface funcOp) {
funcOp.walk([&](tensor::UnPackOp unpackOp) {
funcOp.walk([&](linalg::UnPackOp unpackOp) {
if (unpackOp.getSourceRank() != 5 || unpackOp.getDestRank() != 3) {
return;
}
Expand Down Expand Up @@ -251,10 +251,10 @@ struct ConvertBatchMmt4DtoMmt4DPattern
}
};

struct Convert3DPackto2DPackPattern : public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
struct Convert3DPackto2DPackPattern : public OpRewritePattern<linalg::PackOp> {
using OpRewritePattern<linalg::PackOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::PackOp packOp,
LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override {
if (packOp.getSourceRank() != 3 || packOp.getDestRank() != 5) {
return failure();
Expand Down Expand Up @@ -309,7 +309,7 @@ struct Convert3DPackto2DPackPattern : public OpRewritePattern<tensor::PackOp> {
auto reducedDest = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, packOp.getDest(), reducedDestType);

auto newPackOp = rewriter.create<tensor::PackOp>(
auto newPackOp = rewriter.create<linalg::PackOp>(
loc, reducedSrc, reducedDest, newInnerDimsPos, packOp.getMixedTiles(),
packOp.getPaddingValue(), newOuterDimsPerm);

Expand All @@ -321,10 +321,10 @@ struct Convert3DPackto2DPackPattern : public OpRewritePattern<tensor::PackOp> {
};

struct Convert5DUnPackto4DUnPackPattern
: public OpRewritePattern<tensor::UnPackOp> {
using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
: public OpRewritePattern<linalg::UnPackOp> {
using OpRewritePattern<linalg::UnPackOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp,
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
if (unpackOp.getSourceRank() != 5 || unpackOp.getDestRank() != 3) {
return failure();
Expand Down Expand Up @@ -387,7 +387,7 @@ struct Convert5DUnPackto4DUnPackPattern
auto reducedDest = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, unpackOp.getDest(), reducedDestType);

auto newUnpackOp = rewriter.create<tensor::UnPackOp>(
auto newUnpackOp = rewriter.create<linalg::UnPackOp>(
loc, reducedSrc, reducedDest, newInnerDimsPos, unpackOp.getMixedTiles(),
newOuterDimsPerm);

Expand Down Expand Up @@ -436,8 +436,8 @@ void CPUPrepareUkernelsPass::runOnOperation() {
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
tensor::PackOp::getCanonicalizationPatterns(patterns, ctx);
tensor::UnPackOp::getCanonicalizationPatterns(patterns, ctx);
linalg::PackOp::getCanonicalizationPatterns(patterns, ctx);
linalg::UnPackOp::getCanonicalizationPatterns(patterns, ctx);
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
tensor::populateFoldTensorEmptyPatterns(patterns);
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func.func @mmt4d_bf16bf16f32(%arg0 : tensor<?x?x16x2xbf16>, %arg1 : tensor<?x?x1
func.func @pack_i8i8_x86(%arg0 : tensor<?x?xi8>, %arg1 : tensor<?x?x7x8xi8>, %arg2 : i8) -> tensor<?x?x7x8xi8> attributes {
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {ukernels = "all", target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
} {
%result = tensor.pack %arg0 padding_value(%arg2 : i8) inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
%result = linalg.pack %arg0 padding_value(%arg2 : i8) inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
: tensor<?x?xi8> -> tensor<?x?x7x8xi8>
func.return %result : tensor<?x?x7x8xi8>
}
Expand Down Expand Up @@ -315,7 +315,7 @@ func.func @pack_i8i8_x86(%arg0 : tensor<?x?xi8>, %arg1 : tensor<?x?x7x8xi8>, %ar
func.func @pack_i8i8(%arg0 : tensor<?x?xi8>, %arg1 : tensor<?x?x7x8xi8>, %arg2 : i8) -> tensor<?x?x7x8xi8> attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {ukernels = "all"}>
} {
%result = tensor.pack %arg0 padding_value(%arg2 : i8) inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
%result = linalg.pack %arg0 padding_value(%arg2 : i8) inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
: tensor<?x?xi8> -> tensor<?x?x7x8xi8>
func.return %result : tensor<?x?x7x8xi8>
}
Expand Down Expand Up @@ -344,7 +344,7 @@ func.func @pack_i8i8(%arg0 : tensor<?x?xi8>, %arg1 : tensor<?x?x7x8xi8>, %arg2 :
func.func @pack_f16f16(%arg0 : tensor<?x?xf16>, %arg1 : tensor<?x?x7x8xf16>, %arg2 : f16) -> tensor<?x?x7x8xf16> attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {ukernels = "all"}>
} {
%result = tensor.pack %arg0 padding_value(%arg2 : f16) inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
%result = linalg.pack %arg0 padding_value(%arg2 : f16) inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
: tensor<?x?xf16> -> tensor<?x?x7x8xf16>
func.return %result : tensor<?x?x7x8xf16>
}
Expand Down Expand Up @@ -373,7 +373,7 @@ func.func @pack_f16f16(%arg0 : tensor<?x?xf16>, %arg1 : tensor<?x?x7x8xf16>, %ar
func.func @pack_bf16bf16(%arg0 : tensor<?x?xbf16>, %arg1 : tensor<?x?x7x8xbf16>, %arg2 : bf16) -> tensor<?x?x7x8xbf16> attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {ukernels = "all"}>
} {
%result = tensor.pack %arg0 padding_value(%arg2 : bf16) inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
%result = linalg.pack %arg0 padding_value(%arg2 : bf16) inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
: tensor<?x?xbf16> -> tensor<?x?x7x8xbf16>
func.return %result : tensor<?x?x7x8xbf16>
}
Expand Down Expand Up @@ -401,7 +401,7 @@ func.func @pack_bf16bf16(%arg0 : tensor<?x?xbf16>, %arg1 : tensor<?x?x7x8xbf16>,
func.func @pack_i32i32_transpose_inner(%arg0 : tensor<?x?xi32>, %arg1 : tensor<?x?x7x8xi32>, %arg2 : i32) -> tensor<?x?x7x8xi32> attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {ukernels = "all"}>
} {
%result = tensor.pack %arg0 padding_value(%arg2 : i32) inner_dims_pos = [1, 0] inner_tiles = [7, 8] into %arg1
%result = linalg.pack %arg0 padding_value(%arg2 : i32) inner_dims_pos = [1, 0] inner_tiles = [7, 8] into %arg1
: tensor<?x?xi32> -> tensor<?x?x7x8xi32>
func.return %result : tensor<?x?x7x8xi32>
}
Expand Down Expand Up @@ -430,19 +430,19 @@ func.func @pack_i32i32_transpose_inner(%arg0 : tensor<?x?xi32>, %arg1 : tensor<?
func.func @pack_f32f32_transpose_inner_and_outer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?x7x8xf32>, %arg2 : f32) -> tensor<?x?x7x8xf32> attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {ukernels = "all"}>
} {
%result = tensor.pack %arg0 padding_value(%arg2 : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [7, 8] into %arg1
%result = linalg.pack %arg0 padding_value(%arg2 : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [7, 8] into %arg1
: tensor<?x?xf32> -> tensor<?x?x7x8xf32>
func.return %result : tensor<?x?x7x8xf32>
}

// -----

// Check that tensor.pack is not lowered to a microkernel by default - it should
// Check that linalg.pack is not lowered to a microkernel by default - it should
// only be on VMVX.
// CHECK: func @unpack_f16f16_default
// CHECK: tensor.unpack
// CHECK: linalg.unpack
func.func @unpack_f16f16_default(%arg0 : tensor<?x?x7x8xf16>, %arg1 : tensor<?x?xf16>) -> tensor<?x?xf16> {
%result = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
%result = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
: tensor<?x?x7x8xf16> -> tensor<?x?xf16>
func.return %result : tensor<?x?xf16>
}
Expand All @@ -468,7 +468,7 @@ func.func @unpack_f16f16_default(%arg0 : tensor<?x?x7x8xf16>, %arg1 : tensor<?x?
func.func @unpack_f16f16(%arg0 : tensor<?x?x7x8xf16>, %arg1 : tensor<?x?xf16>) -> tensor<?x?xf16> attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {ukernels = "all"}>
} {
%result = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
%result = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [7, 8] into %arg1
: tensor<?x?x7x8xf16> -> tensor<?x?xf16>
func.return %result : tensor<?x?xf16>
}
Expand All @@ -494,7 +494,7 @@ func.func @unpack_f16f16(%arg0 : tensor<?x?x7x8xf16>, %arg1 : tensor<?x?xf16>) -
func.func @unpack_i32i32_transpose_inner(%arg0 : tensor<?x?x7x8xi32>, %arg1 : tensor<?x?xi32>) -> tensor<?x?xi32> attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {ukernels = "all"}>
} {
%result = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [7, 8] into %arg1
%result = linalg.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [7, 8] into %arg1
: tensor<?x?x7x8xi32> -> tensor<?x?xi32>
func.return %result : tensor<?x?xi32>
}
Expand All @@ -520,7 +520,7 @@ func.func @unpack_i32i32_transpose_inner(%arg0 : tensor<?x?x7x8xi32>, %arg1 : te
func.func @unpack_f32f32_transpose_inner_and_outer(%arg0 : tensor<?x?x7x8xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {ukernels = "all"}>
} {
%result = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [7, 8] into %arg1
%result = linalg.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [7, 8] into %arg1
: tensor<?x?x7x8xf32> -> tensor<?x?xf32>
func.return %result : tensor<?x?xf32>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func.func @pack_without_outer_dims_perm(%arg0: tensor<1x16384x512xbf16>, %arg1:
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {ukernels = "pack", target_triple="x86_64-xyz-xyz", cpu_features=""}>
} {
%cst = arith.constant 0.000000e+00 : bf16
%pack = tensor.pack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 2] into %arg1 : tensor<1x16384x512xbf16> -> tensor<1x1024x256x16x2xbf16>
%pack = linalg.pack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 2] into %arg1 : tensor<1x16384x512xbf16> -> tensor<1x1024x256x16x2xbf16>
return %pack : tensor<1x1024x256x16x2xbf16>
}
// CHECK: func.func @pack_without_outer_dims_perm
Expand All @@ -168,7 +168,7 @@ func.func @pack_without_outer_dims_perm(%arg0: tensor<1x16384x512xbf16>, %arg1:
// CHECK-SAME: tensor<1x16384x512xbf16> to tensor<16384x512xbf16>
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
// CHECK-SAME: tensor<1x1024x256x16x2xbf16> to tensor<1024x256x16x2xbf16>
// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC_SLICE]]
// CHECK: %[[PACK:.+]] = linalg.pack %[[SRC_SLICE]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [16, 2]
// CHECK-SAME: into %[[DEST_SLICE]]

Expand All @@ -178,7 +178,7 @@ func.func @pack_with_outer_dims_perm(%arg0: tensor<484x16x64xbf16>, %arg1: tenso
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {ukernels = "pack", target_triple="x86_64-xyz-xyz", cpu_features=""}>
} {
%cst = arith.constant 0.000000e+00 : bf16
%pack = tensor.pack %arg0 padding_value(%cst : bf16) outer_dims_perm = [2, 0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1 : tensor<484x16x64xbf16> -> tensor<64x31x8x16x2xbf16>
%pack = linalg.pack %arg0 padding_value(%cst : bf16) outer_dims_perm = [2, 0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1 : tensor<484x16x64xbf16> -> tensor<64x31x8x16x2xbf16>
return %pack : tensor<64x31x8x16x2xbf16>
}
// CHECK: func.func @pack_with_outer_dims_perm
Expand All @@ -190,7 +190,7 @@ func.func @pack_with_outer_dims_perm(%arg0: tensor<484x16x64xbf16>, %arg1: tenso
// CHECK-SAME: tensor<484x16x64xbf16> to tensor<484x16xbf16>
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[ITER]]
// CHECK-SAME: tensor<64x31x8x16x2xbf16> to tensor<31x8x16x2xbf16>
// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC_SLICE]]
// CHECK: %[[PACK:.+]] = linalg.pack %[[SRC_SLICE]]
// CHECK-SAME: padding_value(%[[PAD_VAL]] : bf16)
// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 2]
// CHECK-SAME: into %[[DEST_SLICE]]
Expand All @@ -202,19 +202,19 @@ func.func @do_not_decompose_pack(%arg0: tensor<1x16384x512xbf16>, %arg1: tensor<
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {ukernels = "", target_triple="x86_64-xyz-xyz", cpu_features=""}>
} {
%cst = arith.constant 0.000000e+00 : bf16
%pack = tensor.pack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 2] into %arg1 : tensor<1x16384x512xbf16> -> tensor<1x1024x256x16x2xbf16>
%pack = linalg.pack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 2] into %arg1 : tensor<1x16384x512xbf16> -> tensor<1x1024x256x16x2xbf16>
return %pack : tensor<1x1024x256x16x2xbf16>
}
// CHECK-LABEL: func.func @do_not_decompose_pack
// CHECK: tensor.pack {{.+}} : tensor<1x16384x512xbf16> -> tensor<1x1024x256x16x2xbf16>
// CHECK: linalg.pack {{.+}} : tensor<1x16384x512xbf16> -> tensor<1x1024x256x16x2xbf16>

// -----

func.func @unpack_without_transpose(%arg0: tensor<1828x8x64x16x16xf32>) -> tensor<1828x128x1024xf32> attributes {
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {ukernels = "unpack", target_triple="x86_64-xyz-xyz", cpu_features=""}>
} {
%6 = tensor.empty() : tensor<1828x128x1024xf32>
%unpack = tensor.unpack %arg0
%unpack = linalg.unpack %arg0
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
Expand All @@ -233,7 +233,7 @@ func.func @unpack_without_transpose(%arg0: tensor<1828x8x64x16x16xf32>) -> tenso
// CHECK-SAME: : tensor<1828x8x64x16x16xf32> to tensor<8x64x16x16xf32>
// CHECK: %[[DEST_SLICE:.*]] = tensor.extract_slice %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 1024] [1, 1, 1]
// CHECK-SAME: : tensor<1828x128x1024xf32> to tensor<128x1024xf32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[SRC_SLICE]]
// CHECK: %[[UNPACK:.*]] = linalg.unpack %[[SRC_SLICE]]
// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16]
// CHECK-SAME: into %[[DEST_SLICE]] : tensor<8x64x16x16xf32> -> tensor<128x1024xf32>
// CHECK: %[[NEW_ITER_ARG:.*]] = tensor.insert_slice %[[UNPACK]] into %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 1024] [1, 1, 1]
Expand All @@ -250,7 +250,7 @@ func.func @unpack_outer_dim_transpose(%arg0: tensor<4x8x29241x16x16xf32>) -> ten
} {
%cst = arith.constant 0.000000e+00 : bf16
%4 = tensor.empty() : tensor<29241x128x64xf32>
%unpack = tensor.unpack %arg0 outer_dims_perm = [2, 1, 0] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %4 : tensor<4x8x29241x16x16xf32> -> tensor<29241x128x64xf32>
%unpack = linalg.unpack %arg0 outer_dims_perm = [2, 1, 0] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %4 : tensor<4x8x29241x16x16xf32> -> tensor<29241x128x64xf32>
return %unpack : tensor<29241x128x64xf32>
}
// CHECK-LABEL: func.func @unpack_outer_dim_transpose(
Expand All @@ -265,7 +265,7 @@ func.func @unpack_outer_dim_transpose(%arg0: tensor<4x8x29241x16x16xf32>) -> ten
// CHECK-SAME: : tensor<4x8x29241x16x16xf32> to tensor<4x8x16x16xf32>
// CHECK: %[[DEST_SLICE:.*]] = tensor.extract_slice %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 64] [1, 1, 1]
// CHECK-SAME: : tensor<29241x128x64xf32> to tensor<128x64xf32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[SRC_SLICE]]
// CHECK: %[[UNPACK:.*]] = linalg.unpack %[[SRC_SLICE]]
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16]
// CHECK-SAME: into %[[DEST_SLICE]] : tensor<4x8x16x16xf32> -> tensor<128x64xf32>
// CHECK: %[[NEW_ITER_ARG:.*]] = tensor.insert_slice %[[UNPACK]] into %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 64] [1, 1, 1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,13 +463,13 @@ static LogicalResult adaptComputeConsumerToAvoidStackAllocation(
}

/// Replaces a tensor.empty op with bufferization.alloc_tensor op which is
/// created by tiling tensor.unpack op. It is intended because tiling unpack ops
/// created by tiling linalg.unpack op. It is intended because tiling unpack ops
/// with non-perfect sizes needs extra elements. See the tiling implementation
/// of tensor.unpack op for more details.
/// of linalg.unpack op for more details.
static LogicalResult
replaceUnpackEmptyWithAllocTensor(OpBuilder &b,
mlir::FunctionOpInterface funcOp) {
funcOp.walk([&](tensor::UnPackOp unpackOp) {
funcOp.walk([&](linalg::UnPackOp unpackOp) {
if (!unpackOp->hasOneUse() ||
!isa<tensor::ExtractSliceOp>(*(unpackOp->user_begin()))) {
return;
Expand Down
Loading

0 comments on commit 40fd49f

Please sign in to comment.