diff --git a/datafusion/functions-aggregate/benches/approx_distinct.rs b/datafusion/functions-aggregate/benches/approx_distinct.rs index 4608c39d548b9..2ab783d9acf05 100644 --- a/datafusion/functions-aggregate/benches/approx_distinct.rs +++ b/datafusion/functions-aggregate/benches/approx_distinct.rs @@ -20,10 +20,12 @@ use std::sync::Arc; use arrow::array::{ ArrayRef, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, - Int8Array, Int16Array, Int64Array, StringArray, StringViewArray, UInt8Array, - UInt16Array, + Int8Array, Int16Array, Int64Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, + IntervalYearMonthArray, StringArray, StringViewArray, UInt8Array, UInt16Array, +}; +use arrow::datatypes::{ + DataType, Field, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, Schema, i256, }; -use arrow::datatypes::{DataType, Field, Schema, i256}; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ @@ -154,6 +156,38 @@ fn create_i16_array(n_distinct: usize) -> Int16Array { .collect() } +/// Creates an `IntervalYearMonthArray` where values are drawn from `0..n_distinct`. +fn create_interval_year_month_array(n_distinct: usize) -> IntervalYearMonthArray { + let mut rng = StdRng::seed_from_u64(42); + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(0..n_distinct as i32))) + .collect() +} + +/// Creates an `IntervalDayTimeArray` where values are drawn from a pool of +/// `n_distinct` values. +fn create_interval_day_time_array(n_distinct: usize) -> IntervalDayTimeArray { + let mut rng = StdRng::seed_from_u64(42); + let pool: Vec = (0..n_distinct) + .map(|i| IntervalDayTime::new(i as i32, i as i32 * 100)) + .collect(); + (0..BATCH_SIZE) + .map(|_| Some(pool[rng.random_range(0..pool.len())])) + .collect() +} + +/// Creates an `IntervalMonthDayNanoArray` where values are drawn from a pool of +/// `n_distinct` values. +fn create_interval_month_day_nano_array(n_distinct: usize) -> IntervalMonthDayNanoArray { + let mut rng = StdRng::seed_from_u64(42); + let pool: Vec = (0..n_distinct) + .map(|i| IntervalMonthDayNano::new(i as i32, i as i32, i as i64 * 1_000)) + .collect(); + (0..BATCH_SIZE) + .map(|_| Some(pool[rng.random_range(0..pool.len())])) + .collect() +} + /// Creates a pool of `n_distinct` random strings of the given length. fn create_string_pool(n_distinct: usize, string_length: usize) -> Vec { let mut rng = StdRng::seed_from_u64(42); @@ -333,6 +367,58 @@ fn approx_distinct_benchmark(c: &mut Criterion) { .unwrap() }) }); + + // Interval benchmarks + for pct in [80, 99] { + let n_distinct = BATCH_SIZE * pct / 100; + + // IntervalYearMonth + let values = Arc::new(create_interval_year_month_array(n_distinct)) as ArrayRef; + c.bench_function( + &format!("approx_distinct interval year_month {pct}% distinct"), + |b| { + b.iter(|| { + let mut accumulator = + prepare_accumulator(DataType::Interval(IntervalUnit::YearMonth)); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }, + ); + + // IntervalDayTime + let values = Arc::new(create_interval_day_time_array(n_distinct)) as ArrayRef; + c.bench_function( + &format!("approx_distinct interval day_time {pct}% distinct"), + |b| { + b.iter(|| { + let mut accumulator = + prepare_accumulator(DataType::Interval(IntervalUnit::DayTime)); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }, + ); + + // IntervalMonthDayNano + let values = + Arc::new(create_interval_month_day_nano_array(n_distinct)) as ArrayRef; + c.bench_function( + &format!("approx_distinct interval month_day_nano {pct}% distinct"), + |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Interval( + IntervalUnit::MonthDayNano, + )); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }, + ); + } } /// Build a `GroupsAccumulator` the same way the aggregate operator does: use the @@ -424,6 +510,32 @@ fn build_grouped_batches(data_type: &DataType) -> Vec<(ArrayRef, Vec)> { .with_precision_and_scale(*p, *s) .unwrap(), ), + DataType::Interval(IntervalUnit::YearMonth) => Arc::new( + (0..BATCH_SIZE) + .map(|_| Some(rng.random::())) + .collect::(), + ), + DataType::Interval(IntervalUnit::DayTime) => Arc::new( + (0..BATCH_SIZE) + .map(|_| { + Some(IntervalDayTime::new( + rng.random::(), + rng.random::(), + )) + }) + .collect::(), + ), + DataType::Interval(IntervalUnit::MonthDayNano) => Arc::new( + (0..BATCH_SIZE) + .map(|_| { + Some(IntervalMonthDayNano::new( + rng.random::(), + rng.random::(), + rng.random::(), + )) + }) + .collect::(), + ), other => panic!("unsupported grouped bench type: {other}"), }; (values, group_indices) @@ -445,6 +557,9 @@ fn approx_distinct_grouped_benchmark(c: &mut Criterion) { DataType::Decimal64(DECIMAL64_PRECISION, DECIMAL_SCALE), DataType::Decimal128(DECIMAL128_PRECISION, DECIMAL_SCALE), DataType::Decimal256(DECIMAL256_PRECISION, DECIMAL_SCALE), + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::DayTime), + DataType::Interval(IntervalUnit::MonthDayNano), ] { let batches = build_grouped_batches(&data_type); let label = format!("{data_type:?} {N_GROUPS} groups"); diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 1062a478b7bea..5fe3f350d73fb 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -26,6 +26,7 @@ use arrow::buffer::NullBuffer; use arrow::datatypes::{ ArrowPrimitiveType, DataType, Date32Type, Date64Type, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Field, FieldRef, Int32Type, Int64Type, + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt32Type, UInt64Type, @@ -759,6 +760,15 @@ impl AggregateUDFImpl for ApproxDistinct { DataType::Timestamp(TimeUnit::Nanosecond, _) => { Box::new(NumericHLLAccumulator::::new()) } + DataType::Interval(IntervalUnit::YearMonth) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Interval(IntervalUnit::DayTime) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Box::new(NumericHLLAccumulator::::new()) + } DataType::Decimal32(_, _) => { Box::new(NumericHLLAccumulator::::new()) } @@ -831,6 +841,9 @@ fn is_hll_groups_type(data_type: &DataType) -> bool { | DataType::Timestamp(TimeUnit::Millisecond, _) | DataType::Timestamp(TimeUnit::Microsecond, _) | DataType::Timestamp(TimeUnit::Nanosecond, _) + | DataType::Interval(IntervalUnit::YearMonth) + | DataType::Interval(IntervalUnit::DayTime) + | DataType::Interval(IntervalUnit::MonthDayNano) | DataType::Decimal32(_, _) | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) @@ -853,9 +866,10 @@ mod tests { use super::*; use arrow::array::{ AsArray, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, - Int64Array, StringViewArray, + Int64Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, + IntervalYearMonthArray, StringViewArray, }; - use arrow::datatypes::i256; + use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano, i256}; use std::sync::Arc; // A string longer than the 12-byte inline limit const LONG: &str = "this string is definitely longer than twelve bytes"; @@ -995,6 +1009,36 @@ mod tests { assert_count_numerical_acc_and_group_acc::(decimal_256, 6); } + #[test] + fn interval_support_numerical_acc_and_group_acc() { + let year_month: ArrayRef = + Arc::new(IntervalYearMonthArray::from(vec![1, 2, 2, 3, 3, 3, 0, 0])); + assert_count_numerical_acc_and_group_acc::( + year_month, 4, + ); + + let day_time: ArrayRef = Arc::new(IntervalDayTimeArray::from(vec![ + IntervalDayTime::new(1, 0), + IntervalDayTime::new(1, 0), + IntervalDayTime::new(1, 5), + IntervalDayTime::new(2, 0), + ])); + assert_count_numerical_acc_and_group_acc::(day_time, 3); + + let month_day_nano: ArrayRef = + Arc::new(IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(1, 0, 0), + IntervalMonthDayNano::new(1, 0, 0), + IntervalMonthDayNano::new(1, 0, 5), + IntervalMonthDayNano::new(0, 2, 0), + IntervalMonthDayNano::new(0, 0, 0), + ])); + assert_count_numerical_acc_and_group_acc::( + month_day_nano, + 4, + ); + } + /// `approx_distinct(v) FILTER (WHERE nullable_bool)` — a NULL filter row /// must not be counted (null filter is treated the same as false). #[test] diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index bd88cdc1ac111..dbf5063a30188 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2006,6 +2006,35 @@ statement ok DROP TABLE approx_distinct_decimal_test; +# This test runs approx_distinct over the intervals YearMonth, +# DayTime, MonthDayNano for the scalar and the grouped path. +statement ok +CREATE TABLE approx_distinct_interval_test (g INT, ym INTERVAL, dt INTERVAL, mdn INTERVAL) AS VALUES + (1, INTERVAL '1' MONTH, INTERVAL '1' DAY, INTERVAL '1' MONTH), + (1, INTERVAL '2' MONTH, INTERVAL '1 day 5 hours', INTERVAL '1 day 5 nanoseconds'), + (1, INTERVAL '2' MONTH, INTERVAL '1 day 5 hours', INTERVAL '1 day 5 nanoseconds'), + (2, INTERVAL '3' YEAR, INTERVAL '2' DAY, INTERVAL '2' DAY), + (2, INTERVAL '0' MONTH, INTERVAL '0' DAY, INTERVAL '0' DAY), + (2, INTERVAL '0' MONTH, INTERVAL '0' DAY, INTERVAL '0' DAY); + +# Scalar path +query III +SELECT approx_distinct(ym), approx_distinct(dt), approx_distinct(mdn) FROM approx_distinct_interval_test; +---- +4 4 4 + +# Grouped path +query IIII +SELECT g, approx_distinct(ym), approx_distinct(dt), approx_distinct(mdn) +FROM approx_distinct_interval_test GROUP BY g ORDER BY g; +---- +1 2 2 2 +2 2 2 2 + +statement ok +DROP TABLE approx_distinct_interval_test; + + ## This test executes the APPROX_PERCENTILE_CONT aggregation against the test ## data, asserting the estimated quantiles are ±5% their actual values.