Skip to content

Commit

Permalink
Fix RVV build/tests
Browse files Browse the repository at this point in the history
AllBits0/1: add D arg, matching the existing docs
Add native MaskedEq/Ne etc
Use second HWY_IF_CONSTEXPR instead of else to avoid lint complaint
Use PositiveIota to prevent overflow, and add an arg to match Iota
Remove unnecessary stddef/stdint in tests, export from test_util-inl
Limit values in TestMaskedWidenMulPairwiseAdd to prevent loss of precision
TestAllInsertIntoUpper: only for not-already-minimal vectors
reduction_test: use proper min/max initializer, N can overflow
shift_test: fix padding and initialization for >128 bit vectors
topology: add tolerance for computed sets
PiperOrigin-RevId: 731693011
  • Loading branch information
jan-wassenberg authored and copybara-github committed Mar 3, 2025
1 parent abe6999 commit 77e0f50
Show file tree
Hide file tree
Showing 16 changed files with 194 additions and 159 deletions.
6 changes: 4 additions & 2 deletions hwy/contrib/sort/shared-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,15 @@ static_assert(SortConstants::MaxBufBytes<2>(64) <= 1664, "Unexpectedly high");
// vqsort isn't available on HWY_SCALAR, and builds time out on MSVC opt and
// Armv7 debug, and Armv8 GCC 11 asan hits an internal compiler error likely
// due to https://gcc.gnu.org/bugzilla/show_bug.cgi?id=97696. Armv8 Clang
// hwasan/msan/tsan/asan also fail to build SVE (b/335157772).
// hwasan/msan/tsan/asan also fail to build SVE (b/335157772). RVV currently
// has a compiler issue.
#undef VQSORT_ENABLED
#undef VQSORT_COMPILER_COMPATIBLE

#if (HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD) || \
(HWY_ARCH_ARM_V7 && HWY_IS_DEBUG_BUILD) || \
(HWY_ARCH_ARM_A64 && HWY_COMPILER_GCC_ACTUAL && HWY_IS_ASAN)
(HWY_ARCH_ARM_A64 && HWY_COMPILER_GCC_ACTUAL && HWY_IS_ASAN) || \
(HWY_ARCH_RISCV)
#define VQSORT_COMPILER_COMPATIBLE 0
#else
#define VQSORT_COMPILER_COMPATIBLE 1
Expand Down
3 changes: 2 additions & 1 deletion hwy/contrib/thread_pool/topology.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1189,7 +1189,8 @@ HWY_MAYBE_UNUSED void ComputeSets(Cache& c) {
if (c.sets == 0) {
c.sets = static_cast<uint32_t>(sets);
} else {
if (c.sets != sets) {
const size_t diff = c.sets - sets;
if (diff > 1) {
HWY_ABORT("Inconsistent cache sets %u != %zu\n", c.sets, sets);
}
}
Expand Down
18 changes: 9 additions & 9 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ HWY_API Mask<DTo> DemoteMaskTo(DTo d_to, DFrom d_from, Mask<DFrom> m) {
#else
#define HWY_NATIVE_LOAD_HIGHER
#endif
template <class D, typename T, class V = VFromD<D>(), HWY_IF_LANES_GT_D(D, 1)>
template <class D, typename T, class V = VFromD<D>(), HWY_IF_LANES_GT_D(D, 1),
HWY_IF_POW2_GT_D(D, -3)>
HWY_API V InsertIntoUpper(D d, T* p, V a) {
Half<D> dh;
const VFromD<decltype(dh)> b = LoadU(dh, p);
Expand Down Expand Up @@ -7753,11 +7754,11 @@ HWY_API V MaskedOr(M m, V a, V b) {
#define HWY_NATIVE_ALLONES
#endif

template <class V>
HWY_API bool AllBits1(V a) {
const RebindToUnsigned<DFromV<V>> du;
template <class D, class V = VFromD<D>>
HWY_API bool AllBits1(D d, V v) {
const RebindToUnsigned<decltype(d)> du;
using TU = TFromD<decltype(du)>;
return AllTrue(du, Eq(BitCast(du, a), Set(du, hwy::HighestValue<TU>())));
return AllTrue(du, Eq(BitCast(du, v), Set(du, hwy::HighestValue<TU>())));
}
#endif // HWY_NATIVE_ALLONES

Expand All @@ -7768,10 +7769,9 @@ HWY_API bool AllBits1(V a) {
#define HWY_NATIVE_ALLZEROS
#endif

template <class V>
HWY_API bool AllBits0(V a) {
DFromV<V> d;
return AllTrue(d, Eq(a, Zero(d)));
template <class D, class V = VFromD<D>>
HWY_API bool AllBits0(D d, V v) {
return AllTrue(d, Eq(v, Zero(d)));
}
#endif // HWY_NATIVE_ALLZEROS

Expand Down
96 changes: 72 additions & 24 deletions hwy/ops/rvv-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,20 @@ HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc, _ALL)

// ================================================== COMPARE

// ------------------------------ MClear

// mask = f()
#define HWY_RVV_RETM(SEW, SHIFT, MLEN, NAME, OP) \
HWY_API HWY_RVV_M(MLEN) NAME##MLEN() { \
return __riscv_vm##OP##_m_b##MLEN(HWY_RVV_AVL(SEW, SHIFT)); \
}

namespace detail {
HWY_RVV_FOREACH_B(HWY_RVV_RETM, MClear, clr) // with ##MLEN suffix
} // namespace detail

#undef HWY_RVV_RETM

// Comparisons set a mask bit to 1 if the condition is true, else 0. The XX in
// vboolXX_t is a power of two divisor for vector bits. SEW=8 / LMUL=1 = 1/8th
// of all bits; SEW=8 / LMUL=4 = half of all bits.
Expand All @@ -1584,6 +1598,16 @@ HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc, _ALL)
a, b, HWY_RVV_AVL(SEW, SHIFT)); \
}

// mask = f(mask, vector, vector)
#define HWY_RVV_RETM_ARGMVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \
SHIFT, MLEN, NAME, OP) \
HWY_API HWY_RVV_M(MLEN) \
NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) a, \
HWY_RVV_V(BASE, SEW, LMUL) b) { \
return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_b##MLEN##_mu( \
m, detail::MClear##MLEN(), a, b, HWY_RVV_AVL(SEW, SHIFT)); \
}

// mask = f(vector, scalar)
#define HWY_RVV_RETM_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \
SHIFT, MLEN, NAME, OP) \
Expand All @@ -1593,9 +1617,17 @@ HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc, _ALL)
a, b, HWY_RVV_AVL(SEW, SHIFT)); \
}

#ifdef HWY_NATIVE_MASKED_COMP
#undef HWY_NATIVE_MASKED_COMP
#else
#define HWY_NATIVE_MASKED_COMP
#endif

// ------------------------------ Eq
HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Eq, mseq, _ALL)
HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Eq, mfeq, _ALL)
HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGMVV, MaskedEq, mseq, _ALL)
HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedEq, mfeq, _ALL)

namespace detail {
HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, EqS, mseq_vx, _ALL)
Expand All @@ -1605,6 +1637,8 @@ HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, EqS, mfeq_vf, _ALL)
// ------------------------------ Ne
HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Ne, msne, _ALL)
HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Ne, mfne, _ALL)
HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGMVV, MaskedNe, msne, _ALL)
HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedNe, mfne, _ALL)

namespace detail {
HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, NeS, msne_vx, _ALL)
Expand All @@ -1615,6 +1649,9 @@ HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, NeS, mfne_vf, _ALL)
HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVV, Lt, msltu, _ALL)
HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Lt, mslt, _ALL)
HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Lt, mflt, _ALL)
HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGMVV, MaskedLt, msltu, _ALL)
HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGMVV, MaskedLt, mslt, _ALL)
HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedLt, mflt, _ALL)

namespace detail {
HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVS, LtS, mslt_vx, _ALL)
Expand All @@ -1626,20 +1663,43 @@ HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, LtS, mflt_vf, _ALL)
HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVV, Le, msleu, _ALL)
HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Le, msle, _ALL)
HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Le, mfle, _ALL)
HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGMVV, MaskedLe, msleu, _ALL)
HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGMVV, MaskedLe, msle, _ALL)
HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedLe, mfle, _ALL)

template <class D>
using MFromD = decltype(Eq(Zero(D()), Zero(D())));

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsNaN(const M m, const V v) {
return MaskedNe(m, v, v);
}

#undef HWY_RVV_RETM_ARGMVV
#undef HWY_RVV_RETM_ARGVV
#undef HWY_RVV_RETM_ARGVS

// ------------------------------ Gt/Ge
// ------------------------------ Gt/Ge (Lt, Le)

// Swap args to reverse comparisons:
template <class V>
HWY_API auto Gt(const V a, const V b) -> decltype(Lt(a, b)) {
return Lt(b, a);
}

template <class V>
HWY_API auto Ge(const V a, const V b) -> decltype(Le(a, b)) {
return Le(b, a);
}

template <class V>
HWY_API auto Gt(const V a, const V b) -> decltype(Lt(a, b)) {
return Lt(b, a);
template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedGt(M m, V a, V b) {
return MaskedLt(m, b, a);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedGe(M m, V a, V b) {
return MaskedLe(m, b, a);
}

// ------------------------------ TestBit
Expand Down Expand Up @@ -1713,10 +1773,6 @@ HWY_RVV_FOREACH_F(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, fmerge_vfm, _ALL)
#undef HWY_RVV_IF_THEN_ZERO_ELSE

// ------------------------------ MaskFromVec

template <class D>
using MFromD = decltype(Eq(Zero(D()), Zero(D())));

template <class V>
HWY_API MFromD<DFromV<V>> MaskFromVec(const V v) {
return detail::NeS(v, 0);
Expand Down Expand Up @@ -3185,7 +3241,7 @@ HWY_RVV_FOREACH(HWY_RVV_SLIDE_DOWN, SlideDown, slidedown, _ALL)
HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \
HWY_IF_CONSTEXPR(kIndex == 0) { return Trunc(v); } \
else { \
HWY_IF_CONSTEXPR(kIndex != 0) { \
return Trunc(SlideDown( \
v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \
SHIFT - 1){}))); \
Expand All @@ -3197,7 +3253,7 @@ HWY_RVV_FOREACH(HWY_RVV_SLIDE_DOWN, SlideDown, slidedown, _ALL)
HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \
HWY_IF_CONSTEXPR(kIndex == 0) { return v; } \
else { \
HWY_IF_CONSTEXPR(kIndex != 0) { \
return SlideDown( \
v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \
SHIFT){}) / \
Expand All @@ -3215,13 +3271,9 @@ template <size_t kIndex, class D>
static HWY_INLINE HWY_MAYBE_UNUSED VFromD<AdjustSimdTagToMinVecPow2<Half<D>>>
Get(D d, VFromD<D> v) {
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1");

const AdjustSimdTagToMinVecPow2<Half<decltype(d)>> dh;
HWY_IF_CONSTEXPR(kIndex == 0 || detail::IsFull(d)) {
(void)dh;
return Get<kIndex>(v);
}
else {
HWY_IF_CONSTEXPR(kIndex == 0 || detail::IsFull(d)) { return Get<kIndex>(v); }
HWY_IF_CONSTEXPR(kIndex != 0 && !detail::IsFull(d)) {
const AdjustSimdTagToMinVecPow2<Half<decltype(d)>> dh;
const size_t slide_down_amt =
(dh.Pow2() < DFromV<decltype(v)>().Pow2()) ? Lanes(dh) : (Lanes(d) / 2);
return ResizeBitCast(dh, SlideDown(v, slide_down_amt));
Expand All @@ -3240,9 +3292,7 @@ Get(D d, VFromD<D> v) {
return __riscv_v##OP##_v_v_##CHAR##SEW##LMUL##_tu(dest, Ext(d, v), \
half_N); \
} \
else { \
return SlideUp(dest, Ext(d, v), half_N); \
} \
HWY_IF_CONSTEXPR(kIndex != 0) { return SlideUp(dest, Ext(d, v), half_N); } \
}
#define HWY_RVV_PARTIAL_VEC_SET_HALF_SMALLEST( \
BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, MLEN, NAME, OP) \
Expand All @@ -3254,9 +3304,7 @@ Get(D d, VFromD<D> v) {
HWY_IF_CONSTEXPR(kIndex == 0) { \
return __riscv_v##OP##_v_v_##CHAR##SEW##LMUL##_tu(dest, v, half_N); \
} \
else { \
return SlideUp(dest, v, half_N); \
} \
HWY_IF_CONSTEXPR(kIndex != 0) { return SlideUp(dest, v, half_N); } \
}
HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF, PartialVecSetHalf, mv, _GET_SET)
HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF, PartialVecSetHalf, mv,
Expand All @@ -3276,7 +3324,7 @@ HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF_SMALLEST, PartialVecSetHalf, mv,
return __riscv_v##OP##_v_##CHAR##SEW##LMULH##_##CHAR##SEW##LMUL( \
dest, kIndex, v); /* no AVL */ \
} \
else { \
HWY_IF_CONSTEXPR(!detail::IsFull(d)) { \
const Half<decltype(d)> dh; \
return PartialVecSetHalf<kIndex>(dest, v, Lanes(dh)); \
} \
Expand Down
Loading

0 comments on commit 77e0f50

Please sign in to comment.