Skip to content

Commit b7a4e51

Browse files
Nicoshevfacebook-github-bot
authored andcommitted
Enable KleidiAI for FP32 (#3818)
Summary: Pull Request resolved: #3818 X-link: facebookresearch/FBGEMM#903 This diff enables KleidiAI for FP32 operations It pulls #3751 and tweaks the code to treat beta nans as #0, likewise done for FP16 A 10x performance increase has been observed on FP32xFP32->FP32 matmul operations Reviewed By: embg Differential Revision: D70398308 fbshipit-source-id: 1a419511c3eb8eb293e9fd30f0a42450d0b24e87
1 parent 746bddd commit b7a4e51

6 files changed

+2230
-55
lines changed

Makefile.FP16Benchmark.aarch64

-46
This file was deleted.

include/fbgemm/FbgemmFPCommon.h

+33-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
3-
* All rights reserved.
3+
* Copyright 2024-2025 Arm Limited and/or its affiliates
4+
* <[email protected]> All rights reserved.
45
*
56
* This source code is licensed under the BSD-style license found in the
67
* LICENSE file in the root directory of this source tree.
@@ -57,6 +58,22 @@ struct GemmParams<float16> {
5758
#endif
5859
};
5960

61+
template <>
62+
struct GemmParams<float> {
63+
uint64_t k;
64+
float* A;
65+
const float* B;
66+
float beta;
67+
float* C;
68+
uint64_t ldc;
69+
uint64_t b_block_cols;
70+
#ifdef FBGEMM_ENABLE_KLEIDIAI
71+
uint64_t lda;
72+
#else
73+
uint64_t b_block_size;
74+
#endif
75+
};
76+
6077
template <typename T>
6178
using funcptr_t = void (*)(GemmParams<T>*);
6279
template <typename T>
@@ -175,7 +192,9 @@ void cblas_gemm_compute(
175192
assert(kernel_nrows * kb < static_cast<int64_t>(scratchpad->size()));
176193
if (m != 1) {
177194
#ifdef FBGEMM_ENABLE_KLEIDIAI
178-
if constexpr (std::is_same<T, float16>::value) {
195+
if constexpr (
196+
std::is_same<T, float16>::value ||
197+
std::is_same<T, float>::value) {
179198
gp.A = const_cast<float*>(&A[m2 * k + k_ind]);
180199
} else {
181200
#endif
@@ -201,7 +220,9 @@ void cblas_gemm_compute(
201220
gp.ldc = ldc * sizeof(C[0]);
202221
gp.b_block_cols = nbcol;
203222
#ifdef FBGEMM_ENABLE_KLEIDIAI
204-
if constexpr (std::is_same<T, float16>::value) {
223+
if constexpr (
224+
std::is_same<T, float16>::value ||
225+
std::is_same<T, float>::value) {
205226
gp.lda = k * sizeof(A[0]);
206227
} else {
207228
#endif
@@ -218,7 +239,9 @@ void cblas_gemm_compute(
218239
gp.b_block_cols = jb_end - jb_begin;
219240
if (gp.b_block_cols) {
220241
#ifdef FBGEMM_USE_REF_KERNEL
221-
if constexpr (std::is_same<T, float16>::value) {
242+
if constexpr (
243+
std::is_same<T, float16>::value ||
244+
std::is_same<T, float>::value) {
222245
kernels[kernel_nrows](&gp);
223246
} else {
224247
ref_kernel<T>(kernel_nrows, &gp, C, m, n, simd_width);
@@ -238,7 +261,9 @@ void cblas_gemm_compute(
238261
gp.b_block_cols = jb_end - jb_begin;
239262
if (gp.b_block_cols) {
240263
#ifdef FBGEMM_USE_REF_KERNEL
241-
if constexpr (std::is_same<T, float16>::value) {
264+
if constexpr (
265+
std::is_same<T, float16>::value ||
266+
std::is_same<T, float>::value) {
242267
kernels[kernel_nrows](&gp);
243268
} else {
244269
ref_kernel(kernel_nrows, &gp, C, m, n, simd_width);
@@ -269,7 +294,9 @@ void cblas_gemm_compute(
269294
gp.ldc = Bp.blockColSize() * sizeof(C[0]);
270295
gp.b_block_cols = 1;
271296
#ifdef FBGEMM_USE_REF_KERNEL
272-
if constexpr (std::is_same<T, float16>::value) {
297+
if constexpr (
298+
std::is_same<T, float16>::value ||
299+
std::is_same<T, float>::value) {
273300
kernels[kernel_nrows](&gp);
274301
} else {
275302
ref_kernel<T>(

include/fbgemm/FbgemmPackMatrixB.h

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
3-
* All rights reserved.
3+
* Copyright 2024-2025 Arm Limited and/or its affiliates
4+
* <[email protected]> All rights reserved.
45
*
56
* This source code is licensed under the BSD-style license found in the
67
* LICENSE file in the root directory of this source tree.
@@ -63,7 +64,7 @@ class PackedGemmMatrixB {
6364
const float* smat,
6465
const int brow = 512)
6566
: nrow_(nrow), ncol_(ncol), brow_(brow), kernel_ncol_blocks_(2) {
66-
#if defined(FBGEMM_ENABLE_KLEIDIAI)
67+
#ifdef FBGEMM_ENABLE_KLEIDIAI
6768
if (std::is_same<T, float16>::value) {
6869
kernel_ncol_blocks_ = 1;
6970
}
@@ -92,7 +93,7 @@ class PackedGemmMatrixB {
9293
nbcol_(nbcol),
9394
size_(size),
9495
kernel_ncol_blocks_(2) {
95-
#if defined(FBGEMM_ENABLE_KLEIDIAI)
96+
#ifdef FBGEMM_ENABLE_KLEIDIAI
9697
if (std::is_same<T, float16>::value) {
9798
kernel_ncol_blocks_ = 1;
9899
}
@@ -120,6 +121,11 @@ class PackedGemmMatrixB {
120121
nbcol_(nbcol),
121122
size_(size),
122123
kernel_ncol_blocks_(kernel_ncol_blocks) {
124+
#ifdef FBGEMM_ENABLE_KLEIDIAI
125+
if (std::is_same<T, float16>::value) {
126+
kernel_ncol_blocks_ = 1;
127+
}
128+
#endif
123129
pmat_ = static_cast<T*>(pmat);
124130
packed_ = true;
125131
pmat_passed_in = true;

src/fp32/FbgemmFP32.cc

+29
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* Copyright 2025 Arm Limited and/or its affiliates <[email protected]>
34
* All rights reserved.
45
*
56
* This source code is licensed under the BSD-style license found in the
@@ -11,9 +12,15 @@
1112
#include <cmath>
1213
#include <utility>
1314

15+
#ifndef __aarch64__
1416
#include "./FbgemmFP32UKernelsAvx2.h"
1517
#include "./FbgemmFP32UKernelsAvx512.h"
1618
#include "./FbgemmFP32UKernelsAvx512_256.h"
19+
#else
20+
#ifdef FBGEMM_ENABLE_KLEIDIAI
21+
#include "./KleidiAIFP32UKernelsNeon.h"
22+
#endif
23+
#endif
1724
#include "fbgemm/Fbgemm.h"
1825
#include "fbgemm/FbgemmFPCommon.h"
1926

@@ -80,6 +87,19 @@ constexpr kernel_array_t<float> kernel_f32_avx512_256 = {
8087
nullptr};
8188
#endif
8289

90+
#ifdef __aarch64__
91+
#ifdef FBGEMM_ENABLE_KLEIDIAI
92+
constexpr kernel_array_t<float> kernel_fp32_neon = {
93+
nullptr,
94+
kleidiai::gemmkernel_1x2_Neon_fp32_fA0fB0fC0,
95+
kleidiai::gemmkernel_2x2_Neon_fp32_fA0fB0fC0,
96+
kleidiai::gemmkernel_3x2_Neon_fp32_fA0fB0fC0,
97+
kleidiai::gemmkernel_4x2_Neon_fp32_fA0fB0fC0,
98+
kleidiai::gemmkernel_5x2_Neon_fp32_fA0fB0fC0,
99+
kleidiai::gemmkernel_6x2_Neon_fp32_fA0fB0fC0,
100+
};
101+
#endif
102+
#endif
83103
} // namespace
84104

85105
template <>
@@ -90,9 +110,18 @@ const isa_descriptor<float>& getIsaHandlers(inst_set_t isa, float) {
90110
std::make_tuple(kernel_f32_avx512, partition_avx512);
91111
static isa_descriptor<float> avx512_256_descriptor =
92112
std::make_tuple(kernel_f32_avx512_256, partition_avx512);
113+
#ifdef __aarch64__
114+
#ifdef FBGEMM_ENABLE_KLEIDIAI
115+
static isa_descriptor<float> neon_descriptor =
116+
std::make_tuple(kernel_fp32_neon, partition_sve128);
117+
#endif
118+
#endif
93119

94120
switch (isa) {
95121
case inst_set_t::sve:
122+
#ifdef FBGEMM_ENABLE_KLEIDIAI
123+
return neon_descriptor;
124+
#endif
96125
case inst_set_t::anyarch:
97126
case inst_set_t::avx2:
98127
return avx2_descriptor;

0 commit comments

Comments
 (0)