Skip to content

Commit 831527a

Browse files
authored
[FMV][GlobalOpt] Statically resolve calls to versioned functions. (#87939)
To deduce whether the optimization is legal we need to compare the target features between caller and callee versions. The criteria for bypassing the resolver are the following: * If the callee's feature set is a subset of the caller's feature set, then the callee is a candidate for direct call. * Among such candidates the one of highest priority is the best match and it shall be picked, unless there is a version of the callee with higher priority than the best match which cannot be picked from a higher priority caller (directly or through the resolver). * For every higher priority callee version than the best match, there is a higher priority caller version whose feature set availability is implied by the callee's feature set. Example: Callers and Callees are ordered in decreasing priority. The arrows indicate successful call redirections. Caller Callee Explanation ========================================================================= mops+sve2 --+--> mops all the callee versions are subsets of the | caller but mops has the highest priority | mops --+ sve2 between mops and default callees, mops wins sve sve between sve and default callees, sve wins but sve2 does not have a high priority caller default -----> default sve (callee) implies sve (caller), sve2(callee) implies sve (caller), mops(callee) implies mops(caller)
1 parent 101109f commit 831527a

File tree

9 files changed

+608
-10
lines changed

9 files changed

+608
-10
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+17
Original file line numberDiff line numberDiff line change
@@ -1870,6 +1870,13 @@ class TargetTransformInfo {
18701870
/// false, but it shouldn't matter what it returns anyway.
18711871
bool hasArmWideBranch(bool Thumb) const;
18721872

1873+
/// Returns a bitmask constructed from the target-features or fmv-features
1874+
/// metadata of a function.
1875+
uint64_t getFeatureMask(const Function &F) const;
1876+
1877+
/// Returns true if this is an instance of a function with multiple versions.
1878+
bool isMultiversionedFunction(const Function &F) const;
1879+
18731880
/// \return The maximum number of function arguments the target supports.
18741881
unsigned getMaxNumArgs() const;
18751882

@@ -2312,6 +2319,8 @@ class TargetTransformInfo::Concept {
23122319
virtual VPLegalization
23132320
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
23142321
virtual bool hasArmWideBranch(bool Thumb) const = 0;
2322+
virtual uint64_t getFeatureMask(const Function &F) const = 0;
2323+
virtual bool isMultiversionedFunction(const Function &F) const = 0;
23152324
virtual unsigned getMaxNumArgs() const = 0;
23162325
virtual unsigned getNumBytesToPadGlobalArray(unsigned Size,
23172326
Type *ArrayType) const = 0;
@@ -3144,6 +3153,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
31443153
return Impl.hasArmWideBranch(Thumb);
31453154
}
31463155

3156+
uint64_t getFeatureMask(const Function &F) const override {
3157+
return Impl.getFeatureMask(F);
3158+
}
3159+
3160+
bool isMultiversionedFunction(const Function &F) const override {
3161+
return Impl.isMultiversionedFunction(F);
3162+
}
3163+
31473164
unsigned getMaxNumArgs() const override {
31483165
return Impl.getMaxNumArgs();
31493166
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+4
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,10 @@ class TargetTransformInfoImplBase {
10391039

10401040
bool hasArmWideBranch(bool) const { return false; }
10411041

1042+
uint64_t getFeatureMask(const Function &F) const { return 0; }
1043+
1044+
bool isMultiversionedFunction(const Function &F) const { return false; }
1045+
10421046
unsigned getMaxNumArgs() const { return UINT_MAX; }
10431047

10441048
unsigned getNumBytesToPadGlobalArray(unsigned Size, Type *ArrayType) const {

llvm/include/llvm/TargetParser/AArch64TargetParser.h

+8-5
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,16 @@ void fillValidCPUArchList(SmallVectorImpl<StringRef> &Values);
270270

271271
bool isX18ReservedByDefault(const Triple &TT);
272272

273-
// Return the priority for a given set of FMV features.
273+
// For a given set of feature names, which can be either target-features, or
274+
// fmv-features metadata, expand their dependencies and then return a bitmask
275+
// corresponding to the entries of AArch64::FeatPriorities.
274276
uint64_t getFMVPriority(ArrayRef<StringRef> Features);
275277

276-
// For given feature names, return a bitmask corresponding to the entries of
277-
// AArch64::CPUFeatures. The values in CPUFeatures are not bitmasks themselves,
278-
// they are sequential (0, 1, 2, 3, ...). The resulting bitmask is used at
279-
// runtime to test whether a certain FMV feature is available on the host.
278+
// For a given set of FMV feature names, expand their dependencies and then
279+
// return a bitmask corresponding to the entries of AArch64::CPUFeatures.
280+
// The values in CPUFeatures are not bitmasks themselves, they are sequential
281+
// (0, 1, 2, 3, ...). The resulting bitmask is used at runtime to test whether
282+
// a certain FMV feature is available on the host.
280283
uint64_t getCpuSupportsMask(ArrayRef<StringRef> Features);
281284

282285
void PrintSupportedExtensions();

llvm/lib/Analysis/TargetTransformInfo.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,14 @@ bool TargetTransformInfo::hasArmWideBranch(bool Thumb) const {
13831383
return TTIImpl->hasArmWideBranch(Thumb);
13841384
}
13851385

1386+
uint64_t TargetTransformInfo::getFeatureMask(const Function &F) const {
1387+
return TTIImpl->getFeatureMask(F);
1388+
}
1389+
1390+
bool TargetTransformInfo::isMultiversionedFunction(const Function &F) const {
1391+
return TTIImpl->isMultiversionedFunction(F);
1392+
}
1393+
13861394
unsigned TargetTransformInfo::getMaxNumArgs() const {
13871395
return TTIImpl->getMaxNumArgs();
13881396
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/IR/IntrinsicsAArch64.h"
2424
#include "llvm/IR/PatternMatch.h"
2525
#include "llvm/Support/Debug.h"
26+
#include "llvm/TargetParser/AArch64TargetParser.h"
2627
#include "llvm/Transforms/InstCombine/InstCombiner.h"
2728
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
2829
#include <algorithm>
@@ -248,6 +249,19 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
248249
return false;
249250
}
250251

252+
uint64_t AArch64TTIImpl::getFeatureMask(const Function &F) const {
253+
StringRef AttributeStr =
254+
isMultiversionedFunction(F) ? "fmv-features" : "target-features";
255+
StringRef FeatureStr = F.getFnAttribute(AttributeStr).getValueAsString();
256+
SmallVector<StringRef, 8> Features;
257+
FeatureStr.split(Features, ",");
258+
return AArch64::getFMVPriority(Features);
259+
}
260+
261+
bool AArch64TTIImpl::isMultiversionedFunction(const Function &F) const {
262+
return F.hasFnAttribute("fmv-features");
263+
}
264+
251265
bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
252266
const Function *Callee) const {
253267
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

+4
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
8989
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
9090
unsigned DefaultCallPenalty) const;
9191

92+
uint64_t getFeatureMask(const Function &F) const;
93+
94+
bool isMultiversionedFunction(const Function &F) const;
95+
9296
/// \name Scalar TTI Implementations
9397
/// @{
9498

llvm/lib/TargetParser/AArch64TargetParser.cpp

+26-5
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,33 @@ std::optional<AArch64::ArchInfo> AArch64::ArchInfo::findBySubArch(StringRef SubA
4848
return {};
4949
}
5050

51+
std::optional<AArch64::FMVInfo> lookupFMVByID(AArch64::ArchExtKind ExtID) {
52+
for (const AArch64::FMVInfo &Info : AArch64::getFMVInfo())
53+
if (Info.ID && *Info.ID == ExtID)
54+
return Info;
55+
return {};
56+
}
57+
5158
uint64_t AArch64::getFMVPriority(ArrayRef<StringRef> Features) {
52-
uint64_t Priority = 0;
53-
for (StringRef Feature : Features)
54-
if (std::optional<FMVInfo> Info = parseFMVExtension(Feature))
55-
Priority |= (1ULL << Info->PriorityBit);
56-
return Priority;
59+
// Transitively enable the Arch Extensions which correspond to each feature.
60+
ExtensionSet FeatureBits;
61+
for (const StringRef Feature : Features) {
62+
std::optional<FMVInfo> FMV = parseFMVExtension(Feature);
63+
if (!FMV) {
64+
if (std::optional<ExtensionInfo> Info = targetFeatureToExtension(Feature))
65+
FMV = lookupFMVByID(Info->ID);
66+
}
67+
if (FMV && FMV->ID)
68+
FeatureBits.enable(*FMV->ID);
69+
}
70+
71+
// Construct a bitmask for all the transitively enabled Arch Extensions.
72+
uint64_t PriorityMask = 0;
73+
for (const FMVInfo &Info : getFMVInfo())
74+
if (Info.ID && FeatureBits.Enabled.test(*Info.ID))
75+
PriorityMask |= (1ULL << Info.PriorityBit);
76+
77+
return PriorityMask;
5778
}
5879

5980
uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> Features) {

llvm/lib/Transforms/IPO/GlobalOpt.cpp

+162
Original file line numberDiff line numberDiff line change
@@ -2641,6 +2641,165 @@ DeleteDeadIFuncs(Module &M,
26412641
return Changed;
26422642
}
26432643

2644+
// Follows the use-def chain of \p V backwards until it finds a Function,
2645+
// in which case it collects in \p Versions. Return true on successful
2646+
// use-def chain traversal, false otherwise.
2647+
static bool collectVersions(TargetTransformInfo &TTI, Value *V,
2648+
SmallVectorImpl<Function *> &Versions) {
2649+
if (auto *F = dyn_cast<Function>(V)) {
2650+
if (!TTI.isMultiversionedFunction(*F))
2651+
return false;
2652+
Versions.push_back(F);
2653+
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
2654+
if (!collectVersions(TTI, Sel->getTrueValue(), Versions))
2655+
return false;
2656+
if (!collectVersions(TTI, Sel->getFalseValue(), Versions))
2657+
return false;
2658+
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
2659+
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
2660+
if (!collectVersions(TTI, Phi->getIncomingValue(I), Versions))
2661+
return false;
2662+
} else {
2663+
// Unknown instruction type. Bail.
2664+
return false;
2665+
}
2666+
return true;
2667+
}
2668+
2669+
// Bypass the IFunc Resolver of MultiVersioned functions when possible. To
2670+
// deduce whether the optimization is legal we need to compare the target
2671+
// features between caller and callee versions. The criteria for bypassing
2672+
// the resolver are the following:
2673+
//
2674+
// * If the callee's feature set is a subset of the caller's feature set,
2675+
// then the callee is a candidate for direct call.
2676+
//
2677+
// * Among such candidates the one of highest priority is the best match
2678+
// and it shall be picked, unless there is a version of the callee with
2679+
// higher priority than the best match which cannot be picked from a
2680+
// higher priority caller (directly or through the resolver).
2681+
//
2682+
// * For every higher priority callee version than the best match, there
2683+
// is a higher priority caller version whose feature set availability
2684+
// is implied by the callee's feature set.
2685+
//
2686+
static bool OptimizeNonTrivialIFuncs(
2687+
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
2688+
bool Changed = false;
2689+
2690+
// Cache containing the mask constructed from a function's target features.
2691+
DenseMap<Function *, uint64_t> FeatureMask;
2692+
2693+
for (GlobalIFunc &IF : M.ifuncs()) {
2694+
if (IF.isInterposable())
2695+
continue;
2696+
2697+
Function *Resolver = IF.getResolverFunction();
2698+
if (!Resolver)
2699+
continue;
2700+
2701+
if (Resolver->isInterposable())
2702+
continue;
2703+
2704+
TargetTransformInfo &TTI = GetTTI(*Resolver);
2705+
2706+
// Discover the callee versions.
2707+
SmallVector<Function *> Callees;
2708+
if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
2709+
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
2710+
if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
2711+
return true;
2712+
return false;
2713+
}))
2714+
continue;
2715+
2716+
assert(!Callees.empty() && "Expecting successful collection of versions");
2717+
2718+
// Cache the feature mask for each callee.
2719+
for (Function *Callee : Callees) {
2720+
auto [It, Inserted] = FeatureMask.try_emplace(Callee);
2721+
if (Inserted)
2722+
It->second = TTI.getFeatureMask(*Callee);
2723+
}
2724+
2725+
// Sort the callee versions in decreasing priority order.
2726+
sort(Callees, [&](auto *LHS, auto *RHS) {
2727+
return FeatureMask[LHS] > FeatureMask[RHS];
2728+
});
2729+
2730+
// Find the callsites and cache the feature mask for each caller.
2731+
SmallVector<Function *> Callers;
2732+
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
2733+
for (User *U : IF.users()) {
2734+
if (auto *CB = dyn_cast<CallBase>(U)) {
2735+
if (CB->getCalledOperand() == &IF) {
2736+
Function *Caller = CB->getFunction();
2737+
auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
2738+
if (FeatInserted)
2739+
FeatIt->second = TTI.getFeatureMask(*Caller);
2740+
auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
2741+
if (CallInserted)
2742+
Callers.push_back(Caller);
2743+
CallIt->second.push_back(CB);
2744+
}
2745+
}
2746+
}
2747+
2748+
// Sort the caller versions in decreasing priority order.
2749+
sort(Callers, [&](auto *LHS, auto *RHS) {
2750+
return FeatureMask[LHS] > FeatureMask[RHS];
2751+
});
2752+
2753+
auto implies = [](uint64_t A, uint64_t B) { return (A & B) == B; };
2754+
2755+
// Index to the highest priority candidate.
2756+
unsigned I = 0;
2757+
// Now try to redirect calls starting from higher priority callers.
2758+
for (Function *Caller : Callers) {
2759+
assert(I < Callees.size() && "Found callers of equal priority");
2760+
2761+
Function *Callee = Callees[I];
2762+
uint64_t CallerBits = FeatureMask[Caller];
2763+
uint64_t CalleeBits = FeatureMask[Callee];
2764+
2765+
// In the case of FMV callers, we know that all higher priority callers
2766+
// than the current one did not get selected at runtime, which helps
2767+
// reason about the callees (if they have versions that mandate presence
2768+
// of the features which we already know are unavailable on this target).
2769+
if (TTI.isMultiversionedFunction(*Caller)) {
2770+
// If the feature set of the caller implies the feature set of the
2771+
// highest priority candidate then it shall be picked. In case of
2772+
// identical sets advance the candidate index one position.
2773+
if (CallerBits == CalleeBits)
2774+
++I;
2775+
else if (!implies(CallerBits, CalleeBits)) {
2776+
// Keep advancing the candidate index as long as the caller's
2777+
// features are a subset of the current candidate's.
2778+
while (implies(CalleeBits, CallerBits)) {
2779+
if (++I == Callees.size())
2780+
break;
2781+
CalleeBits = FeatureMask[Callees[I]];
2782+
}
2783+
continue;
2784+
}
2785+
} else {
2786+
// We can't reason much about non-FMV callers. Just pick the highest
2787+
// priority callee if it matches, otherwise bail.
2788+
if (I > 0 || !implies(CallerBits, CalleeBits))
2789+
continue;
2790+
}
2791+
auto &Calls = CallSites[Caller];
2792+
for (CallBase *CS : Calls)
2793+
CS->setCalledOperand(Callee);
2794+
Changed = true;
2795+
}
2796+
if (IF.use_empty() ||
2797+
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))
2798+
NumIFuncsResolved++;
2799+
}
2800+
return Changed;
2801+
}
2802+
26442803
static bool
26452804
optimizeGlobalsInModule(Module &M, const DataLayout &DL,
26462805
function_ref<TargetLibraryInfo &(Function &)> GetTLI,
@@ -2707,6 +2866,9 @@ optimizeGlobalsInModule(Module &M, const DataLayout &DL,
27072866
// Optimize IFuncs whose callee's are statically known.
27082867
LocalChange |= OptimizeStaticIFuncs(M);
27092868

2869+
// Optimize IFuncs based on the target features of the caller.
2870+
LocalChange |= OptimizeNonTrivialIFuncs(M, GetTTI);
2871+
27102872
// Remove any IFuncs that are now dead.
27112873
LocalChange |= DeleteDeadIFuncs(M, NotDiscardableComdats);
27122874

0 commit comments

Comments
 (0)