Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon by 5%-15% #3860

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions src/QuantUtilsNeon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,20 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
size_t input_rows,
int input_columns,
OutputType* output) {
int output_columns = input_columns - 2 * sizeof(float);
size_t output_columns = std::max<int>(input_columns - 2 * sizeof(float), 0);

svbool_t allTruePred = svptrue_b32();
size_t output_columns_mod = output_columns % 8;
svbool_t lastPredA = svwhilelt_b32_u64(0, output_columns_mod);
svbool_t lastPredB = svwhilelt_b32_u64(4, output_columns_mod);
svbool_t lastPredC = svwhilelt_b16_u64(0, output_columns_mod);

for (size_t row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const float* input_row_scale_bias =
reinterpret_cast<const float*>(input_row + output_columns);
OutputType* output_row = output + row * output_columns;

svbool_t pred = svptrue_b32();

float scale = input_row_scale_bias[0];
float bias = input_row_scale_bias[1];
svfloat32_t scale_v = svdup_n_f32(scale);
Expand All @@ -59,17 +63,19 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
float32x4x2_t* output_row_v = reinterpret_cast<float32x4x2_t*>(output_row);
float16x8_t* output_row_v_half = reinterpret_cast<float16x8_t*>(output_row);

int colIndex = 0;
for (int colMax = output_columns / 8; colIndex < colMax; ++colIndex) {
size_t colIndex = 0;
for (size_t colMax = output_columns / 8; colIndex < colMax; ++colIndex) {
svuint32_t in_v_0 = svld1ub_u32(
pred, reinterpret_cast<const uint8_t*>(input_row_v_0 + colIndex));
allTruePred,
reinterpret_cast<const uint8_t*>(input_row_v_0 + colIndex));
svuint32_t in_v_1 = svld1ub_u32(
pred, reinterpret_cast<const uint8_t*>(input_row_v_1 + colIndex));
svfloat32_t in_v_0_f = svcvt_f32_u32_x(pred, in_v_0);
svfloat32_t in_v_1_f = svcvt_f32_u32_x(pred, in_v_1);
allTruePred,
reinterpret_cast<const uint8_t*>(input_row_v_1 + colIndex));
svfloat32_t in_v_0_f = svcvt_f32_u32_x(allTruePred, in_v_0);
svfloat32_t in_v_1_f = svcvt_f32_u32_x(allTruePred, in_v_1);

in_v_0_f = svmad_f32_m(pred, in_v_0_f, scale_v, bias_v);
in_v_1_f = svmad_f32_m(pred, in_v_1_f, scale_v, bias_v);
in_v_0_f = svmad_f32_m(allTruePred, in_v_0_f, scale_v, bias_v);
in_v_1_f = svmad_f32_m(allTruePred, in_v_1_f, scale_v, bias_v);

if constexpr (std::is_same<OutputType, float>()) {
output_row_v[colIndex].val[0] = svget_neonq(in_v_0_f);
Expand All @@ -83,17 +89,31 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
}
}

#pragma clang loop vectorize(disable)
#pragma clang loop unroll(disable)
for (colIndex *= 8; colIndex < output_columns; ++colIndex) {
float output_value = input_row[colIndex] * input_row_scale_bias[0] +
input_row_scale_bias[1];
if (std::is_same<OutputType, float>()) {
output_row[colIndex] = output_value;
} else {
output_row[colIndex] = cpu_float2half_rn(output_value);
}
svuint32_t in_v_0 = svld1ub_u32(
lastPredA, reinterpret_cast<const uint8_t*>(input_row_v_0 + colIndex));
svuint32_t in_v_1 = svld1ub_u32(
lastPredB, reinterpret_cast<const uint8_t*>(input_row_v_1 + colIndex));
svfloat32_t in_v_0_f = svcvt_f32_u32_x(lastPredA, in_v_0);
svfloat32_t in_v_1_f = svcvt_f32_u32_x(lastPredB, in_v_1);

in_v_0_f = svmad_f32_m(lastPredA, in_v_0_f, scale_v, bias_v);
in_v_1_f = svmad_f32_m(lastPredB, in_v_1_f, scale_v, bias_v);

if constexpr (std::is_same<OutputType, float>()) {
svst1_f32(lastPredA, (float32_t*)&(output_row_v[colIndex]), in_v_0_f);
svst1_f32(
lastPredB, (float32_t*)&(output_row_v[colIndex].val[1]), in_v_1_f);
} else {
float16x4_t dequantzed_v_half_low_low =
vcvt_f16_f32(svget_neonq(in_v_0_f));
float16x8_t dequantzed_v_half_low =
vcvt_high_f16_f32(dequantzed_v_half_low_low, svget_neonq(in_v_1_f));
svst1_f16(
lastPredC,
(float16_t*)&(output_row_v_half[colIndex]),
svset_neonq_f16(svundef_f16(), dequantzed_v_half_low));
}

} // for each row
}

Expand Down
Loading