Skip to content

Commit

Permalink
[VectorCombine] Fold vector.interleave2 with two constant splats (#12…
Browse files Browse the repository at this point in the history
…5144)

If we're interleaving 2 constant splats, for instance `<vscale x 8 x
i32> <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can
create a larger splat `<vscale x 8 x i64> <splat of ((777 << 32) |
666)>` first before casting it back into `<vscale x 16 x i32>`.
  • Loading branch information
mshockwave authored Feb 4, 2025
1 parent d810c74 commit 635ab51
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 0 deletions.
43 changes: 43 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class VectorCombine {
bool foldShuffleFromReductions(Instruction &I);
bool foldCastFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
bool foldInterleaveIntrinsics(Instruction &I);
bool shrinkType(Instruction &I);

void replaceValue(Value &Old, Value &New) {
Expand Down Expand Up @@ -3204,6 +3205,47 @@ bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
return true;
}

/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
/// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
/// before casting it back into `<vscale x 16 x i32>`.
bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
const APInt *SplatVal0, *SplatVal1;
if (!match(&I, m_Intrinsic<Intrinsic::vector_interleave2>(
m_APInt(SplatVal0), m_APInt(SplatVal1))))
return false;

LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
<< "\n");

auto *VTy =
cast<VectorType>(cast<IntrinsicInst>(I).getArgOperand(0)->getType());
auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
unsigned Width = VTy->getElementType()->getIntegerBitWidth();

// Just in case the cost of interleave2 intrinsic and bitcast are both
// invalid, in which case we want to bail out, we use <= rather
// than < here. Even they both have valid and equal costs, it's probably
// not a good idea to emit a high-cost constant splat.
if (TTI.getInstructionCost(&I, CostKind) <=
TTI.getCastInstrCost(Instruction::BitCast, I.getType(), ExtVTy,
TTI::CastContextHint::None, CostKind)) {
LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
<< *I.getType() << " is too high.\n");
return false;
}

APInt NewSplatVal = SplatVal1->zext(Width * 2);
NewSplatVal <<= Width;
NewSplatVal |= SplatVal0->zext(Width * 2);
auto *NewSplat = ConstantVector::getSplat(
ExtVTy->getElementCount(), ConstantInt::get(F.getContext(), NewSplatVal));

IRBuilder<> Builder(&I);
replaceValue(I, *Builder.CreateBitCast(NewSplat, I.getType()));
return true;
}

/// This is the entry point for all transforms. Pass manager differences are
/// handled in the callers of this function.
bool VectorCombine::run() {
Expand Down Expand Up @@ -3248,6 +3290,7 @@ bool VectorCombine::run() {
MadeChange |= scalarizeBinopOrCmp(I);
MadeChange |= scalarizeLoadExtract(I);
MadeChange |= scalarizeVPIntrinsic(I);
MadeChange |= foldInterleaveIntrinsics(I);
}

if (Opcode == Instruction::Store)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -mtriple=riscv64 -mattr=+v %s -passes=vector-combine | FileCheck %s
; RUN: opt -S -mtriple=riscv32 -mattr=+v %s -passes=vector-combine | FileCheck %s

; We should not form a i128 vector.

define void @interleave2_const_splat_nxv8i64(ptr %dst) {
; CHECK-LABEL: define void @interleave2_const_splat_nxv8i64(
; CHECK-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[INTERLEAVE2:%.*]] = call <vscale x 8 x i64> @llvm.vector.interleave2.nxv8i64(<vscale x 4 x i64> splat (i64 666), <vscale x 4 x i64> splat (i64 777))
; CHECK-NEXT: call void @llvm.vp.store.nxv8i64.p0(<vscale x 8 x i64> [[INTERLEAVE2]], ptr [[DST]], <vscale x 8 x i1> splat (i1 true), i32 88)
; CHECK-NEXT: ret void
;
%interleave2 = call <vscale x 8 x i64> @llvm.vector.interleave2.nxv8i64(<vscale x 4 x i64> splat (i64 666), <vscale x 4 x i64> splat (i64 777))
call void @llvm.vp.store.nxv8i64.p0(<vscale x 8 x i64> %interleave2, ptr %dst, <vscale x 8 x i1> splat (i1 true), i32 88)
ret void
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -mtriple=riscv64 -mattr=+v %s -passes=vector-combine | FileCheck %s
; RUN: opt -S -mtriple=riscv32 -mattr=+v %s -passes=vector-combine | FileCheck %s
; RUN: opt -S -mtriple=riscv64 -mattr=+zve32x %s -passes=vector-combine | FileCheck %s --check-prefix=ZVE32X

define void @interleave2_const_splat_nxv16i32(ptr %dst) {
; CHECK-LABEL: define void @interleave2_const_splat_nxv16i32(
; CHECK-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> bitcast (<vscale x 8 x i64> splat (i64 3337189589658) to <vscale x 16 x i32>), ptr [[DST]], <vscale x 16 x i1> splat (i1 true), i32 88)
; CHECK-NEXT: ret void
;
; ZVE32X-LABEL: define void @interleave2_const_splat_nxv16i32(
; ZVE32X-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
; ZVE32X-NEXT: [[INTERLEAVE2:%.*]] = call <vscale x 16 x i32> @llvm.vector.interleave2.nxv16i32(<vscale x 8 x i32> splat (i32 666), <vscale x 8 x i32> splat (i32 777))
; ZVE32X-NEXT: call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> [[INTERLEAVE2]], ptr [[DST]], <vscale x 16 x i1> splat (i1 true), i32 88)
; ZVE32X-NEXT: ret void
;
%interleave2 = call <vscale x 16 x i32> @llvm.vector.interleave2.nxv16i32(<vscale x 8 x i32> splat (i32 666), <vscale x 8 x i32> splat (i32 777))
call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> %interleave2, ptr %dst, <vscale x 16 x i1> splat (i1 true), i32 88)
ret void
}

0 comments on commit 635ab51

Please sign in to comment.