Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 232 additions & 7 deletions datafusion/functions/src/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

use crate::utils::{calculate_binary_decimal_math_cast, calculate_binary_math};

use arrow::array::ArrayRef;
use arrow::array::{Array, ArrayRef, AsArray};
use arrow::datatypes::DataType::{
Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64,
Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, Int8, Int16, Int32,
Int64, UInt8, UInt16, UInt32, UInt64,
};
use arrow::datatypes::{
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
Decimal256Type, DecimalType, Float32Type, Float64Type, Int8Type, Int16Type,
Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
};
use arrow::datatypes::{Field, FieldRef};
use arrow::error::ArrowError;
Expand All @@ -37,6 +39,7 @@ use datafusion_expr::{
ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
use num_traits::{PrimInt, Signed, cast, checked_pow};
use std::sync::Arc;

fn output_scale_for_decimal(precision: u8, input_scale: i8, decimal_places: i32) -> i8 {
Expand Down Expand Up @@ -185,6 +188,7 @@ impl RoundFunc {
vec![TypeSignatureClass::Integer],
NativeType::Int32,
);
let integer = Coercion::new_exact(TypeSignatureClass::Integer);
let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
let float64 = Coercion::new_implicit(
TypeSignatureClass::Native(logical_float64()),
Expand All @@ -199,6 +203,11 @@ impl RoundFunc {
decimal_places.clone(),
]),
TypeSignature::Coercible(vec![decimal]),
TypeSignature::Coercible(vec![
integer.clone(),
decimal_places.clone(),
]),
TypeSignature::Coercible(vec![integer]),
TypeSignature::Coercible(vec![
float32.clone(),
decimal_places.clone(),
Expand Down Expand Up @@ -245,6 +254,7 @@ impl ScalarUDFImpl for RoundFunc {
// extra precision to accommodate potential carry-over.
let return_type =
match input_type {
input_type if input_type.is_integer() => input_type.clone(),
Float32 => Float32,
Decimal32(precision, scale) => calculate_new_precision_scale::<
Decimal32Type,
Expand Down Expand Up @@ -308,6 +318,9 @@ impl ScalarUDFImpl for RoundFunc {
};

match (value_scalar, args.return_type()) {
(value_scalar, return_type) if return_type.is_integer() => {
round_integer_scalar(value_scalar, return_type, dp)
}
(ScalarValue::Float32(Some(v)), _) => {
let rounded = round_float(*v, dp)?;
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
Expand Down Expand Up @@ -468,6 +481,20 @@ fn round_columnar(
let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_));

let arr: ArrayRef = match (value_array.data_type(), return_type) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please restore a fast path, when decimal_places is a non-negative scalar literal and input_type == return_type, return value_array directly, only fall into round_integer_array for negative/array decimal_places.

@pchintar pchintar Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kumarUjjawal I've restored the fast path for non-negative scalar decimal_places values while still using round_integer_array for negative and column-valued decimal_places.

Before:

(input_type, return_type)
    if input_type == return_type && is_integer_data_type(return_type) =>
{
    round_integer_array(value_array.as_ref(), decimal_places, return_type)?
}

Now:

(input_type, return_type)
    if input_type == return_type && is_integer_data_type(return_type) =>
{
    match decimal_places {
        ColumnarValue::Scalar(ScalarValue::Int32(Some(dp))) if *dp >= 0 => {
            value_array
        }
        _ => round_integer_array(
            value_array.as_ref(),
            decimal_places,
            return_type,
        )?,
    }
}

(input_type, return_type)
Comment thread
Jefffrey marked this conversation as resolved.
if input_type == return_type && return_type.is_integer() =>
{
match decimal_places {
ColumnarValue::Scalar(ScalarValue::Int32(Some(dp))) if *dp >= 0 => {
value_array
}
_ => round_integer_array(
value_array.as_ref(),
decimal_places,
return_type,
)?,
}
}
(Float64, _) => {
let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>(
value_array.as_ref(),
Expand Down Expand Up @@ -518,7 +545,7 @@ fn round_columnar(
},
*precision,
*new_scale,
&DataType::Int32,
&Int32,
)?;
result as _
}
Expand Down Expand Up @@ -552,7 +579,7 @@ fn round_columnar(
},
*precision,
*new_scale,
&DataType::Int32,
&Int32,
)?;
result as _
}
Expand Down Expand Up @@ -586,7 +613,7 @@ fn round_columnar(
},
*precision,
*new_scale,
&DataType::Int32,
&Int32,
)?;
result as _
}
Expand Down Expand Up @@ -620,7 +647,7 @@ fn round_columnar(
},
*precision,
*new_scale,
&DataType::Int32,
&Int32,
)?;
result as _
}
Expand All @@ -634,6 +661,204 @@ fn round_columnar(
}
}

fn round_signed_integer<T>(
value: T,
decimal_places: i32,
type_name: &str,
) -> Result<T, ArrowError>
where
T: PrimInt + Signed,
{
if decimal_places >= 0 || value == T::zero() {
return Ok(value);
}

let ten = cast::<_, T>(10).expect("10 fits in all integer types");
let Some(factor) = checked_pow(ten, decimal_places.unsigned_abs() as usize) else {
return Ok(T::zero());
};

let two = cast::<_, T>(2).expect("2 fits in all integer types");
let one = T::one();
let threshold = factor / two;
let mut quotient = value / factor;
let remainder = value % factor;

if remainder >= threshold {
quotient = quotient.checked_add(&one).ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow while rounding {type_name}"))
})?;
} else if remainder <= -threshold {
quotient = quotient.checked_sub(&one).ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow while rounding {type_name}"))
})?;
}

quotient.checked_mul(&factor).ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow while rounding {type_name}"))
})
}

fn round_unsigned_integer<T>(
value: T,
decimal_places: i32,
type_name: &str,
) -> Result<T, ArrowError>
where
T: PrimInt,
{
if decimal_places >= 0 || value == T::zero() {
return Ok(value);
}

let ten = cast::<_, T>(10).expect("10 fits in all integer types");
let Some(factor) = checked_pow(ten, decimal_places.unsigned_abs() as usize) else {
return Ok(T::zero());
};

let two = cast::<_, T>(2).expect("2 fits in all integer types");
let one = T::one();
let threshold = factor / two;
let mut quotient = value / factor;
let remainder = value % factor;

if remainder >= threshold {
quotient = quotient.checked_add(&one).ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow while rounding {type_name}"))
})?;
}

quotient.checked_mul(&factor).ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow while rounding {type_name}"))
})
}

fn round_integer_scalar(
value: &ScalarValue,
return_type: &DataType,
decimal_places: i32,
) -> Result<ColumnarValue> {
match (value, return_type) {
(ScalarValue::Int8(Some(v)), Int8) => Ok(ColumnarValue::Scalar(
ScalarValue::Int8(Some(round_signed_integer(*v, decimal_places, "Int8")?)),
)),
(ScalarValue::Int16(Some(v)), Int16) => Ok(ColumnarValue::Scalar(
ScalarValue::Int16(Some(round_signed_integer(*v, decimal_places, "Int16")?)),
)),
(ScalarValue::Int32(Some(v)), Int32) => Ok(ColumnarValue::Scalar(
ScalarValue::Int32(Some(round_signed_integer(*v, decimal_places, "Int32")?)),
)),
(ScalarValue::Int64(Some(v)), Int64) => Ok(ColumnarValue::Scalar(
ScalarValue::Int64(Some(round_signed_integer(*v, decimal_places, "Int64")?)),
)),
(ScalarValue::UInt8(Some(v)), UInt8) => {
Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some(
round_unsigned_integer(*v, decimal_places, "UInt8")?,
))))
}
(ScalarValue::UInt16(Some(v)), UInt16) => {
Ok(ColumnarValue::Scalar(ScalarValue::UInt16(Some(
round_unsigned_integer(*v, decimal_places, "UInt16")?,
))))
}
(ScalarValue::UInt32(Some(v)), UInt32) => {
Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some(
round_unsigned_integer(*v, decimal_places, "UInt32")?,
))))
}
(ScalarValue::UInt64(Some(v)), UInt64) => {
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(
round_unsigned_integer(*v, decimal_places, "UInt64")?,
))))
}
_ => internal_err!(
"Unexpected integer round input/output types: {} -> {}",
value.data_type(),
return_type
),
}
}

macro_rules! round_integer_array {
($ARRAY:expr, $DP:expr, $ARRAY_TYPE:ty, $ROUND_FN:ident, $TYPE_NAME:expr) => {{
let array = $ARRAY.as_primitive::<$ARRAY_TYPE>();

let result = calculate_binary_math::<$ARRAY_TYPE, Int32Type, $ARRAY_TYPE, _>(
array,
$DP,
|v, dp| $ROUND_FN(v, dp, $TYPE_NAME),
)?;

Ok(result as ArrayRef)
}};
}

fn round_integer_array(
value_array: &dyn Array,
decimal_places: &ColumnarValue,
return_type: &DataType,
) -> Result<ArrayRef> {
match return_type {
Int8 => round_integer_array!(
value_array,
decimal_places,
Int8Type,
round_signed_integer,
"Int8"
),
Int16 => round_integer_array!(
value_array,
decimal_places,
Int16Type,
round_signed_integer,
"Int16"
),
Int32 => round_integer_array!(
value_array,
decimal_places,
Int32Type,
round_signed_integer,
"Int32"
),
Int64 => round_integer_array!(
value_array,
decimal_places,
Int64Type,
round_signed_integer,
"Int64"
),
UInt8 => round_integer_array!(
value_array,
decimal_places,
UInt8Type,
round_unsigned_integer,
"UInt8"
),
UInt16 => round_integer_array!(
value_array,
decimal_places,
UInt16Type,
round_unsigned_integer,
"UInt16"
),
UInt32 => round_integer_array!(
value_array,
decimal_places,
UInt32Type,
round_unsigned_integer,
"UInt32"
),
UInt64 => round_integer_array!(
value_array,
decimal_places,
UInt64Type,
round_unsigned_integer,
"UInt64"
),
_ => internal_err!("Unexpected return type for integer round: {return_type}"),
}
}

fn round_float<T>(value: T, decimal_places: i32) -> Result<T, ArrowError>
where
T: num_traits::Float,
Expand Down
37 changes: 15 additions & 22 deletions datafusion/spark/src/function/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,18 +462,7 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result<Columna
impl_integer_array_round!(array, UInt32Type, scale, enable_ansi_mode)
}
DataType::UInt64 => {
let array = array.as_primitive::<UInt64Type>();
let result: PrimitiveArray<UInt64Type> = array.try_unary(|x| {
let v_i64 = i64::try_from(x).map_err(|_| {
(exec_err!(
"round: UInt64 value {x} exceeds i64::MAX and cannot be rounded"
) as Result<(), _>)
.unwrap_err()
Comment thread
Jefffrey marked this conversation as resolved.
})?;
round_integer(v_i64, scale, enable_ansi_mode)
.map(|v| v as u64)
})?;
Ok(ColumnarValue::Array(Arc::new(result)))
impl_integer_array_round!(array, UInt64Type, scale, enable_ansi_mode)
}

// Float types
Expand Down Expand Up @@ -588,16 +577,20 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result<Columna
Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some(result))))
}
ScalarValue::UInt64(Some(v)) => {
let v_i64 = i64::try_from(*v).map_err(|_| {
(exec_err!(
"round: UInt64 value {v} exceeds i64::MAX and cannot be rounded"
) as Result<(), _>)
.unwrap_err()
})?;
let result = round_integer(v_i64, scale, enable_ansi_mode)?;
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(
result as u64,
))))
if scale >= 0 {
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(*v))))
} else {
let v_i64 = i64::try_from(*v).map_err(|_| {
(exec_err!(
"round: UInt64 value {v} exceeds i64::MAX and cannot be rounded"
) as Result<(), _>)
.unwrap_err()
})?;
let result = round_integer(v_i64, scale, enable_ansi_mode)?;
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(
result as u64,
))))
}
}

// Float scalars
Expand Down
Loading
Loading