Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements serde-lexical #676

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ serde-str = ["serde-with-str"]
serde-with-arbitrary-precision = ["serde", "serde_json/arbitrary_precision", "serde_json/std"]
serde-with-float = ["serde"]
serde-with-str = ["serde"]
serde-lexical = ["std", "serde"]
std = ["arrayvec/std", "borsh?/std", "bytes?/std", "rand?/std", "rkyv?/std", "serde?/std", "serde_json?/std"]
tokio-pg = ["db-tokio-postgres"] # Backwards compatability

Expand Down
5 changes: 5 additions & 0 deletions make/tests/serde.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies = [
"test-serde-with-arbitrary-precision",
"test-serde-with-float",
"test-serde-with-str",
"test-serde-lexical",
]

[tasks.test-serde-float]
Expand Down Expand Up @@ -41,3 +42,7 @@ args = ["test", "--workspace", "--tests", "--features=serde-with-float", "serde"
[tasks.test-serde-with-str]
command = "cargo"
args = ["test", "--workspace", "--tests", "--features=serde-with-str", "serde", "--", "--skip", "generated"]

[tasks.test-serde-lexical]
command = "cargo"
args = ["test", "--workspace", "--tests", "--features=serde-lexical", "serde", "--", "--skip", "generated"]
210 changes: 205 additions & 5 deletions src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ pub mod str_option {
}
}

#[cfg(not(feature = "serde-str"))]
#[cfg(not(any(feature = "serde-str", feature = "serde-lexical")))]
impl<'de> serde::Deserialize<'de> for Decimal {
fn deserialize<D>(deserializer: D) -> Result<Decimal, D::Error>
where
Expand Down Expand Up @@ -517,7 +517,7 @@ impl<'de> serde::de::Deserialize<'de> for DecimalFromString {
}
}

#[cfg(not(feature = "serde-float"))]
#[cfg(not(any(feature = "serde-float", feature = "serde-lexical")))]
impl serde::Serialize for Decimal {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
Expand Down Expand Up @@ -551,6 +551,164 @@ impl serde::Serialize for Decimal {
}
}

/// Serialize/deserialize Decimals preserving the lexical order in the serialized form.
/// This is particularly useful to keep the decimal ordered when stored as a key in a key value store.
///
/// ```
/// # use serde::{Serialize, Deserialize};
/// # use rust_decimal::Decimal;
/// # use std::str::FromStr;
///
/// #[derive(Serialize, Deserialize)]
/// pub struct LexicalExample {
/// #[serde(with = "rust_decimal::serde::lexical")]
/// value: Decimal,
/// }
///
/// let value1 = LexicalExample { value: Decimal::from_str("123.45").unwrap() };
/// let value2 = LexicalExample { value: Decimal::from_str("678.123").unwrap() };
/// let bin1 = bincode::serialize(&value1).unwrap();
/// let bin2 = bincode::serialize(&value2).unwrap();
/// assert!(bin1 < bin2);
///
/// ```
#[cfg(feature = "serde-lexical")]
pub mod lexical {
use std::io::{Cursor, Read, Write};
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use serde::de::Visitor;
use super::*;

impl Serialize for Decimal {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut buffer = Vec::new();
self.write_to(&mut buffer).map_err(serde::ser::Error::custom)?;
serializer.serialize_bytes(&buffer)
}
}

impl<'de> Deserialize<'de> for Decimal {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct DecimalVisitor;

impl<'de> Visitor<'de> for DecimalVisitor {
type Value = Decimal;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a byte array representing a Decimal")
}

fn visit_bytes<E>(self, value: &[u8]) -> Result<Decimal, E>
where
E: de::Error,
{
let mut cursor = Cursor::new(value);
Decimal::read_from(&mut cursor).map_err(de::Error::custom)
}
}

deserializer.deserialize_bytes(DecimalVisitor)
}
}

impl Decimal {
// Method to write Decimal to a writer (e.g., Vec<u8>)
fn write_to<W: Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
let is_negative = self.is_sign_negative();
let sign_byte = if is_negative { 0x00 } else { 0xFF };
// Separate the decimal into integer and fractional parts
let value = self.normalize().abs();
let integer_part = value.trunc();
let fractional_part = value - integer_part;
// The fractional part is scaled up to use the largest possible precision.
let fractional_part= if !fractional_part.is_zero() {
fractional_part.mantissa() * 10_i128.pow(28 - fractional_part.scale())
} else {
0i128
};

// Prepare 96-bit integer (maximum Decimal precision) and fractional parts
let mut integer_bytes = Self::serialize_i96(integer_part.mantissa());
let mut fractional_bytes = Self::serialize_i96(fractional_part);
if is_negative {
for byte in &mut integer_bytes {
*byte = !*byte;
}
for byte in &mut fractional_bytes {
*byte = !*byte;
}
}
//
writer.write_all(&[sign_byte])?;
writer.write_all(&integer_bytes)?;
writer.write_all(&fractional_bytes)?;
Ok(())
}

// Method to read Decimal from a reader (e.g., &[u8])
fn read_from<R: Read>(reader: &mut R) -> Result<Decimal, std::io::Error> {
let mut sign_byte = [0u8; 1];
let mut integer_bytes = [0u8; 12];
let mut fractional_bytes = [0u8; 12];

// Read the sign byte, integer part, and fractional part from the reader
reader.read_exact(&mut sign_byte)?;
reader.read_exact(&mut integer_bytes)?;
reader.read_exact(&mut fractional_bytes)?;

let is_negative =sign_byte[0] == 0x00;

if is_negative {
for byte in &mut integer_bytes {
*byte = !*byte;
}
for byte in &mut fractional_bytes {
*byte = !*byte;
}
}

// Convert the 96-bit big-endian byte arrays back to u128
let integer_part = Self::deserialize_i96( integer_bytes);
let fractional_part = Self::deserialize_i96( fractional_bytes);

// Convert integer and fractional parts back to Decimal
let i = Decimal::from_i128_with_scale(integer_part, 0);
let f =Decimal::from_i128_with_scale(fractional_part, 28);

// Reconstitute the decimal value
let mut value = i + f;
// Apply the sign
if is_negative {
value.set_sign_negative(true);
}
Ok(value)
}

/// Serializes an i128 to a 12 bytes array, preserving the lexicographical order.
fn serialize_i96(value: i128) -> [u8; 12] {
let mut buf = [0u8; 12];
// Take the last 12 bytes (96 bits)
buf.copy_from_slice(&value.to_be_bytes()[4..]);
buf
}

/// Deserializes an i128 from a 12 bytes array, restoring the original value.
fn deserialize_i96( buf: [u8; 12]) -> i128 {
let mut int_buf = [0x00; 16];
int_buf[4..].copy_from_slice(&buf);
i128::from_be_bytes(int_buf)
}

}

}

#[cfg(test)]
mod test {
use super::*;
Expand All @@ -562,7 +720,7 @@ mod test {
}

#[test]
#[cfg(not(feature = "serde-str"))]
#[cfg(not(any(feature = "serde-str", feature = "serde-lexical")))]
fn deserialize_valid_decimal() {
let data = [
("{\"amount\":\"1.234\"}", "1.234"),
Expand Down Expand Up @@ -605,7 +763,7 @@ mod test {
}

#[test]
#[cfg(not(feature = "serde-float"))]
#[cfg(not(any(feature = "serde-float", feature = "serde-lexical")))]
fn serialize_decimal() {
let record = Record {
amount: Decimal::new(1234, 3),
Expand All @@ -615,7 +773,7 @@ mod test {
}

#[test]
#[cfg(not(feature = "serde-float"))]
#[cfg(not(any(feature = "serde-float", feature = "serde-lexical")))]
fn serialize_negative_zero() {
let record = Record { amount: -Decimal::ZERO };
let serialized = serde_json::to_string(&record).unwrap();
Expand Down Expand Up @@ -918,4 +1076,46 @@ mod test {
assert_eq!(deserialized.value, original.value);
assert!(deserialized.value.is_none());
}

#[test]
#[cfg(feature = "serde-lexical")]
fn serialize_decimal_lexical() {
let data = [
Decimal::from_i128_with_scale(-1000, 0),
Decimal::from_i128_with_scale(-100, 0),
Decimal::from_i128_with_scale(-10, 0),
Decimal::from_i128_with_scale(-31415926535897932384626433833, 28),
Decimal::from_i128_with_scale(-2, 0),
Decimal::from_i128_with_scale(-102, 2),
Decimal::from_i128_with_scale(-101, 2),
Decimal::from_i128_with_scale(-10000000000000000000000000002, 28),
Decimal::from_i128_with_scale(-10000000000000000000000000001, 28),
Decimal::NEGATIVE_ONE,
Decimal::from_i128_with_scale(-2, 28),
Decimal::from_i128_with_scale(-1, 28),
-Decimal::ZERO,
Decimal::ZERO,
Decimal::from_i128_with_scale(1, 28),
Decimal::from_i128_with_scale(2, 28),
Decimal::ONE,
Decimal::from_i128_with_scale(10000000000000000000000000001, 28),
Decimal::from_i128_with_scale(10000000000000000000000000002, 28),
Decimal::from_i128_with_scale(101, 2),
Decimal::from_i128_with_scale(102, 2),
Decimal::TWO,
Decimal::from_i128_with_scale(31415926535897932384626433833, 28),
Decimal::TEN,
Decimal::ONE_HUNDRED,
Decimal::ONE_THOUSAND,
Decimal::MAX,
];
let mut previous = bincode::serialize(&Decimal::MIN).unwrap();
for value in data {
let encoded = bincode::serialize(&value).unwrap();
let decoded: Decimal = bincode::deserialize(&encoded[..]).unwrap();
assert_eq!(value, decoded);
assert!(previous <= encoded, "{value} -> {previous:?} <= {encoded:?}");
previous = encoded;
}
}
}