From 2e1d91f4afe5964a1393fe80f708332eb2c041f1 Mon Sep 17 00:00:00 2001 From: Emmanuel Keller Date: Thu, 5 Sep 2024 17:36:24 +0100 Subject: [PATCH] Implements serde lexical --- Cargo.toml | 1 + make/tests/serde.toml | 5 + src/serde.rs | 210 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 211 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a2c00f34..35555c93 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,6 +82,7 @@ serde-str = ["serde-with-str"] serde-with-arbitrary-precision = ["serde", "serde_json/arbitrary_precision", "serde_json/std"] serde-with-float = ["serde"] serde-with-str = ["serde"] +serde-lexical = ["std", "serde"] std = ["arrayvec/std", "borsh?/std", "bytes?/std", "rand?/std", "rkyv?/std", "serde?/std", "serde_json?/std"] tokio-pg = ["db-tokio-postgres"] # Backwards compatability diff --git a/make/tests/serde.toml b/make/tests/serde.toml index 34412cc8..ea0ee302 100644 --- a/make/tests/serde.toml +++ b/make/tests/serde.toml @@ -8,6 +8,7 @@ dependencies = [ "test-serde-with-arbitrary-precision", "test-serde-with-float", "test-serde-with-str", + "test-serde-lexical", ] [tasks.test-serde-float] @@ -41,3 +42,7 @@ args = ["test", "--workspace", "--tests", "--features=serde-with-float", "serde" [tasks.test-serde-with-str] command = "cargo" args = ["test", "--workspace", "--tests", "--features=serde-with-str", "serde", "--", "--skip", "generated"] + +[tasks.test-serde-lexical] +command = "cargo" +args = ["test", "--workspace", "--tests", "--features=serde-lexical", "serde", "--", "--skip", "generated"] diff --git a/src/serde.rs b/src/serde.rs index c18e08ec..1ef2d83f 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -276,7 +276,7 @@ pub mod str_option { } } -#[cfg(not(feature = "serde-str"))] +#[cfg(not(any(feature = "serde-str", feature = "serde-lexical")))] impl<'de> serde::Deserialize<'de> for Decimal { fn deserialize(deserializer: D) -> Result where @@ -515,7 +515,7 @@ impl<'de> serde::de::Deserialize<'de> for DecimalFromString { } } -#[cfg(not(feature = "serde-float"))] +#[cfg(not(any(feature = "serde-float", feature = "serde-lexical")))] impl serde::Serialize for Decimal { fn serialize(&self, serializer: S) -> Result where @@ -549,6 +549,164 @@ impl serde::Serialize for Decimal { } } +/// Serialize/deserialize Decimals preserving the lexical order in the serialized form. +/// This is particularly useful to keep the decimal ordered when stored as a key in a key value store. +/// +/// ``` +/// # use serde::{Serialize, Deserialize}; +/// # use rust_decimal::Decimal; +/// # use std::str::FromStr; +/// +/// #[derive(Serialize, Deserialize)] +/// pub struct LexicalExample { +/// #[serde(with = "rust_decimal::serde::lexical")] +/// value: Decimal, +/// } +/// +/// let value1 = LexicalExample { value: Decimal::from_str("123.45").unwrap() }; +/// let value2 = LexicalExample { value: Decimal::from_str("678.123").unwrap() }; +/// let bin1 = bincode::serialize(&value1).unwrap(); +/// let bin2 = bincode::serialize(&value2).unwrap(); +/// assert!(bin1 < bin2); +/// +/// ``` +#[cfg(feature = "serde-lexical")] +pub mod lexical { + use std::io::{Cursor, Read, Write}; + use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; + use serde::de::Visitor; + use super::*; + + impl Serialize for Decimal { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut buffer = Vec::new(); + self.write_to(&mut buffer).map_err(serde::ser::Error::custom)?; + serializer.serialize_bytes(&buffer) + } + } + + impl<'de> Deserialize<'de> for Decimal { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct DecimalVisitor; + + impl<'de> Visitor<'de> for DecimalVisitor { + type Value = Decimal; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a byte array representing a Decimal") + } + + fn visit_bytes(self, value: &[u8]) -> Result + where + E: de::Error, + { + let mut cursor = Cursor::new(value); + Decimal::read_from(&mut cursor).map_err(de::Error::custom) + } + } + + deserializer.deserialize_bytes(DecimalVisitor) + } + } + + impl Decimal { + // Method to write Decimal to a writer (e.g., Vec) + fn write_to(&self, writer: &mut W) -> Result<(), std::io::Error> { + let is_negative = self.is_sign_negative(); + let sign_byte = if is_negative { 0x00 } else { 0xFF }; + // Separate the decimal into integer and fractional parts + let value = self.normalize().abs(); + let integer_part = value.trunc(); + let fractional_part = value - integer_part; + // The fractional part is scaled up to use the largest possible precision. + let fractional_part= if !fractional_part.is_zero() { + fractional_part.mantissa() * 10_i128.pow(28 - fractional_part.scale()) + } else { + 0i128 + }; + + // Prepare 96-bit integer (maximum Decimal precision) and fractional parts + let mut integer_bytes = Self::serialize_i96(integer_part.mantissa()); + let mut fractional_bytes = Self::serialize_i96(fractional_part); + if is_negative { + for byte in &mut integer_bytes { + *byte = !*byte; + } + for byte in &mut fractional_bytes { + *byte = !*byte; + } + } + // + writer.write_all(&[sign_byte])?; + writer.write_all(&integer_bytes)?; + writer.write_all(&fractional_bytes)?; + Ok(()) + } + + // Method to read Decimal from a reader (e.g., &[u8]) + fn read_from(reader: &mut R) -> Result { + let mut sign_byte = [0u8; 1]; + let mut integer_bytes = [0u8; 12]; + let mut fractional_bytes = [0u8; 12]; + + // Read the sign byte, integer part, and fractional part from the reader + reader.read_exact(&mut sign_byte)?; + reader.read_exact(&mut integer_bytes)?; + reader.read_exact(&mut fractional_bytes)?; + + let is_negative =sign_byte[0] == 0x00; + + if is_negative { + for byte in &mut integer_bytes { + *byte = !*byte; + } + for byte in &mut fractional_bytes { + *byte = !*byte; + } + } + + // Convert the 96-bit big-endian byte arrays back to u128 + let integer_part = Self::deserialize_i96( integer_bytes); + let fractional_part = Self::deserialize_i96( fractional_bytes); + + // Convert integer and fractional parts back to Decimal + let i = Decimal::from_i128_with_scale(integer_part, 0); + let f =Decimal::from_i128_with_scale(fractional_part, 28); + + // Reconstitute the decimal value + let mut value = i + f; + // Apply the sign + if is_negative { + value.set_sign_negative(true); + } + Ok(value) + } + + /// Serializes an i128 to a 12 bytes array, preserving the lexicographical order. + fn serialize_i96(value: i128) -> [u8; 12] { + let mut buf = [0u8; 12]; + // Take the last 12 bytes (96 bits) + buf.copy_from_slice(&value.to_be_bytes()[4..]); + buf + } + + /// Deserializes an i128 from a 12 bytes array, restoring the original value. + fn deserialize_i96( buf: [u8; 12]) -> i128 { + let mut int_buf = [0x00; 16]; + int_buf[4..].copy_from_slice(&buf); + i128::from_be_bytes(int_buf) + } + + } + +} + #[cfg(test)] mod test { use super::*; @@ -560,7 +718,7 @@ mod test { } #[test] - #[cfg(not(feature = "serde-str"))] + #[cfg(not(any(feature = "serde-str", feature = "serde-lexical")))] fn deserialize_valid_decimal() { let data = [ ("{\"amount\":\"1.234\"}", "1.234"), @@ -603,7 +761,7 @@ mod test { } #[test] - #[cfg(not(feature = "serde-float"))] + #[cfg(not(any(feature = "serde-float", feature = "serde-lexical")))] fn serialize_decimal() { let record = Record { amount: Decimal::new(1234, 3), @@ -613,7 +771,7 @@ mod test { } #[test] - #[cfg(not(feature = "serde-float"))] + #[cfg(not(any(feature = "serde-float", feature = "serde-lexical")))] fn serialize_negative_zero() { let record = Record { amount: -Decimal::ZERO }; let serialized = serde_json::to_string(&record).unwrap(); @@ -916,4 +1074,46 @@ mod test { assert_eq!(deserialized.value, original.value); assert!(deserialized.value.is_none()); } + + #[test] + #[cfg(feature = "serde-lexical")] + fn serialize_decimal_lexical() { + let data = [ + Decimal::from_i128_with_scale(-1000, 0), + Decimal::from_i128_with_scale(-100, 0), + Decimal::from_i128_with_scale(-10, 0), + Decimal::from_i128_with_scale(-31415926535897932384626433833, 28), + Decimal::from_i128_with_scale(-2, 0), + Decimal::from_i128_with_scale(-102, 2), + Decimal::from_i128_with_scale(-101, 2), + Decimal::from_i128_with_scale(-10000000000000000000000000002, 28), + Decimal::from_i128_with_scale(-10000000000000000000000000001, 28), + Decimal::NEGATIVE_ONE, + Decimal::from_i128_with_scale(-2, 28), + Decimal::from_i128_with_scale(-1, 28), + -Decimal::ZERO, + Decimal::ZERO, + Decimal::from_i128_with_scale(1, 28), + Decimal::from_i128_with_scale(2, 28), + Decimal::ONE, + Decimal::from_i128_with_scale(10000000000000000000000000001, 28), + Decimal::from_i128_with_scale(10000000000000000000000000002, 28), + Decimal::from_i128_with_scale(101, 2), + Decimal::from_i128_with_scale(102, 2), + Decimal::TWO, + Decimal::from_i128_with_scale(31415926535897932384626433833, 28), + Decimal::TEN, + Decimal::ONE_HUNDRED, + Decimal::ONE_THOUSAND, + Decimal::MAX, + ]; + let mut previous = bincode::serialize(&Decimal::MIN).unwrap(); + for value in data { + let encoded = bincode::serialize(&value).unwrap(); + let decoded: Decimal = bincode::deserialize(&encoded[..]).unwrap(); + assert_eq!(value, decoded); + assert!(previous <= encoded, "{value} -> {previous:?} <= {encoded:?}"); + previous = encoded; + } + } }