|
| 1 | +#include "quantize.h" |
| 2 | +#include <torch/extension.h> |
| 3 | + |
| 4 | + |
| 5 | +template <typename T> |
| 6 | +torch::Tensor quantize_symmetric_per_tensor(const torch::Tensor& input, const torch::Tensor& scale) { |
| 7 | + torch::Tensor output = torch::empty_like(input, c10::TensorOptions(c10::kChar).dtype(), LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| 8 | + auto qdata = reinterpret_cast<int8_t*>(output.data_ptr()); |
| 9 | + auto numel = input.numel(); |
| 10 | + const T* const data = input.data_ptr<T>(); |
| 11 | + float float_scale = scale.data_ptr<T>()[0]; |
| 12 | + float inv_scale = float_scale == 0 ? 1.0f : 1.0f / float_scale; |
| 13 | + for (const auto i : c10::irange(numel)) { |
| 14 | + int64_t qvalue = lrintf(std::nearbyint(data[i] * inv_scale)); |
| 15 | + qvalue = std::max(-127LL, std::min(qvalue, 127LL)); |
| 16 | + qdata[i] = static_cast<int8_t>(qvalue); |
| 17 | + } |
| 18 | + return output; |
| 19 | +} |
| 20 | + |
| 21 | + |
| 22 | +int get_scale_axis(const torch::Tensor& scale) { |
| 23 | + int axis = -1; |
| 24 | + auto scale_dims = scale.sizes(); |
| 25 | + for (int i = 0; i < scale_dims.size(); ++i) { |
| 26 | + if (scale_dims[i] != 1) { |
| 27 | + axis = i; |
| 28 | + } |
| 29 | + } |
| 30 | + return axis; |
| 31 | +} |
| 32 | + |
| 33 | + |
| 34 | +torch::Tensor quantize_symmetric_char(const torch::Tensor& input, |
| 35 | + const torch::Tensor& scale) { |
| 36 | + int axis = get_scale_axis(scale); |
| 37 | + if (axis == -1) { |
| 38 | + auto scale_dtype = scale.dtype(); |
| 39 | + if (scale_dtype == at::ScalarType::Float) { |
| 40 | + return quantize_symmetric_per_tensor<float>(input, scale); |
| 41 | + } |
| 42 | + if (scale_dtype == at::ScalarType::Half) { |
| 43 | + return quantize_symmetric_per_tensor<at::Half>(input, scale); |
| 44 | + } |
| 45 | + TORCH_CHECK(false, "Unsupported scale dtype:", scale_dtype) |
| 46 | + } |
| 47 | + TORCH_CHECK(false, "symmetric per-axis is not supported") |
| 48 | +} |
| 49 | + |
| 50 | + |
| 51 | +torch::Tensor quantize_symmetric(const torch::Tensor& input, |
| 52 | + const torch::Tensor& scale, |
| 53 | + at::ScalarType dtype) { |
| 54 | + bool scalar_scale = (scale.sizes().size() == 0); |
| 55 | + bool broadcastable_scale = (input.sizes().size() == scale.sizes().size()); |
| 56 | + TORCH_CHECK(scalar_scale || broadcastable_scale, |
| 57 | + "Quantization scale must be scalar or broadcastable to the base tensor.") |
| 58 | + TORCH_CHECK((scale.dtype() == at::ScalarType::Float) || (scale.dtype() == at::ScalarType::Half), |
| 59 | + "Quantization scale must be float or float16.") |
| 60 | + if (dtype == at::ScalarType::Char) { |
| 61 | + return quantize_symmetric_char(input, scale); |
| 62 | + } |
| 63 | + TORCH_CHECK_NOT_IMPLEMENTED(false, "quantize_symmetric not supported for ", dtype) |
| 64 | +} |
0 commit comments