Skip to content

Commit 664f14a

Browse files
Suharsh Sivakumarbjacob
Suharsh Sivakumar
authored andcommitted
Per-channel output rescale and int8 input support for NEON.
1 parent 2390b74 commit 664f14a

17 files changed

+1267
-50
lines changed

CONTRIBUTORS

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Maciek Chociej <[email protected]>
1818
Justine Tunney <[email protected]>
1919
Mark J. Matthews <[email protected]>
2020
Marie White <[email protected]>
21+
Suharsh Sivakumar <[email protected]>
2122

2223
Intel:
2324
Sagi Marcovich <[email protected]>

fixedpoint/fixedpoint.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ tIntegerType Neg(tIntegerType a) {
121121
// in the overflow case, we just want to avoid undefined behavior.
122122
//
123123
// tIntegerType may be int32 or any narrower signed type.
124-
template <typename tIntegerType>
125-
tIntegerType ShiftLeft(tIntegerType a, int offset) {
124+
template <typename tIntegerType, typename OffsetType>
125+
tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
126126
const std::int64_t wide_a = static_cast<std::int64_t>(a);
127127
const std::int64_t wide_shifted = wide_a * (1 << offset);
128128
const auto min = std::numeric_limits<tIntegerType>::min();
@@ -353,8 +353,8 @@ inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
353353

354354
// Correctly-rounded-to-nearest division by a power-of-two.
355355
// Also known as a rounding arithmetic right shift.
356-
template <typename IntegerType>
357-
inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
356+
template <typename IntegerType, typename ExponentType>
357+
inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
358358
assert(exponent >= 0);
359359
assert(exponent <= 31);
360360
const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);

fixedpoint/fixedpoint_neon.h

+26
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,16 @@ inline int16x8_t ShiftLeft(int16x8_t a, int offset) {
114114
return vshlq_s16(a, vdupq_n_s16(offset));
115115
}
116116

117+
template <>
118+
inline int32x4_t ShiftLeft(int32x4_t a, int32x4_t offset) {
119+
return vshlq_s32(a, offset);
120+
}
121+
122+
template <>
123+
inline int16x8_t ShiftLeft(int16x8_t a, int16x8_t offset) {
124+
return vshlq_s32(a, offset);
125+
}
126+
117127
template <>
118128
inline int32x4_t ShiftRight(int32x4_t a, int offset) {
119129
return vshlq_s32(a, vdupq_n_s32(-offset));
@@ -282,6 +292,22 @@ inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) {
282292
return vrshlq_s16(fixed_up_x, shift_vec);
283293
}
284294

295+
template <>
296+
inline int32x4_t RoundingDivideByPOT(int32x4_t x, int32x4_t exponent) {
297+
const int32x4_t shift_vec = vnegq_s32(exponent);
298+
const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
299+
const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
300+
return vrshlq_s32(fixed_up_x, shift_vec);
301+
}
302+
303+
template <>
304+
inline int16x8_t RoundingDivideByPOT(int16x8_t x, int16x8_t exponent) {
305+
const int16x8_t shift_vec = vnegq_s16(exponent);
306+
const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15);
307+
const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
308+
return vrshlq_s16(fixed_up_x, shift_vec);
309+
}
310+
285311
template <int Exponent>
286312
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> {
287313
static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); }

internal/dispatch_gemm_shape.h

+16
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,22 @@ struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
8585
}
8686
};
8787

88+
template <VectorShape Shape>
89+
struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
90+
typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
91+
static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
92+
typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
93+
DstType;
94+
static DstType Run(const SrcType& src) {
95+
DstType dst;
96+
dst.result_fixedpoint_multiplier =
97+
Transpose(src.result_fixedpoint_multiplier);
98+
dst.result_exponent = Transpose(src.result_exponent);
99+
dst.result_offset_after_shift = src.result_offset_after_shift;
100+
return dst;
101+
}
102+
};
103+
88104
template <typename VectorMapType>
89105
struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
90106
typedef OutputStageBiasAddition<VectorMapType> SrcType;

internal/kernel.h

+21-4
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,24 @@ struct KernelSideFormat {
145145
static const int kCells = tCells;
146146
static const int kWidth = kCells * Cell::kWidth;
147147
static const int kDepth = Cell::kDepth;
148-
typedef std::uint8_t Scalar;
148+
typedef std::uint8_t Scalar; // The scalar type of the Format.
149+
typedef std::uint8_t InputScalar; // The scalar type of the original input.
149150
};
150151

152+
// KernelSideFormat for int8 fast kernel trick. The original input is uint8, but
153+
// packs converts it to int8.
151154
template <typename tCellFormat, int tCells>
152155
struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> {
153156
typedef std::int8_t Scalar;
157+
typedef std::uint8_t InputScalar;
158+
};
159+
160+
// KernelSideFormat for int8 inputs, enabling int8 fast kernel trick without
161+
// pack conversion.
162+
template <typename tCellFormat, int tCells>
163+
struct KernelSideFormatInt8Inputs : KernelSideFormat<tCellFormat, tCells> {
164+
typedef std::int8_t Scalar;
165+
typedef std::int8_t InputScalar;
154166
};
155167

156168
// KernelFormat describes fully the input data layout that a kernel expects.
@@ -216,19 +228,24 @@ struct KernelBase {
216228
virtual ~KernelBase() {}
217229
};
218230

219-
template <typename KernelScalarType>
231+
template <typename InputKernelScalarType, typename KernelScalarType>
220232
struct ZeroPointInputValue {};
221233

222234
template <>
223-
struct ZeroPointInputValue<std::uint8_t> {
235+
struct ZeroPointInputValue<std::uint8_t, std::uint8_t> {
224236
static constexpr std::uint8_t kValue = 0;
225237
};
226238

227239
template <>
228-
struct ZeroPointInputValue<std::int8_t> {
240+
struct ZeroPointInputValue<std::uint8_t, std::int8_t> {
229241
static constexpr std::uint8_t kValue = 128;
230242
};
231243

244+
template <>
245+
struct ZeroPointInputValue<std::int8_t, std::int8_t> {
246+
static constexpr std::uint8_t kValue = 0;
247+
};
248+
232249
} // namespace gemmlowp
233250

234251
#endif // GEMMLOWP_INTERNAL_KERNEL_H_

internal/kernel_default.h

+39-29
Original file line numberDiff line numberDiff line change
@@ -20,74 +20,84 @@
2020

2121
#include "../public/bit_depth.h"
2222
#include "common.h"
23+
#include "kernel.h"
2324
#include "kernel_reference.h"
2425

2526
namespace gemmlowp {
2627

27-
template <bool MaxProductIsLessThan4096, bool LhsAlwaysNonzero>
28+
template <bool MaxProductIsLessThan4096, bool IsUnsigned, bool LhsNonZero>
2829
struct DefaultKernelImpl {};
2930

30-
// Partial specialization implementing the logic that if we want to use
31-
// a kernel for LhsAlwaysNonzero but do not have such a kernel, then we fall
32-
// back to a generic kernel not taking advantage of LhsAlwaysNonzero.
33-
template <bool LhsAlwaysNonzero>
34-
struct DefaultKernelImpl<true, LhsAlwaysNonzero>
35-
: DefaultKernelImpl<false, LhsAlwaysNonzero> {};
36-
3731
// Partial specialization implementing the logic that if we want to use
3832
// a kernel for MaxProductIsLessThan4096 but do not have such a kernel, then we
3933
// fall back to a generic kernel not taking advantage of
4034
// MaxProductIsLessThan4096.
35+
template <bool LhsNonZero>
36+
struct DefaultKernelImpl<true, true, LhsNonZero>
37+
: DefaultKernelImpl<false, true, LhsNonZero> {};
38+
39+
// Partial specialization implementing the logic that if we want to use
40+
// a kernel for LhsNonZero but do not have such a kernel, then we fall
41+
// back to a generic kernel not taking advantage of LhsNonZero.
4142
template <bool MaxProductIsLessThan4096>
42-
struct DefaultKernelImpl<MaxProductIsLessThan4096, true>
43-
: DefaultKernelImpl<MaxProductIsLessThan4096, false> {};
43+
struct DefaultKernelImpl<MaxProductIsLessThan4096, true, true>
44+
: DefaultKernelImpl<MaxProductIsLessThan4096, true, false> {};
4445

4546
template <typename BitDepthParams>
4647
struct DefaultKernel
4748
: DefaultKernelImpl<(BitDepthParams::LhsRange::kMaxValue *
4849
BitDepthParams::RhsRange::kMaxValue <
4950
4096),
50-
(BitDepthParams::LhsRange::kMinValue > 0)> {};
51+
(BitDepthParams::LhsRange::kMinValue >= 0),
52+
(BitDepthParams::LhsRange::kMinValue > 0 ||
53+
(BitDepthParams::LhsRange::kMaxValue <= 127 &&
54+
BitDepthParams::LhsRange::kMinValue > -128))> {};
5155

5256
} // end namespace gemmlowp
5357

54-
#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, \
55-
LhsAlwaysNonzero, Kernel) \
56-
namespace gemmlowp { \
57-
template <> \
58-
struct DefaultKernelImpl<MaxProductIsLessThan4096, LhsAlwaysNonzero> \
59-
: Kernel {}; \
58+
#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, IsUnsigned, \
59+
LhsAlwaysNonZero, Kernel) \
60+
namespace gemmlowp { \
61+
template <> \
62+
struct DefaultKernelImpl<MaxProductIsLessThan4096, IsUnsigned, \
63+
LhsAlwaysNonZero> : Kernel {}; \
6064
}
6165

66+
// User-provided int8 inputs is only supported in the NEON path currently.
6267
#if defined GEMMLOWP_NEON_32
6368
#include "kernel_neon.h"
64-
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_32_Kernel12x4Depth2)
65-
GEMMLOWP_SET_DEFAULT_KERNEL(true, false,
69+
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_32_Kernel12x4Depth2)
70+
GEMMLOWP_SET_DEFAULT_KERNEL(true, true, false,
6671
NEON_32_Kernel12x4Depth2Assuming12BitProducts)
67-
GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
72+
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true,
6873
NEON_32bit_GEMM_Int8Operands_LhsNonzero)
74+
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true,
75+
NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs)
6976
#elif defined GEMMLOWP_NEON_64
7077
#include "kernel_neon.h"
7178
#if defined GEMMLOWP_DOTPROD_KERNEL
72-
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth4_dotprod)
79+
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false,
80+
NEON_64_Kernel12x8Depth4_dotprod)
7381
#else
74-
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth2)
75-
GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
82+
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_64_Kernel12x8Depth2)
83+
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true,
7684
NEON_64bit_GEMM_Int8Operands_LhsNonzero)
7785
#endif
86+
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true,
87+
NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs)
7888
#elif defined(GEMMLOWP_MSA)
7989
#include "kernel_msa.h"
80-
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, MSA_Kernel12x8Depth2)
81-
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, MSA_GEMM_Int8Operands_LhsNonzero)
90+
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, MSA_Kernel12x8Depth2)
91+
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true, MSA_GEMM_Int8Operands_LhsNonzero)
8292
#elif defined GEMMLOWP_SSE4_32
8393
#include "kernel_sse.h"
84-
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_32_Kernel4x4Depth2)
94+
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_32_Kernel4x4Depth2)
8595
#elif defined GEMMLOWP_SSE4_64
8696
#include "kernel_sse.h"
87-
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_64_Kernel12x4Depth2)
97+
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_64_Kernel12x4Depth2)
8898
#elif defined GEMMLOWP_AVX2_64
8999
#include "kernel_avx.h"
90-
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, AVX2_64_Kernel24x8Depth2)
100+
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, AVX2_64_Kernel24x8Depth2)
91101
#else
92102
#include "kernel_reference.h"
93103
namespace gemmlowp {
@@ -96,7 +106,7 @@ typedef ReferenceKernel<KernelFormat<
96106
KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > >
97107
DefaultReferenceKernel;
98108
}
99-
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, DefaultReferenceKernel)
109+
GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, DefaultReferenceKernel)
100110
#endif
101111

102112
#endif // GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_

internal/kernel_neon.h

+22
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,17 @@ struct NEON_32bit_GEMM_Int8Operands_LhsNonzero : KernelBase {
924924
}
925925
};
926926

927+
// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that
928+
// requires that user inputs were originally int8. This avoids the uint8->int8
929+
// conversion in the pack step.
930+
struct NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs
931+
: NEON_32bit_GEMM_Int8Operands_LhsNonzero {
932+
typedef KernelFormat<
933+
KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
934+
KernelSideFormatInt8Inputs<CellFormat<2, 16, CellOrder::WidthMajor>, 1> >
935+
Format;
936+
};
937+
927938
#endif // GEMMLOWP_NEON_32
928939

929940
// The kernels here are specifically arm 64bit assembly, not arm 32bit.
@@ -1265,6 +1276,17 @@ struct NEON_64bit_GEMM_Int8Operands_LhsNonzero : KernelBase {
12651276
}
12661277
};
12671278

1279+
// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that
1280+
// requires that user inputs were originally int8. This avoids the uint8->int8
1281+
// conversion in the pack step.
1282+
struct NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs
1283+
: NEON_64bit_GEMM_Int8Operands_LhsNonzero {
1284+
typedef KernelFormat<
1285+
KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
1286+
KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
1287+
Format;
1288+
};
1289+
12681290
// Our main GEMM kernel.
12691291
struct NEON_64_Kernel12x8Depth2 : KernelBase {
12701292
typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,

internal/output.h

+66-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <cmath>
2323
#include <tuple>
2424
#include <type_traits>
25+
#include <typeinfo>
2526

2627
#include "../fixedpoint/fixedpoint.h"
2728
#include "../public/output_stages.h"
@@ -179,7 +180,47 @@ struct OutputStageEvalBufferImpl<OutputStageScaleInt32ByFixedPointAndExponent,
179180
int right_shift;
180181
};
181182

182-
// Implementation of OutputStageSaturatingCastToUint8 for scalar data
183+
template <int Rows, int Cols, VectorShape Shape>
184+
struct OutputStageEvalImpl<
185+
OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>,
186+
RegisterBlock<std::int32_t, Rows, Cols>> {
187+
typedef RegisterBlock<std::int32_t, Rows, Cols> InputType;
188+
typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType;
189+
190+
typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> OutputStage;
191+
192+
OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
193+
194+
OutputType Eval(InputType input, int row, int col) const {
195+
OutputType output;
196+
const int pos = Shape == VectorShape::Row ? col : row;
197+
using RegisterType = typename InputType::RegisterType;
198+
const RegisterType result_offset_after_shift =
199+
Dup<RegisterType>(output_stage.result_offset_after_shift);
200+
auto left_shift =
201+
LoadForBroadcasting<InputType>(output_stage.result_exponent, pos);
202+
auto right_shift =
203+
LoadForBroadcasting<InputType>(output_stage.result_exponent, pos);
204+
const auto result_fixedpoint_multiplier = LoadForBroadcasting<InputType>(
205+
output_stage.result_fixedpoint_multiplier, pos);
206+
for (int i = 0; i < decltype(left_shift)::kRegisterCount; i++) {
207+
left_shift.buf.reg[i] = Max(left_shift.buf.reg[i], 0);
208+
right_shift.buf.reg[i] = Max(-right_shift.buf.reg[i], 0);
209+
}
210+
const auto mulhigh_val = BroadcastSaturatingRoundingDoublingHighMul(
211+
BroadcastShiftLeft(input, left_shift), result_fixedpoint_multiplier);
212+
const auto rdpot_val =
213+
BroadcastRoundingDivideByPOT(mulhigh_val, right_shift);
214+
for (int i = 0; i < InputType::kRegisterCount; i++) {
215+
output.buf.reg[i] = Add(rdpot_val.buf.reg[i], result_offset_after_shift);
216+
}
217+
return output;
218+
}
219+
220+
const OutputStage& output_stage;
221+
};
222+
223+
// Implementation of OutputStageSaturatingCastToUint8 for scalar data.
183224
template <int Size>
184225
struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
185226
RegisterBuffer<std::int32_t, Size>> {
@@ -202,7 +243,30 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
202243
}
203244
};
204245

205-
// Implementation of OutputStageSaturatingCastToInt16 for scalar data
246+
// Implementation of OutputStageSaturatingCastToInt8 for scalar data.
247+
template <int Size>
248+
struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
249+
RegisterBuffer<std::int32_t, Size>> {
250+
typedef RegisterBuffer<std::int32_t, Size> InputType;
251+
typedef RegisterBuffer<std::int8_t, Size> OutputType;
252+
static_assert(InputType::kRegisterLanes == 1,
253+
"This path is only for scalar values");
254+
255+
typedef OutputStageSaturatingCastToInt8 OutputStage;
256+
257+
OutputStageEvalBufferImpl(const OutputStage&) {}
258+
259+
OutputType Eval(InputType input) const {
260+
OutputType output;
261+
for (int i = 0; i < InputType::kRegisterCount; i++) {
262+
std::int32_t data = input.reg[i];
263+
output.reg[i] = data > 127 ? 127 : data < -128 ? -128 : data;
264+
}
265+
return output;
266+
}
267+
};
268+
269+
// Implementation of OutputStageSaturatingCastToInt16 for scalar data.
206270
template <int Size>
207271
struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
208272
RegisterBuffer<std::int32_t, Size>> {

0 commit comments

Comments
 (0)