From 8d15cdf80c0f41371039870604e8e2d84691382c Mon Sep 17 00:00:00 2001 From: pchintar <89355405+pchintar@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:39:13 +0530 Subject: [PATCH] Preserve integer precision in round() --- datafusion/functions/src/math/round.rs | 239 +++++++++++++++++- datafusion/spark/src/function/math/round.rs | 37 ++- datafusion/sqllogictest/test_files/scalar.slt | 55 ++++ datafusion/sqllogictest/test_files/select.slt | 11 +- .../test_files/spark/math/round.slt | 24 ++ 5 files changed, 331 insertions(+), 35 deletions(-) diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index aacc8820a8cb6..62f1c3540b9ce 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -17,13 +17,15 @@ use crate::utils::{calculate_binary_decimal_math_cast, calculate_binary_math}; -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef, AsArray}; use arrow::datatypes::DataType::{ - Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, + Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, Int8, Int16, Int32, + Int64, UInt8, UInt16, UInt32, UInt64, }; use arrow::datatypes::{ ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, - Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type, + Decimal256Type, DecimalType, Float32Type, Float64Type, Int8Type, Int16Type, + Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; use arrow::datatypes::{Field, FieldRef}; use arrow::error::ArrowError; @@ -37,6 +39,7 @@ use datafusion_expr::{ ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use num_traits::{PrimInt, Signed, cast, checked_pow}; use std::sync::Arc; fn output_scale_for_decimal(precision: u8, input_scale: i8, decimal_places: i32) -> i8 { @@ -185,6 +188,7 @@ impl RoundFunc { vec![TypeSignatureClass::Integer], NativeType::Int32, ); + let integer = Coercion::new_exact(TypeSignatureClass::Integer); let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32())); let float64 = Coercion::new_implicit( TypeSignatureClass::Native(logical_float64()), @@ -199,6 +203,11 @@ impl RoundFunc { decimal_places.clone(), ]), TypeSignature::Coercible(vec![decimal]), + TypeSignature::Coercible(vec![ + integer.clone(), + decimal_places.clone(), + ]), + TypeSignature::Coercible(vec![integer]), TypeSignature::Coercible(vec![ float32.clone(), decimal_places.clone(), @@ -245,6 +254,7 @@ impl ScalarUDFImpl for RoundFunc { // extra precision to accommodate potential carry-over. let return_type = match input_type { + input_type if input_type.is_integer() => input_type.clone(), Float32 => Float32, Decimal32(precision, scale) => calculate_new_precision_scale::< Decimal32Type, @@ -308,6 +318,9 @@ impl ScalarUDFImpl for RoundFunc { }; match (value_scalar, args.return_type()) { + (value_scalar, return_type) if return_type.is_integer() => { + round_integer_scalar(value_scalar, return_type, dp) + } (ScalarValue::Float32(Some(v)), _) => { let rounded = round_float(*v, dp)?; Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) @@ -468,6 +481,20 @@ fn round_columnar( let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_)); let arr: ArrayRef = match (value_array.data_type(), return_type) { + (input_type, return_type) + if input_type == return_type && return_type.is_integer() => + { + match decimal_places { + ColumnarValue::Scalar(ScalarValue::Int32(Some(dp))) if *dp >= 0 => { + value_array + } + _ => round_integer_array( + value_array.as_ref(), + decimal_places, + return_type, + )?, + } + } (Float64, _) => { let result = calculate_binary_math::( value_array.as_ref(), @@ -518,7 +545,7 @@ fn round_columnar( }, *precision, *new_scale, - &DataType::Int32, + &Int32, )?; result as _ } @@ -552,7 +579,7 @@ fn round_columnar( }, *precision, *new_scale, - &DataType::Int32, + &Int32, )?; result as _ } @@ -586,7 +613,7 @@ fn round_columnar( }, *precision, *new_scale, - &DataType::Int32, + &Int32, )?; result as _ } @@ -620,7 +647,7 @@ fn round_columnar( }, *precision, *new_scale, - &DataType::Int32, + &Int32, )?; result as _ } @@ -634,6 +661,204 @@ fn round_columnar( } } +fn round_signed_integer( + value: T, + decimal_places: i32, + type_name: &str, +) -> Result +where + T: PrimInt + Signed, +{ + if decimal_places >= 0 || value == T::zero() { + return Ok(value); + } + + let ten = cast::<_, T>(10).expect("10 fits in all integer types"); + let Some(factor) = checked_pow(ten, decimal_places.unsigned_abs() as usize) else { + return Ok(T::zero()); + }; + + let two = cast::<_, T>(2).expect("2 fits in all integer types"); + let one = T::one(); + let threshold = factor / two; + let mut quotient = value / factor; + let remainder = value % factor; + + if remainder >= threshold { + quotient = quotient.checked_add(&one).ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow while rounding {type_name}")) + })?; + } else if remainder <= -threshold { + quotient = quotient.checked_sub(&one).ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow while rounding {type_name}")) + })?; + } + + quotient.checked_mul(&factor).ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow while rounding {type_name}")) + }) +} + +fn round_unsigned_integer( + value: T, + decimal_places: i32, + type_name: &str, +) -> Result +where + T: PrimInt, +{ + if decimal_places >= 0 || value == T::zero() { + return Ok(value); + } + + let ten = cast::<_, T>(10).expect("10 fits in all integer types"); + let Some(factor) = checked_pow(ten, decimal_places.unsigned_abs() as usize) else { + return Ok(T::zero()); + }; + + let two = cast::<_, T>(2).expect("2 fits in all integer types"); + let one = T::one(); + let threshold = factor / two; + let mut quotient = value / factor; + let remainder = value % factor; + + if remainder >= threshold { + quotient = quotient.checked_add(&one).ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow while rounding {type_name}")) + })?; + } + + quotient.checked_mul(&factor).ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow while rounding {type_name}")) + }) +} + +fn round_integer_scalar( + value: &ScalarValue, + return_type: &DataType, + decimal_places: i32, +) -> Result { + match (value, return_type) { + (ScalarValue::Int8(Some(v)), Int8) => Ok(ColumnarValue::Scalar( + ScalarValue::Int8(Some(round_signed_integer(*v, decimal_places, "Int8")?)), + )), + (ScalarValue::Int16(Some(v)), Int16) => Ok(ColumnarValue::Scalar( + ScalarValue::Int16(Some(round_signed_integer(*v, decimal_places, "Int16")?)), + )), + (ScalarValue::Int32(Some(v)), Int32) => Ok(ColumnarValue::Scalar( + ScalarValue::Int32(Some(round_signed_integer(*v, decimal_places, "Int32")?)), + )), + (ScalarValue::Int64(Some(v)), Int64) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(Some(round_signed_integer(*v, decimal_places, "Int64")?)), + )), + (ScalarValue::UInt8(Some(v)), UInt8) => { + Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some( + round_unsigned_integer(*v, decimal_places, "UInt8")?, + )))) + } + (ScalarValue::UInt16(Some(v)), UInt16) => { + Ok(ColumnarValue::Scalar(ScalarValue::UInt16(Some( + round_unsigned_integer(*v, decimal_places, "UInt16")?, + )))) + } + (ScalarValue::UInt32(Some(v)), UInt32) => { + Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some( + round_unsigned_integer(*v, decimal_places, "UInt32")?, + )))) + } + (ScalarValue::UInt64(Some(v)), UInt64) => { + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some( + round_unsigned_integer(*v, decimal_places, "UInt64")?, + )))) + } + _ => internal_err!( + "Unexpected integer round input/output types: {} -> {}", + value.data_type(), + return_type + ), + } +} + +macro_rules! round_integer_array { + ($ARRAY:expr, $DP:expr, $ARRAY_TYPE:ty, $ROUND_FN:ident, $TYPE_NAME:expr) => {{ + let array = $ARRAY.as_primitive::<$ARRAY_TYPE>(); + + let result = calculate_binary_math::<$ARRAY_TYPE, Int32Type, $ARRAY_TYPE, _>( + array, + $DP, + |v, dp| $ROUND_FN(v, dp, $TYPE_NAME), + )?; + + Ok(result as ArrayRef) + }}; +} + +fn round_integer_array( + value_array: &dyn Array, + decimal_places: &ColumnarValue, + return_type: &DataType, +) -> Result { + match return_type { + Int8 => round_integer_array!( + value_array, + decimal_places, + Int8Type, + round_signed_integer, + "Int8" + ), + Int16 => round_integer_array!( + value_array, + decimal_places, + Int16Type, + round_signed_integer, + "Int16" + ), + Int32 => round_integer_array!( + value_array, + decimal_places, + Int32Type, + round_signed_integer, + "Int32" + ), + Int64 => round_integer_array!( + value_array, + decimal_places, + Int64Type, + round_signed_integer, + "Int64" + ), + UInt8 => round_integer_array!( + value_array, + decimal_places, + UInt8Type, + round_unsigned_integer, + "UInt8" + ), + UInt16 => round_integer_array!( + value_array, + decimal_places, + UInt16Type, + round_unsigned_integer, + "UInt16" + ), + UInt32 => round_integer_array!( + value_array, + decimal_places, + UInt32Type, + round_unsigned_integer, + "UInt32" + ), + UInt64 => round_integer_array!( + value_array, + decimal_places, + UInt64Type, + round_unsigned_integer, + "UInt64" + ), + _ => internal_err!("Unexpected return type for integer round: {return_type}"), + } +} + fn round_float(value: T, decimal_places: i32) -> Result where T: num_traits::Float, diff --git a/datafusion/spark/src/function/math/round.rs b/datafusion/spark/src/function/math/round.rs index 05745666183d3..471d38d804cac 100644 --- a/datafusion/spark/src/function/math/round.rs +++ b/datafusion/spark/src/function/math/round.rs @@ -462,18 +462,7 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result { - let array = array.as_primitive::(); - let result: PrimitiveArray = array.try_unary(|x| { - let v_i64 = i64::try_from(x).map_err(|_| { - (exec_err!( - "round: UInt64 value {x} exceeds i64::MAX and cannot be rounded" - ) as Result<(), _>) - .unwrap_err() - })?; - round_integer(v_i64, scale, enable_ansi_mode) - .map(|v| v as u64) - })?; - Ok(ColumnarValue::Array(Arc::new(result))) + impl_integer_array_round!(array, UInt64Type, scale, enable_ansi_mode) } // Float types @@ -588,16 +577,20 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result { - let v_i64 = i64::try_from(*v).map_err(|_| { - (exec_err!( - "round: UInt64 value {v} exceeds i64::MAX and cannot be rounded" - ) as Result<(), _>) - .unwrap_err() - })?; - let result = round_integer(v_i64, scale, enable_ansi_mode)?; - Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some( - result as u64, - )))) + if scale >= 0 { + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(*v)))) + } else { + let v_i64 = i64::try_from(*v).map_err(|_| { + (exec_err!( + "round: UInt64 value {v} exceeds i64::MAX and cannot be rounded" + ) as Result<(), _>) + .unwrap_err() + })?; + let result = round_integer(v_i64, scale, enable_ansi_mode)?; + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some( + result as u64, + )))) + } } // Float scalars diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index c34d4696d52f8..65b78acd46234 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -965,6 +965,61 @@ select round(a), round(b), round(c) from small_floats; 0 0 1 1 0 0 +# round int64 should preserve exact values above Float64 precision range +query TI +select arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'))), + round(arrow_cast(9007199254740993, 'Int64')); +---- +Int64 9007199254740993 + +# round int64 with positive decimal_places should preserve exact values above Float64 precision range +query TI +select arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'), 2)), + round(arrow_cast(9007199254740993, 'Int64'), 2); +---- +Int64 9007199254740993 + +# round int64 with negative decimal_places +query TI +select arrow_typeof(round(arrow_cast(125, 'Int64'), -1)), + round(arrow_cast(125, 'Int64'), -1); +---- +Int64 130 + +# round int64 with column decimal_places +query I +select round(v, dp) +from (values (arrow_cast(125, 'Int64'), 1), + (arrow_cast(125, 'Int64'), -1)) as t(v, dp); +---- +125 +130 + +# round int64 overflow with negative decimal_places +query error Overflow while rounding Int64 +select round(arrow_cast(9223372036854775807, 'Int64'), -1); + +# round uint64 should preserve exact values +query TI +select arrow_typeof(round(arrow_cast(18446744073709551615, 'UInt64'))), + round(arrow_cast(18446744073709551615, 'UInt64')); +---- +UInt64 18446744073709551615 + +# round uint64 with positive decimal_places should preserve exact values +query TI +select arrow_typeof(round(arrow_cast(18446744073709551615, 'UInt64'), 2)), + round(arrow_cast(18446744073709551615, 'UInt64'), 2); +---- +UInt64 18446744073709551615 + +# round int64 to place larger than the number itself +query TI +select arrow_typeof(round(arrow_cast(125, 'Int64'), -5)), + round(arrow_cast(125, 'Int64'), -5); +---- +Int64 0 + # round with too large # max Int32 is 2147483647 query error round decimal_places 2147483648 is out of supported i32 range diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index c7e5ed12fc0af..3a9ae30d04275 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1591,16 +1591,15 @@ WHERE CAST(ROUND(b) as INT) = a ORDER BY CAST(ROUND(b) as INT); ---- logical_plan -01)Sort: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) ASC NULLS LAST -02)--Filter: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a -03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a] +01)Sort: CAST(round(annotated_data_finite2.b) AS Int32) ASC NULLS LAST +02)--Filter: CAST(round(annotated_data_finite2.b) AS Int32) = annotated_data_finite2.a +03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[CAST(round(annotated_data_finite2.b) AS Int32) = annotated_data_finite2.a] physical_plan -01)SortPreservingMergeExec: [CAST(round(CAST(b@2 AS Float64)) AS Int32) ASC NULLS LAST] -02)--FilterExec: CAST(round(CAST(b@2 AS Float64)) AS Int32) = a@1 +01)SortPreservingMergeExec: [round(b@2) ASC NULLS LAST] +02)--FilterExec: round(b@2) = a@1 03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1, maintains_sort_order=true 04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true - statement ok drop table annotated_data_finite2; diff --git a/datafusion/sqllogictest/test_files/spark/math/round.slt b/datafusion/sqllogictest/test_files/spark/math/round.slt index 91c5bdf0506f5..49956846ac814 100644 --- a/datafusion/sqllogictest/test_files/spark/math/round.slt +++ b/datafusion/sqllogictest/test_files/spark/math/round.slt @@ -222,6 +222,18 @@ SELECT round(25::bigint, -1::int); ---- 30 +# round(bigint) should preserve exact values above Float64's exact integer range +query IT +SELECT round(arrow_cast(9007199254740993, 'Int64')), arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'))); +---- +9007199254740993 Int64 + +# round(bigint, positive scale) should also preserve exact values above Float64's exact integer range +query IT +SELECT round(arrow_cast(9007199254740993, 'Int64'), 2::int), arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'), 2::int)); +---- +9007199254740993 Int64 + # round(smallint, -1) query I SELECT round(25::smallint, -1::int); @@ -268,6 +280,18 @@ SELECT round(arrow_cast(25, 'UInt64'), -1::int); ---- 30 +# round(uint64) should preserve exact values above Float64's exact integer range +query IT +SELECT round(arrow_cast(18446744073709551615, 'UInt64')), arrow_typeof(round(arrow_cast(18446744073709551615, 'UInt64'))); +---- +18446744073709551615 UInt64 + +# round(uint64, positive scale) should also preserve exact values above Float64's exact integer range +query IT +SELECT round(arrow_cast(18446744073709551615, 'UInt64'), 2::int), arrow_typeof(round(arrow_cast(18446744073709551615, 'UInt64'), 2::int)); +---- +18446744073709551615 UInt64 + # round(uint32, positive scale) — no-op for integers query I SELECT round(arrow_cast(42, 'UInt32'), 2::int);