Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a quantize_symmetric operation and the corresponding CPU kernel #83

Merged
merged 3 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion bench/library/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
from quanto.library import disable_extensions


def get_quantize_symmetric_bench(src_dtype, dst_dtype, per_axis, device):
a = torch.rand([10240, 10240], dtype=src_dtype).to(device)
scale = torch.fill((10240,), 0.5) if per_axis else torch.tensor(0.5)
scale = scale.to(src_dtype).to(device)

def bench_fn():
return torch.ops.quanto.quantize_symmetric(a, scale, dst_dtype)

return bench_fn


def get_unpack_bench(bits, device):
qmax = 2**bits
a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device)
Expand Down Expand Up @@ -69,6 +80,9 @@ def elapsed_time(self, other):


GET_BENCH_FUNCTIONS = {
"quantize_symmetric_fp32_int8_per_tensor": lambda device: get_quantize_symmetric_bench(
torch.float32, torch.int8, False, device
),
"unpack_2bit": lambda device: get_unpack_bench(2, device),
"unpack_4bit": lambda device: get_unpack_bench(4, device),
}
Expand All @@ -89,7 +103,7 @@ def main():
device = torch.device("cpu")
else:
device = torch.device(args.device)
all_kernels = ["unpack_2bit", "unpack_4bit"]
all_kernels = GET_BENCH_FUNCTIONS.keys()
kernels = all_kernels if args.kernel is None else [args.kernel]
for kernel in kernels:
get_bench_fn = GET_BENCH_FUNCTIONS[kernel]
Expand Down
6 changes: 6 additions & 0 deletions quanto/library/ext/cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def ext():
_ext = load(
name="quanto_cpp",
sources=[
f"{module_path}/quantize.cpp",
f"{module_path}/unpack.cpp",
f"{module_path}/pybind_module.cpp",
],
Expand All @@ -27,6 +28,11 @@ def ext():
return _ext


@torch.library.impl("quanto_ext::quantize_symmetric", ["CPU"])
def quantize_symmetric_cpp(t: torch.Tensor, scale: torch.Tensor, dtype: torch.Tensor.dtype):
return ext().quantize_symmetric(t, scale, dtype)


@impl("quanto_ext::unpack", ["CPU", "CUDA"])
def unpack_cpp(t: torch.Tensor, bits: int):
return ext().unpack(t, bits)
13 changes: 13 additions & 0 deletions quanto/library/ext/cpp/pybind_module.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
#include <torch/extension.h>
#include "quantize.h"
#include "unpack.h"

// !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types,
// and need to be explicitly converted using dedicated helpers before calling a C++ method.
// As a consequence, when an operation takes such an object as parameter, instead
// of creating a binding directly to the C++ method, you must create a binding to a
// lambda method that converts the unmapped types and calls the C++ method.
// See the binding of quantize_symmetric for instance.

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("quantize_symmetric",
[](const torch::Tensor& t, const torch::Tensor& scale, py::object dtype) {
return quantize_symmetric(t,
scale,
torch::python::detail::py_object_to_dtype(dtype));
}, "quantize_symmetric");
m.def("unpack", &unpack, "unpack");
}
64 changes: 64 additions & 0 deletions quanto/library/ext/cpp/quantize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "quantize.h"
#include <torch/extension.h>


template <typename T>
torch::Tensor quantize_symmetric_per_tensor(const torch::Tensor& input, const torch::Tensor& scale) {
torch::Tensor output = torch::empty_like(input, c10::TensorOptions(c10::kChar).dtype(), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto qdata = reinterpret_cast<int8_t*>(output.data_ptr());
auto numel = input.numel();
const T* const data = input.data_ptr<T>();
float float_scale = scale.data_ptr<T>()[0];
float inv_scale = float_scale == 0 ? 1.0f : 1.0f / float_scale;
for (const auto i : c10::irange(numel)) {
int64_t qvalue = lrintf(std::nearbyint(data[i] * inv_scale));
qvalue = std::max(-127LL, std::min(qvalue, 127LL));
qdata[i] = static_cast<int8_t>(qvalue);
}
return output;
}


int get_scale_axis(const torch::Tensor& scale) {
int axis = -1;
auto scale_dims = scale.sizes();
for (int i = 0; i < scale_dims.size(); ++i) {
if (scale_dims[i] != 1) {
axis = i;
}
}
return axis;
}


torch::Tensor quantize_symmetric_char(const torch::Tensor& input,
const torch::Tensor& scale) {
int axis = get_scale_axis(scale);
if (axis == -1) {
auto scale_dtype = scale.dtype();
if (scale_dtype == at::ScalarType::Float) {
return quantize_symmetric_per_tensor<float>(input, scale);
}
if (scale_dtype == at::ScalarType::Half) {
return quantize_symmetric_per_tensor<at::Half>(input, scale);
}
TORCH_CHECK(false, "Unsupported scale dtype:", scale_dtype)
}
TORCH_CHECK(false, "symmetric per-axis is not supported")
}


torch::Tensor quantize_symmetric(const torch::Tensor& input,
const torch::Tensor& scale,
at::ScalarType dtype) {
bool scalar_scale = (scale.sizes().size() == 0);
bool broadcastable_scale = (input.sizes().size() == scale.sizes().size());
TORCH_CHECK(scalar_scale || broadcastable_scale,
"Quantization scale must be scalar or broadcastable to the base tensor.")
TORCH_CHECK((scale.dtype() == at::ScalarType::Float) || (scale.dtype() == at::ScalarType::Half),
"Quantization scale must be float or float16.")
if (dtype == at::ScalarType::Char) {
return quantize_symmetric_char(input, scale);
}
TORCH_CHECK_NOT_IMPLEMENTED(false, "quantize_symmetric not supported for ", dtype)
}
5 changes: 5 additions & 0 deletions quanto/library/ext/cpp/quantize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <torch/extension.h>

torch::Tensor quantize_symmetric(const torch::Tensor& input,
const torch::Tensor& scale,
at::ScalarType dtype);
12 changes: 7 additions & 5 deletions quanto/library/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,20 @@ def define(name, schema):
for libname in ["quanto", "quanto_py", "quanto_ext"]:
torch.library.define(f"{libname}::{name}", schema)

# Provide the inplementation for all dispatch key in the main library
# Provide the inplementation for all dispatch keys in the main library
@torch.library.impl(f"quanto::{name}", "default")
def impl(*args, **kwargs):
if _ext_enabled:
try:
return getattr(torch.ops.quanto_ext, name)(*args, **kwargs)
except Exception as e:
warnings.warn(
f"A {type(e)} exception occured while calling the optimized kernel for quanto::{name}."
"Falling back to default implementation."
)
if isinstance(e, NotImplementedError):
message = f"No optimized kernel found for quanto::{name}."
else:
message = f"An exception was raised while calling the optimized kernel for quanto::{name}: {e}"
warnings.warn(message + " Falling back to default implementation.")
return getattr(torch.ops.quanto_py, name)(*args, **kwargs)


define("quantize_symmetric", "(Tensor self, Tensor scale, ScalarType dtype) -> Tensor")
define("unpack", "(Tensor self, int bits) -> Tensor")
1 change: 1 addition & 0 deletions quanto/library/python/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .quantize import *
from .unpack import *
15 changes: 15 additions & 0 deletions quanto/library/python/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch


def dtype_info(dtype):
info = torch.finfo if dtype.is_floating_point else torch.iinfo
return info(dtype)


@torch.library.impl("quanto_py::quantize_symmetric", "default")
def quantize_symmetric(t: torch.Tensor, scale: torch.Tensor, dtype: torch.Tensor.dtype):
info = dtype_info(dtype)
data = t / scale
if not dtype.is_floating_point:
data = torch.round(data)
return torch.clamp(data, min=info.min, max=info.max).to(dtype)
32 changes: 32 additions & 0 deletions test/library/test_quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import torch
from helpers import random_tensor


@pytest.mark.parametrize("shape", [(12,), (32, 32)], ids=["vector", "matrix"])
@pytest.mark.parametrize("src_dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"])
@pytest.mark.parametrize("dst_dtype", [torch.int8, torch.float8_e4m3fn], ids=["int8", "float8"])
@pytest.mark.parametrize("per_axis", [True, False], ids=["per-axis", "per-tensor"])
def test_quantize_symmetric(shape, src_dtype, dst_dtype, per_axis, device):
if device.type == "mps" and dst_dtype != torch.int8:
pytest.skip("float8 types are not supported on MPS device")
# Craft manually data and scale
if dst_dtype.is_floating_point:
data = random_tensor(shape, torch.float16).to(dst_dtype).to(device)
else:
data = torch.randint(-127, 127, shape, dtype=dst_dtype).to(device)
if per_axis:
scale_shape = (shape[0],) + (1,) * (len(shape) - 1)
else:
scale_shape = ()
scale = torch.rand(scale_shape, dtype=src_dtype).to(device)
# Dequantize to obtain a float tensor
t = data.to(src_dtype) * scale
qdata = torch.ops.quanto.quantize_symmetric(t, scale, dst_dtype)
assert qdata.dtype == dst_dtype
assert qdata.shape == shape
# float8 tensors direct comparison is not supported yet on CPU
if dst_dtype.is_floating_point:
assert torch.equal(qdata.to(torch.float16), data.to(torch.float16))
else:
assert torch.equal(qdata, data)
Loading