diff --git a/src/QuantUtilsNeon.cc b/src/QuantUtilsNeon.cc index 673caaa6c7..dce4dcfffa 100644 --- a/src/QuantUtilsNeon.cc +++ b/src/QuantUtilsNeon.cc @@ -37,7 +37,13 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( size_t input_rows, int input_columns, OutputType* output) { - int output_columns = input_columns - 2 * sizeof(float); + size_t output_columns = std::max(input_columns - 2 * sizeof(float), 0); + + svbool_t allTruePred = svptrue_b32(); + size_t output_columns_mod = output_columns % 8; + svbool_t lastPredA = svwhilelt_b32_u64(0, output_columns_mod); + svbool_t lastPredB = svwhilelt_b32_u64(4, output_columns_mod); + svbool_t lastPredC = svwhilelt_b16_u64(0, output_columns_mod); for (size_t row = 0; row < input_rows; ++row) { const std::uint8_t* input_row = input + row * input_columns; @@ -45,8 +51,6 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( reinterpret_cast(input_row + output_columns); OutputType* output_row = output + row * output_columns; - svbool_t pred = svptrue_b32(); - float scale = input_row_scale_bias[0]; float bias = input_row_scale_bias[1]; svfloat32_t scale_v = svdup_n_f32(scale); @@ -59,17 +63,19 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( float32x4x2_t* output_row_v = reinterpret_cast(output_row); float16x8_t* output_row_v_half = reinterpret_cast(output_row); - int colIndex = 0; - for (int colMax = output_columns / 8; colIndex < colMax; ++colIndex) { + size_t colIndex = 0; + for (size_t colMax = output_columns / 8; colIndex < colMax; ++colIndex) { svuint32_t in_v_0 = svld1ub_u32( - pred, reinterpret_cast(input_row_v_0 + colIndex)); + allTruePred, + reinterpret_cast(input_row_v_0 + colIndex)); svuint32_t in_v_1 = svld1ub_u32( - pred, reinterpret_cast(input_row_v_1 + colIndex)); - svfloat32_t in_v_0_f = svcvt_f32_u32_x(pred, in_v_0); - svfloat32_t in_v_1_f = svcvt_f32_u32_x(pred, in_v_1); + allTruePred, + reinterpret_cast(input_row_v_1 + colIndex)); + svfloat32_t in_v_0_f = svcvt_f32_u32_x(allTruePred, in_v_0); + svfloat32_t in_v_1_f = svcvt_f32_u32_x(allTruePred, in_v_1); - in_v_0_f = svmad_f32_m(pred, in_v_0_f, scale_v, bias_v); - in_v_1_f = svmad_f32_m(pred, in_v_1_f, scale_v, bias_v); + in_v_0_f = svmad_f32_m(allTruePred, in_v_0_f, scale_v, bias_v); + in_v_1_f = svmad_f32_m(allTruePred, in_v_1_f, scale_v, bias_v); if constexpr (std::is_same()) { output_row_v[colIndex].val[0] = svget_neonq(in_v_0_f); @@ -83,17 +89,31 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( } } -#pragma clang loop vectorize(disable) -#pragma clang loop unroll(disable) - for (colIndex *= 8; colIndex < output_columns; ++colIndex) { - float output_value = input_row[colIndex] * input_row_scale_bias[0] + - input_row_scale_bias[1]; - if (std::is_same()) { - output_row[colIndex] = output_value; - } else { - output_row[colIndex] = cpu_float2half_rn(output_value); - } + svuint32_t in_v_0 = svld1ub_u32( + lastPredA, reinterpret_cast(input_row_v_0 + colIndex)); + svuint32_t in_v_1 = svld1ub_u32( + lastPredB, reinterpret_cast(input_row_v_1 + colIndex)); + svfloat32_t in_v_0_f = svcvt_f32_u32_x(lastPredA, in_v_0); + svfloat32_t in_v_1_f = svcvt_f32_u32_x(lastPredB, in_v_1); + + in_v_0_f = svmad_f32_m(lastPredA, in_v_0_f, scale_v, bias_v); + in_v_1_f = svmad_f32_m(lastPredB, in_v_1_f, scale_v, bias_v); + + if constexpr (std::is_same()) { + svst1_f32(lastPredA, (float32_t*)&(output_row_v[colIndex]), in_v_0_f); + svst1_f32( + lastPredB, (float32_t*)&(output_row_v[colIndex].val[1]), in_v_1_f); + } else { + float16x4_t dequantzed_v_half_low_low = + vcvt_f16_f32(svget_neonq(in_v_0_f)); + float16x8_t dequantzed_v_half_low = + vcvt_high_f16_f32(dequantzed_v_half_low_low, svget_neonq(in_v_1_f)); + svst1_f16( + lastPredC, + (float16_t*)&(output_row_v_half[colIndex]), + svset_neonq_f16(svundef_f16(), dequantzed_v_half_low)); } + } // for each row }