Skip to content

Commit

Permalink
fix(rust,python): Fix precision/scale handling and invalid numbers in…
Browse files Browse the repository at this point in the history
… string-to-decimal conversions. (pola-rs#13548)
  • Loading branch information
cgevans authored Jan 10, 2024
1 parent 32c074e commit 10ef186
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 69 deletions.
158 changes: 96 additions & 62 deletions crates/polars-arrow/src/legacy/compute/decimal.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use atoi::FromRadix10SignedChecked;

fn significant_digits(bytes: &[u8]) -> u8 {
(bytes.len() as u8) - leading_zeros(bytes)
}

/// Count the number of b'0's at the beginning of a slice.
fn leading_zeros(bytes: &[u8]) -> u8 {
bytes.iter().take_while(|byte| **byte == b'0').count() as u8
}
Expand All @@ -15,85 +12,92 @@ fn split_decimal_bytes(bytes: &[u8]) -> (Option<&[u8]>, Option<&[u8]>) {
(lhs, rhs)
}

/// Parse a single i128 from bytes, ensuring the entire slice is read.
fn parse_integer_checked(bytes: &[u8]) -> Option<i128> {
let (n, len) = i128::from_radix_10_signed_checked(bytes);
n.filter(|_| len == bytes.len())
}

pub fn infer_scale(bytes: &[u8]) -> Option<u8> {
/// Assuming bytes are a well-formed decimal number (with or without a separator),
/// infer the scale of the number. If no separator is present, the scale is 0.
pub fn infer_scale(bytes: &[u8]) -> u8 {
let (_lhs, rhs) = split_decimal_bytes(bytes);
rhs.map(significant_digits)
rhs.map_or(0, |x| x.len() as u8)
}

/// Deserializes bytes to a single i128 representing a decimal
/// The decimal precision and scale are not checked.
/// Deserialize bytes to a single i128 representing a decimal, at a specified precision
/// (optional) and scale (required). If precision is not specified, it is assumed to be
/// 38 (the max precision allowed by the i128 representation). The number is checked to
/// ensure it fits within the specified precision and scale. Consistent with float parsing,
/// no decimal separator is required (eg "500", "500.", and "500.0" are all accepted); this allows
/// mixed integer/decimal sequences to be parsed as decimals. All trailing zeros are assumed to
/// be significant, whether or not a separator is present: 1200 requires precision >= 4, while 1200.200
/// requires precision >= 7 and scale >= 3. Returns None if the number is not well-formed, or does not
/// fit. Only b'.' is allowed as a decimal separator (issue #6698).
#[inline]
pub(super) fn deserialize_decimal(
mut bytes: &[u8],
precision: Option<u8>,
scale: u8,
) -> Option<i128> {
// While parse_integer_checked will parse negative numbers, we want to handle
// the negative sign ourselves, and so check for it initially, then handle it
// at the end.
let negative = bytes.first() == Some(&b'-');
if negative {
bytes = &bytes[1..];
};
let (lhs, rhs) = split_decimal_bytes(bytes);
let precision = precision.unwrap_or(u8::MAX);
let precision = precision.unwrap_or(38);

let lhs_b = lhs?;
let abs = parse_integer_checked(lhs_b).and_then(|x| {
match rhs {
Some(rhs) => {
parse_integer_checked(rhs)
.map(|y| (x, lhs_b, y, rhs))
.and_then(|(lhs, lhs_b, rhs, rhs_b)| {
let lhs_s = significant_digits(lhs_b);
let leading_zeros_rhs = leading_zeros(rhs_b);
let rhs_s = rhs_b.len() as u8 - leading_zeros_rhs;

// parameters don't match bytes
if lhs_s + rhs_s > precision || rhs_s > scale {
None
}
// significant digits don't fit scale
else if rhs_s < scale {
// scale: 2
// number: x.09
// significant digits: 1
// leading_zeros: 1
// parsed: 9
// so this is correct
if leading_zeros_rhs + rhs_s == scale {
Some((lhs, rhs))
}
// scale: 2
// number: x.9
// significant digits: 1
// parsed: 9
// so we must multiply by 10 to get 90
else {
let diff = scale as u32 - (rhs_s + leading_zeros_rhs) as u32;
Some((lhs, rhs * 10i128.pow(diff)))
}
}
// scale: 2
// number: x.90
// significant digits: 2
// parsed: 90
// so this is correct
else {
Some((lhs, rhs))
}
})
.map(|(lhs, rhs)| lhs * 10i128.pow(scale as u32) + rhs)
},
None => {
if lhs_b.len() > precision as usize || scale != 0 {
return None;

// For the purposes of decimal parsing, we assume that all digits other than leading zeros
// are significant, eg, 001200 has 4 significant digits, not 2. The Decimal type does
// not allow negative scales, so all trailing zeros on the LHS of any decimal separator
// will still take up space in the representation (eg, 1200 requires, at minimum, precision 4
// at scale 0; there is no scale -2 where it would only need precision 2).
let lhs_s = lhs_b.len() as u8 - leading_zeros(lhs_b);

let abs = parse_integer_checked(lhs_b).and_then(|x| match rhs {
// A decimal separator was found, so LHS and RHS need to be combined.
Some(rhs) => parse_integer_checked(rhs)
.map(|y| (x, y, rhs))
.and_then(|(lhs, rhs, rhs_b)| {
// We include all digits on the RHS, including both leading and trailing zeros,
// as significant. This is consistent with standard scientific practice for writing
// numbers. However, an alternative for parsing could truncate trailing zeros that extend
// beyond the scale: we choose not to do this here.
let scale_adjust = scale as i8 - rhs_b.len() as i8;

if (lhs_s + scale > precision)
|| (scale_adjust < 0)
|| (rhs_b.first() == Some(&b'-'))
{
// LHS significant figures and scale exceed precision,
// RHS significant figures (all digits in RHS) exceed scale, or
// RHS starts with a '-' and the number is not well-formed.
None
} else if (rhs_b.len() as u8) == scale {
// RHS has exactly scale significant digits, so no adjustment
// is needed to RHS.
Some((lhs, rhs))
} else {
// RHS needs adjustment to scale. scale_adjust is known to be
// positive.
Some((lhs, rhs * 10i128.pow(scale_adjust as u32)))
}
parse_integer_checked(lhs_b)
},
}
})
.map(|(lhs, rhs)| lhs * 10i128.pow(scale as u32) + rhs),
// No decimal separator was found; we have an integer / LHS only.
None => {
if (lhs_s + scale > precision) || lhs_b.is_empty() {
// Either the integer itself exceeds the precision, or we simply have
// no number at all / an empty string.
return None;
}
Some(x * 10i128.pow(scale as u32))
},
});
if negative {
Some(-abs?)
Expand Down Expand Up @@ -142,10 +146,12 @@ mod test {

let scale = 20;
let val = "0.01";
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);
assert_eq!(
deserialize_decimal(val.as_bytes(), precision, scale),
deserialize_decimal(val.as_bytes(), None, scale),
Some(1000000000000000000)
);

let scale = 5;
let val = "12ABC.34";
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);
Expand All @@ -159,6 +165,9 @@ mod test {
let val = "12.3.ABC4";
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);

let val = "12.-3";
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);

let val = "";
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);

Expand All @@ -168,10 +177,35 @@ mod test {
Some(500000i128)
);

let val = "5";
assert_eq!(
deserialize_decimal(val.as_bytes(), precision, scale),
Some(500000i128)
);

let val = ".5";
assert_eq!(
deserialize_decimal(val.as_bytes(), precision, scale),
Some(50000i128)
);

// Precision and scale fitting:
let val = b"1200";
assert_eq!(deserialize_decimal(val, None, 0), Some(1200));
assert_eq!(deserialize_decimal(val, Some(4), 0), Some(1200));
assert_eq!(deserialize_decimal(val, Some(3), 0), None);
assert_eq!(deserialize_decimal(val, Some(4), 1), None);

let val = b"1200.010";
assert_eq!(deserialize_decimal(val, None, 0), None); // insufficient scale
assert_eq!(deserialize_decimal(val, None, 3), Some(1200010)); // exact scale
assert_eq!(deserialize_decimal(val, None, 6), Some(1200010000)); // excess scale
assert_eq!(deserialize_decimal(val, Some(7), 0), None); // insufficient precision and scale
assert_eq!(deserialize_decimal(val, Some(7), 3), Some(1200010)); // exact precision and scale
assert_eq!(deserialize_decimal(val, Some(10), 6), Some(1200010000)); // exact precision, excess scale
assert_eq!(deserialize_decimal(val, Some(5), 6), None); // insufficient precision, excess scale
assert_eq!(deserialize_decimal(val, Some(5), 3), None); // insufficient precision, exact scale
assert_eq!(deserialize_decimal(val, Some(12), 5), Some(120001000)); // excess precision, excess scale
assert_eq!(deserialize_decimal(val, None, 35), None); // scale causes insufficient precision
}
}
42 changes: 35 additions & 7 deletions crates/polars-core/src/chunked_array/ops/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::prelude::*;

impl StringChunked {
/// Convert an [`StringChunked`] to a [`Series`] of [`DataType::Decimal`].
/// The parameters needed for the decimal type are inferred.
/// Scale needed for the decimal type are inferred. Parsing is not strict.
/// Scale inference assumes that all tested strings are well-formed numbers,
/// and may produce unexpected results for scale if this is not the case.
///
/// If the decimal `precision` and `scale` are already known, consider
/// using the `cast` method.
Expand All @@ -11,14 +13,40 @@ impl StringChunked {
let mut iter = self.into_iter();
let mut valid_count = 0;
while let Some(Some(v)) = iter.next() {
if let Some(scale_value) = arrow::legacy::compute::decimal::infer_scale(v.as_bytes()) {
scale = std::cmp::max(scale, scale_value);
valid_count += 1;
if valid_count == infer_length {
break;
}
let scale_value = arrow::legacy::compute::decimal::infer_scale(v.as_bytes());
scale = std::cmp::max(scale, scale_value);
valid_count += 1;
if valid_count == infer_length {
break;
}
}

self.cast(&DataType::Decimal(None, Some(scale as usize)))
}
}

#[cfg(test)]
mod test {
#[test]
fn test_inferred_length() {
use super::*;
let vals = [
"1.0",
"invalid",
"225.0",
"3.00045",
"-4.0",
"5.104",
"5.25251525353",
];
let s = StringChunked::from_slice("test", &vals);
let s = s.to_decimal(6).unwrap();
assert_eq!(s.dtype(), &DataType::Decimal(None, Some(5)));
assert_eq!(s.len(), 7);
assert_eq!(s.get(0).unwrap(), AnyValue::Decimal(100000, 5));
assert_eq!(s.get(1).unwrap(), AnyValue::Null);
assert_eq!(s.get(3).unwrap(), AnyValue::Decimal(300045, 5));
assert_eq!(s.get(4).unwrap(), AnyValue::Decimal(-400000, 5));
assert_eq!(s.get(6).unwrap(), AnyValue::Null);
}
}
6 changes: 6 additions & 0 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,12 @@ impl AnyValue<'_> {
let avs = struct_to_avs_static(*idx, arr, fields);
fields_right == avs
},
#[cfg(feature = "dtype-decimal")]
(Decimal(v_l, scale_l), Decimal(v_r, scale_r)) => {
// Decimal equality here requires that both value and scale be equal, eg
// 1.2 at scale 1, and 1.20 at scale 2, are not equal.
*v_l == *v_r && *scale_l == *scale_r
},
_ => false,
}
}
Expand Down

0 comments on commit 10ef186

Please sign in to comment.