Skip to content

Commit

Permalink
Fix FromSql for Postgres Numeric NaNs (#656)
Browse files Browse the repository at this point in the history
* Fix FromSql for Postgres Numeric NaNs

This fixes a bug where from_sql was converting Numeric::NaN to 0
rather than returning an error.

It also returns a more descriptive error message when from_sql is called
with Numeric Infinity and -Infinity, which are also not representable in
Decimal.

Closes #655

* Rename to_from_sql to postgres_to_from_sql
* Rename from_sql_special_numeric to postgres_from_sql_special_numeric

---------

Co-authored-by: Paul Mason <[email protected]>
  • Loading branch information
lukoktonos and paupino authored Mar 29, 2024
1 parent f0abf16 commit 913dc5b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
21 changes: 21 additions & 0 deletions src/postgres/driver.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
use crate::error::Error;
use crate::postgres::common::*;
use crate::Decimal;
use bytes::{BufMut, BytesMut};
use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
use std::io::{Cursor, Read};

// These are from numeric.c in the PostgreSQL source code
const NUMERIC_NAN: u16 = 0xC000;
const NUMERIC_PINF: u16 = 0xD000;
const NUMERIC_NINF: u16 = 0xF000;
const NUMERIC_SPECIAL: u16 = 0xC000;

fn read_two_bytes(cursor: &mut Cursor<&[u8]>) -> std::io::Result<[u8; 2]> {
let mut result = [0; 2];
cursor.read_exact(&mut result)?;
Expand Down Expand Up @@ -70,6 +77,20 @@ impl<'a> FromSql<'a> for Decimal {
let weight = i16::from_be_bytes(read_two_bytes(&mut raw)?); // 10000^weight
// Sign: 0x0000 = positive, 0x4000 = negative, 0xC000 = NaN
let sign = u16::from_be_bytes(read_two_bytes(&mut raw)?);

if (sign & NUMERIC_SPECIAL) == NUMERIC_SPECIAL {
let special = match sign {
NUMERIC_NAN => "NaN",
NUMERIC_PINF => "Infinity",
NUMERIC_NINF => "-Infinity",
// This shouldn't be hit unless postgres adds a new special numeric type in the
// future
_ => "unknown special numeric",
};

return Err(Box::new(Error::ConversionTo(special.to_string())));
}

// Number of digits (in base 10) to print after decimal separator
let scale = u16::from_be_bytes(read_two_bytes(&mut raw)?);

Expand Down
34 changes: 32 additions & 2 deletions tests/decimal_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3473,9 +3473,9 @@ fn declarative_ref_dec_sum() {
assert_eq!(sum, Decimal::from(45))
}

#[cfg(feature = "postgres")]
#[cfg(feature = "db-postgres")]
#[test]
fn to_from_sql() {
fn postgres_to_from_sql() {
use bytes::BytesMut;
use postgres::types::{FromSql, Kind, ToSql, Type};

Expand Down Expand Up @@ -3514,6 +3514,36 @@ fn to_from_sql() {
}
}

#[cfg(feature = "db-postgres")]
#[test]
fn postgres_from_sql_special_numeric() {
use postgres::types::{FromSql, Kind, Type};

// The numbers below are the big-endian equivalent of the NUMERIC_* masks for NAN, PINF, NINF
let tests = &[
("NaN", &[0, 0, 0, 0, 192, 0, 0, 0]),
("Infinity", &[0, 0, 0, 0, 208, 0, 0, 0]),
("-Infinity", &[0, 0, 0, 0, 240, 0, 0, 0]),
];

let t = Type::new("".into(), 0, Kind::Simple, "".into());

for (name, bytes) in tests {
let res = Decimal::from_sql(&t, *bytes);
match &res {
Ok(_) => panic!("Expected error, got Ok"),
Err(e) => {
let error_message = e.to_string();
assert!(
error_message.contains(name),
"Error message does not contain the expected value: {}",
name
);
}
}
}
}

fn hash_it(d: Decimal) -> u64 {
use core::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
Expand Down

0 comments on commit 913dc5b

Please sign in to comment.