Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

work around broken gcc _mm512_cmp_ph_mask, reduce sign conversion warnings. Refs #2494 #2507

Merged
merged 1 commit into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions hwy/ops/x86_256-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2413,29 +2413,37 @@ HWY_API Vec256<int32_t> BroadcastSignBit(const Vec256<int32_t> v) {
return ShiftRight<31>(v);
}

#if HWY_TARGET <= HWY_AVX3

template <int kBits>
HWY_API Vec256<int64_t> ShiftRight(const Vec256<int64_t> v) {
return Vec256<int64_t>{
_mm256_srai_epi64(v.raw, static_cast<Shift64Count>(kBits))};
}

HWY_API Vec256<int64_t> BroadcastSignBit(const Vec256<int64_t> v) {
return ShiftRight<63>(v);
}

#else // AVX2

// Unlike above, this will be used to implement int64_t ShiftRight.
HWY_API Vec256<int64_t> BroadcastSignBit(const Vec256<int64_t> v) {
#if HWY_TARGET == HWY_AVX2
const DFromV<decltype(v)> d;
return VecFromMask(v < Zero(d));
#else
return Vec256<int64_t>{_mm256_srai_epi64(v.raw, 63)};
#endif
}

template <int kBits>
HWY_API Vec256<int64_t> ShiftRight(const Vec256<int64_t> v) {
#if HWY_TARGET <= HWY_AVX3
return Vec256<int64_t>{
_mm256_srai_epi64(v.raw, static_cast<Shift64Count>(kBits))};
#else
const Full256<int64_t> di;
const Full256<uint64_t> du;
const auto right = BitCast(di, ShiftRight<kBits>(BitCast(du, v)));
const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v));
return right | sign;
#endif
}

#endif // #if HWY_TARGET <= HWY_AVX3

// ------------------------------ IfNegativeThenElse (BroadcastSignBit)
HWY_API Vec256<int8_t> IfNegativeThenElse(Vec256<int8_t> v, Vec256<int8_t> yes,
Vec256<int8_t> no) {
Expand Down Expand Up @@ -2495,6 +2503,10 @@ HWY_API Vec256<int32_t> IfNegativeThenNegOrUndefIfZero(Vec256<int32_t> mask,

// ------------------------------ ShiftLeftSame

// Disable sign conversion warnings for GCC debug intrinsics.
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion")

HWY_API Vec256<uint16_t> ShiftLeftSame(const Vec256<uint16_t> v,
const int bits) {
#if HWY_COMPILER_GCC
Expand Down Expand Up @@ -2642,6 +2654,8 @@ HWY_API Vec256<int8_t> ShiftRightSame(Vec256<int8_t> v, const int bits) {
return (shifted ^ shifted_sign) - shifted_sign;
}

HWY_DIAGNOSTICS(pop)

// ------------------------------ Neg (Xor, Sub)

// Tag dispatch instead of SFINAE for MSVC 2017 compatibility
Expand Down
40 changes: 33 additions & 7 deletions hwy/ops/x86_512-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,11 @@ HWY_API Vec512<T> Ror(Vec512<T> a, Vec512<T> b) {
// ------------------------------ ShiftLeftSame

// GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512
// shift-with-immediate: the counts should all be unsigned int.
// shift-with-immediate: the counts should all be unsigned int. Despite casting,
// we still see warnings in GCC debug builds, hence disable.
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion")

#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100
using Shift16Count = int;
using Shift3264Count = int;
Expand Down Expand Up @@ -1642,6 +1646,8 @@ HWY_API Vec512<int8_t> ShiftRightSame(Vec512<int8_t> v, const int bits) {
return (shifted ^ shifted_sign) - shifted_sign;
}

HWY_DIAGNOSTICS(pop)

// ------------------------------ Minimum

// Unsigned
Expand Down Expand Up @@ -2946,11 +2952,28 @@ HWY_API Vec512<int64_t> BroadcastSignBit(Vec512<int64_t> v) {

// ------------------------------ Floating-point classification (Not)

namespace detail {

template <int kCategories>
__mmask32 Fix_mm512_fpclass_ph_mask(__m512h v) {
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1500
// GCC's _mm512_cmp_ph_mask uses `__mmask8` instead of `__mmask32`, hence only
// the first 8 lanes are set.
return static_cast<__mmask32>(__builtin_ia32_fpclassph512_mask(
static_cast<__v32hf>(v), kCategories, static_cast<__mmask32>(-1)));
#else
return _mm512_fpclass_ph_mask(v, kCategories);
#endif
}

} // namespace detail

#if HWY_HAVE_FLOAT16 || HWY_IDE

HWY_API Mask512<float16_t> IsNaN(Vec512<float16_t> v) {
return Mask512<float16_t>{_mm512_fpclass_ph_mask(
v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)};
constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN;
return Mask512<float16_t>{
detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)};
}

HWY_API Mask512<float16_t> IsEitherNaN(Vec512<float16_t> a,
Expand All @@ -2963,15 +2986,18 @@ HWY_API Mask512<float16_t> IsEitherNaN(Vec512<float16_t> a,
}

HWY_API Mask512<float16_t> IsInf(Vec512<float16_t> v) {
return Mask512<float16_t>{_mm512_fpclass_ph_mask(v.raw, 0x18)};
constexpr int kCategories = HWY_X86_FPCLASS_POS_INF | HWY_X86_FPCLASS_NEG_INF;
return Mask512<float16_t>{
detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)};
}

// Returns whether normal/subnormal/zero. fpclass doesn't have a flag for
// positive, so we have to check for inf/NaN and negate.
HWY_API Mask512<float16_t> IsFinite(Vec512<float16_t> v) {
return Not(Mask512<float16_t>{_mm512_fpclass_ph_mask(
v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN |
HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)});
constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN |
HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF;
return Not(Mask512<float16_t>{
detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)});
}

#endif // HWY_HAVE_FLOAT16
Expand Down
Loading