Skip to content

Commit d612b6b

Browse files
committed
feat(cpp): add quantize_symmetric CPU kernel
For now only per-tensor quantization is supported.
1 parent 89e365a commit d612b6b

File tree

5 files changed

+103
-1
lines changed

5 files changed

+103
-1
lines changed

bench/library/benchmark.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@
99
from quanto.library import disable_extensions
1010

1111

12+
def get_quantize_symmetric_bench(src_dtype, dst_dtype, per_axis, device):
13+
a = torch.rand([10240, 10240], dtype=src_dtype).to(device)
14+
scale = torch.fill((10240,), 0.5) if per_axis else torch.tensor(0.5)
15+
scale = scale.to(src_dtype).to(device)
16+
17+
def bench_fn():
18+
return torch.ops.quanto.quantize_symmetric(a, scale, dst_dtype)
19+
20+
return bench_fn
21+
22+
1223
def get_unpack_bench(bits, device):
1324
qmax = 2**bits
1425
a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device)
@@ -69,6 +80,9 @@ def elapsed_time(self, other):
6980

7081

7182
GET_BENCH_FUNCTIONS = {
83+
"quantize_symmetric_fp32_int8_per_tensor": lambda device: get_quantize_symmetric_bench(
84+
torch.float32, torch.int8, False, device
85+
),
7286
"unpack_2bit": lambda device: get_unpack_bench(2, device),
7387
"unpack_4bit": lambda device: get_unpack_bench(4, device),
7488
}
@@ -89,7 +103,7 @@ def main():
89103
device = torch.device("cpu")
90104
else:
91105
device = torch.device(args.device)
92-
all_kernels = ["unpack_2bit", "unpack_4bit"]
106+
all_kernels = GET_BENCH_FUNCTIONS.keys()
93107
kernels = all_kernels if args.kernel is None else [args.kernel]
94108
for kernel in kernels:
95109
get_bench_fn = GET_BENCH_FUNCTIONS[kernel]

quanto/library/ext/cpp/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def ext():
1919
_ext = load(
2020
name="quanto_cpp",
2121
sources=[
22+
f"{module_path}/quantize.cpp",
2223
f"{module_path}/unpack.cpp",
2324
f"{module_path}/pybind_module.cpp",
2425
],
@@ -27,6 +28,11 @@ def ext():
2728
return _ext
2829

2930

31+
@torch.library.impl("quanto_ext::quantize_symmetric", ["CPU"])
32+
def quantize_symmetric_cpp(t: torch.Tensor, scale: torch.Tensor, dtype: torch.Tensor.dtype):
33+
return ext().quantize_symmetric(t, scale, dtype)
34+
35+
3036
@impl("quanto_ext::unpack", ["CPU", "CUDA"])
3137
def unpack_cpp(t: torch.Tensor, bits: int):
3238
return ext().unpack(t, bits)
+13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
#include <torch/extension.h>
2+
#include "quantize.h"
23
#include "unpack.h"
34

5+
// !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types,
6+
// and need to be explicitly converted using dedicated helpers before calling a C++ method.
7+
// As a consequence, when an operation takes such an object as parameter, instead
8+
// of creating a binding directly to the C++ method, you must create a binding to a
9+
// lambda method that converts the unmapped types and calls the C++ method.
10+
// See the binding of quantize_symmetric for instance.
411

512
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
13+
m.def("quantize_symmetric",
14+
[](const torch::Tensor& t, const torch::Tensor& scale, py::object dtype) {
15+
return quantize_symmetric(t,
16+
scale,
17+
torch::python::detail::py_object_to_dtype(dtype));
18+
}, "quantize_symmetric");
619
m.def("unpack", &unpack, "unpack");
720
}

quanto/library/ext/cpp/quantize.cpp

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "quantize.h"
2+
#include <torch/extension.h>
3+
4+
5+
template <typename T>
6+
torch::Tensor quantize_symmetric_per_tensor(const torch::Tensor& input, const torch::Tensor& scale) {
7+
torch::Tensor output = torch::empty_like(input, c10::TensorOptions(c10::kChar).dtype(), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
8+
auto qdata = reinterpret_cast<int8_t*>(output.data_ptr());
9+
auto numel = input.numel();
10+
const T* const data = input.data_ptr<T>();
11+
float float_scale = scale.data_ptr<T>()[0];
12+
float inv_scale = float_scale == 0 ? 1.0f : 1.0f / float_scale;
13+
for (const auto i : c10::irange(numel)) {
14+
int64_t qvalue = lrintf(std::nearbyint(data[i] * inv_scale));
15+
qvalue = std::max(-127LL, std::min(qvalue, 127LL));
16+
qdata[i] = static_cast<int8_t>(qvalue);
17+
}
18+
return output;
19+
}
20+
21+
22+
int get_scale_axis(const torch::Tensor& scale) {
23+
int axis = -1;
24+
auto scale_dims = scale.sizes();
25+
for (int i = 0; i < scale_dims.size(); ++i) {
26+
if (scale_dims[i] != 1) {
27+
axis = i;
28+
}
29+
}
30+
return axis;
31+
}
32+
33+
34+
torch::Tensor quantize_symmetric_char(const torch::Tensor& input,
35+
const torch::Tensor& scale) {
36+
int axis = get_scale_axis(scale);
37+
if (axis == -1) {
38+
auto scale_dtype = scale.dtype();
39+
if (scale_dtype == at::ScalarType::Float) {
40+
return quantize_symmetric_per_tensor<float>(input, scale);
41+
}
42+
if (scale_dtype == at::ScalarType::Half) {
43+
return quantize_symmetric_per_tensor<at::Half>(input, scale);
44+
}
45+
TORCH_CHECK(false, "Unsupported scale dtype:", scale_dtype)
46+
}
47+
TORCH_CHECK(false, "symmetric per-axis is not supported")
48+
}
49+
50+
51+
torch::Tensor quantize_symmetric(const torch::Tensor& input,
52+
const torch::Tensor& scale,
53+
at::ScalarType dtype) {
54+
bool scalar_scale = (scale.sizes().size() == 0);
55+
bool broadcastable_scale = (input.sizes().size() == scale.sizes().size());
56+
TORCH_CHECK(scalar_scale || broadcastable_scale,
57+
"Quantization scale must be scalar or broadcastable to the base tensor.")
58+
TORCH_CHECK((scale.dtype() == at::ScalarType::Float) || (scale.dtype() == at::ScalarType::Half),
59+
"Quantization scale must be float or float16.")
60+
if (dtype == at::ScalarType::Char) {
61+
return quantize_symmetric_char(input, scale);
62+
}
63+
TORCH_CHECK_NOT_IMPLEMENTED(false, "quantize_symmetric not supported for ", dtype)
64+
}

quanto/library/ext/cpp/quantize.h

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include <torch/extension.h>
2+
3+
torch::Tensor quantize_symmetric(const torch::Tensor& input,
4+
const torch::Tensor& scale,
5+
at::ScalarType dtype);

0 commit comments

Comments
 (0)