Skip to content

Commit

Permalink
work around broken gcc _mm512_cmp_ph_mask, reduce sign conversion war…
Browse files Browse the repository at this point in the history
…nings. Refs #2494

PiperOrigin-RevId: 732900038
  • Loading branch information
jan-wassenberg authored and copybara-github committed Mar 3, 2025
1 parent abe6999 commit 8b83f59
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 16 deletions.
32 changes: 23 additions & 9 deletions hwy/ops/x86_256-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2413,29 +2413,37 @@ HWY_API Vec256<int32_t> BroadcastSignBit(const Vec256<int32_t> v) {
return ShiftRight<31>(v);
}

#if HWY_TARGET <= HWY_AVX3

template <int kBits>
HWY_API Vec256<int64_t> ShiftRight(const Vec256<int64_t> v) {
return Vec256<int64_t>{
_mm256_srai_epi64(v.raw, static_cast<Shift64Count>(kBits))};
}

HWY_API Vec256<int64_t> BroadcastSignBit(const Vec256<int64_t> v) {
return ShiftRight<63>(v);
}

#else // AVX2

// Unlike above, this will be used to implement int64_t ShiftRight.
HWY_API Vec256<int64_t> BroadcastSignBit(const Vec256<int64_t> v) {
#if HWY_TARGET == HWY_AVX2
const DFromV<decltype(v)> d;
return VecFromMask(v < Zero(d));
#else
return Vec256<int64_t>{_mm256_srai_epi64(v.raw, 63)};
#endif
}

template <int kBits>
HWY_API Vec256<int64_t> ShiftRight(const Vec256<int64_t> v) {
#if HWY_TARGET <= HWY_AVX3
return Vec256<int64_t>{
_mm256_srai_epi64(v.raw, static_cast<Shift64Count>(kBits))};
#else
const Full256<int64_t> di;
const Full256<uint64_t> du;
const auto right = BitCast(di, ShiftRight<kBits>(BitCast(du, v)));
const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v));
return right | sign;
#endif
}

#endif // #if HWY_TARGET <= HWY_AVX3

// ------------------------------ IfNegativeThenElse (BroadcastSignBit)
HWY_API Vec256<int8_t> IfNegativeThenElse(Vec256<int8_t> v, Vec256<int8_t> yes,
Vec256<int8_t> no) {
Expand Down Expand Up @@ -2495,6 +2503,10 @@ HWY_API Vec256<int32_t> IfNegativeThenNegOrUndefIfZero(Vec256<int32_t> mask,

// ------------------------------ ShiftLeftSame

// Disable sign conversion warnings for GCC debug intrinsics.
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion")

HWY_API Vec256<uint16_t> ShiftLeftSame(const Vec256<uint16_t> v,
const int bits) {
#if HWY_COMPILER_GCC
Expand Down Expand Up @@ -2642,6 +2654,8 @@ HWY_API Vec256<int8_t> ShiftRightSame(Vec256<int8_t> v, const int bits) {
return (shifted ^ shifted_sign) - shifted_sign;
}

HWY_DIAGNOSTICS(pop)

// ------------------------------ Neg (Xor, Sub)

// Tag dispatch instead of SFINAE for MSVC 2017 compatibility
Expand Down
37 changes: 30 additions & 7 deletions hwy/ops/x86_512-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,11 @@ HWY_API Vec512<T> Ror(Vec512<T> a, Vec512<T> b) {
// ------------------------------ ShiftLeftSame

// GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512
// shift-with-immediate: the counts should all be unsigned int.
// shift-with-immediate: the counts should all be unsigned int. Despite casting,
// we still see warnings in GCC debug builds, hence disable.
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion")

#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100
using Shift16Count = int;
using Shift3264Count = int;
Expand Down Expand Up @@ -1642,6 +1646,8 @@ HWY_API Vec512<int8_t> ShiftRightSame(Vec512<int8_t> v, const int bits) {
return (shifted ^ shifted_sign) - shifted_sign;
}

HWY_DIAGNOSTICS(pop)

// ------------------------------ Minimum

// Unsigned
Expand Down Expand Up @@ -2946,11 +2952,25 @@ HWY_API Vec512<int64_t> BroadcastSignBit(Vec512<int64_t> v) {

// ------------------------------ Floating-point classification (Not)

namespace detail {

__mmask32 Fix_mm512_fpclass_ph_mask(__m512h v, int categories) {
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1500
// GCC's _mm512_cmp_ph_mask uses `__mmask8` instead of `__mmask32`, hence only
// the first 8 lanes are set.
return static_cast<__mmask32>(__builtin_ia32_fpclassph512_mask(
static_cast<__v32hf>(v), categories, static_cast<__mmask32>(-1)));
#else
return _mm512_fpclass_ph_mask(v, categories);
#endif
}

#if HWY_HAVE_FLOAT16 || HWY_IDE

HWY_API Mask512<float16_t> IsNaN(Vec512<float16_t> v) {
return Mask512<float16_t>{_mm512_fpclass_ph_mask(
v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)};
constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN;
return Mask512<float16_t>{
detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)};
}

HWY_API Mask512<float16_t> IsEitherNaN(Vec512<float16_t> a,
Expand All @@ -2963,15 +2983,18 @@ HWY_API Mask512<float16_t> IsEitherNaN(Vec512<float16_t> a,
}

HWY_API Mask512<float16_t> IsInf(Vec512<float16_t> v) {
return Mask512<float16_t>{_mm512_fpclass_ph_mask(v.raw, 0x18)};
constexpr int kCategories = HWY_X86_FPCLASS_POS_INF | HWY_X86_FPCLASS_NEG_INF;
return Mask512<float16_t>{
detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)};
}

// Returns whether normal/subnormal/zero. fpclass doesn't have a flag for
// positive, so we have to check for inf/NaN and negate.
HWY_API Mask512<float16_t> IsFinite(Vec512<float16_t> v) {
return Not(Mask512<float16_t>{_mm512_fpclass_ph_mask(
v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN |
HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)});
constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN |
HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF;
return Not(Mask512<float16_t>{
detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)});
}

#endif // HWY_HAVE_FLOAT16
Expand Down

0 comments on commit 8b83f59

Please sign in to comment.