From 1c207f1b6e8bba69dfbbcbd72704b4d720e363d0 Mon Sep 17 00:00:00 2001 From: vporpo Date: Wed, 12 Feb 2025 15:06:30 -0800 Subject: [PATCH] [SandboxVec][DAG] Fix DAG when old interval is mem free (#126983) This patch fixes a bug in `DependencyGraph::extend()` when the old interval contains no memory instructions. When this is the case we should do a full dependency scan of the new interval. --- .../SandboxVectorizer/DependencyGraph.cpp | 11 +++--- .../SandboxVectorizer/DependencyGraphTest.cpp | 39 +++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp index 06a5e3bed7f03..098b296c30ab8 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp @@ -122,6 +122,8 @@ MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval &Intvl, Interval MemDGNodeIntervalBuilder::make(const Interval &Instrs, DependencyGraph &DAG) { + if (Instrs.empty()) + return {}; auto *TopMemN = getTopMemDGNode(Instrs, DAG); // If we couldn't find a mem node in range TopN - BotN then it's empty. if (TopMemN == nullptr) @@ -529,8 +531,8 @@ Interval DependencyGraph::extend(ArrayRef Instrs) { } } }; - if (DAGInterval.empty()) { - assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!"); + auto MemDAGInterval = MemDGNodeIntervalBuilder::make(DAGInterval, *this); + if (MemDAGInterval.empty()) { FullScan(NewInterval); } // 2. The new section is below the old section. @@ -550,8 +552,7 @@ Interval DependencyGraph::extend(ArrayRef Instrs) { // range including both NewInterval and DAGInterval until DstN, for each DstN. else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) { auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); - auto SrcRangeFull = MemDGNodeIntervalBuilder::make( - DAGInterval.getUnionInterval(NewInterval), *this); + auto SrcRangeFull = MemDAGInterval.getUnionInterval(DstRange); for (MemDGNode &DstN : DstRange) { auto SrcRange = Interval(SrcRangeFull.top(), DstN.getPrevNode()); @@ -589,7 +590,7 @@ Interval DependencyGraph::extend(ArrayRef Instrs) { // When scanning for deps with destination in DAGInterval we need to // consider sources from the NewInterval only, because all intra-DAGInterval // dependencies have already been created. - auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this); + auto DstRangeOld = MemDAGInterval; auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); for (MemDGNode &DstN : DstRangeOld) scanAndAddDeps(DstN, SrcRange); diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp index f1e9afefb4531..37f29428e900a 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp @@ -1013,3 +1013,42 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %arg) { EXPECT_EQ(S2N->getNextNode(), S1N); EXPECT_EQ(S1N->getNextNode(), nullptr); } + +// Extending an "Old" interval with no mem instructions. +TEST_F(DependencyGraphTest, ExtendDAGWithNoMem) { + parseIR(C, R"IR( +define void @foo(ptr %ptr, i8 %v, i8 %v0, i8 %v1, i8 %v2, i8 %v3) { + store i8 %v0, ptr %ptr + store i8 %v1, ptr %ptr + %zext1 = zext i8 %v to i32 + %zext2 = zext i8 %v to i32 + store i8 %v2, ptr %ptr + store i8 %v3, ptr %ptr + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *S0 = cast(&*It++); + auto *S1 = cast(&*It++); + auto *Z1 = cast(&*It++); + auto *Z2 = cast(&*It++); + auto *S2 = cast(&*It++); + auto *S3 = cast(&*It++); + + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); + // Create a non-empty DAG that contains no memory instructions. + DAG.extend({Z1, Z2}); + // Now extend it downwards. + DAG.extend({S2, S3}); + EXPECT_TRUE(memDependency(DAG.getNode(S2), DAG.getNode(S3))); + + // Same but upwards. + DAG.clear(); + DAG.extend({Z1, Z2}); + DAG.extend({S0, S1}); + EXPECT_TRUE(memDependency(DAG.getNode(S0), DAG.getNode(S1))); +}