Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 118 additions & 3 deletions datafusion/functions-aggregate/benches/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ use std::sync::Arc;

use arrow::array::{
ArrayRef, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array,
Int8Array, Int16Array, Int64Array, StringArray, StringViewArray, UInt8Array,
UInt16Array,
Int8Array, Int16Array, Int64Array, IntervalDayTimeArray, IntervalMonthDayNanoArray,
IntervalYearMonthArray, StringArray, StringViewArray, UInt8Array, UInt16Array,
};
use arrow::datatypes::{
DataType, Field, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, Schema, i256,
};
use arrow::datatypes::{DataType, Field, Schema, i256};
use criterion::{Criterion, criterion_group, criterion_main};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::{
Expand Down Expand Up @@ -154,6 +156,38 @@ fn create_i16_array(n_distinct: usize) -> Int16Array {
.collect()
}

/// Creates an `IntervalYearMonthArray` where values are drawn from `0..n_distinct`.
fn create_interval_year_month_array(n_distinct: usize) -> IntervalYearMonthArray {
let mut rng = StdRng::seed_from_u64(42);
(0..BATCH_SIZE)
.map(|_| Some(rng.random_range(0..n_distinct as i32)))
.collect()
}

/// Creates an `IntervalDayTimeArray` where values are drawn from a pool of
/// `n_distinct` values.
fn create_interval_day_time_array(n_distinct: usize) -> IntervalDayTimeArray {
let mut rng = StdRng::seed_from_u64(42);
let pool: Vec<IntervalDayTime> = (0..n_distinct)
.map(|i| IntervalDayTime::new(i as i32, i as i32 * 100))
.collect();
(0..BATCH_SIZE)
.map(|_| Some(pool[rng.random_range(0..pool.len())]))
.collect()
}

/// Creates an `IntervalMonthDayNanoArray` where values are drawn from a pool of
/// `n_distinct` values.
fn create_interval_month_day_nano_array(n_distinct: usize) -> IntervalMonthDayNanoArray {
let mut rng = StdRng::seed_from_u64(42);
let pool: Vec<IntervalMonthDayNano> = (0..n_distinct)
.map(|i| IntervalMonthDayNano::new(i as i32, i as i32, i as i64 * 1_000))
.collect();
(0..BATCH_SIZE)
.map(|_| Some(pool[rng.random_range(0..pool.len())]))
.collect()
}

/// Creates a pool of `n_distinct` random strings of the given length.
fn create_string_pool(n_distinct: usize, string_length: usize) -> Vec<String> {
let mut rng = StdRng::seed_from_u64(42);
Expand Down Expand Up @@ -333,6 +367,58 @@ fn approx_distinct_benchmark(c: &mut Criterion) {
.unwrap()
})
});

// Interval benchmarks
for pct in [80, 99] {
let n_distinct = BATCH_SIZE * pct / 100;

// IntervalYearMonth
let values = Arc::new(create_interval_year_month_array(n_distinct)) as ArrayRef;
c.bench_function(
&format!("approx_distinct interval year_month {pct}% distinct"),
|b| {
b.iter(|| {
let mut accumulator =
prepare_accumulator(DataType::Interval(IntervalUnit::YearMonth));
accumulator
.update_batch(std::slice::from_ref(&values))
.unwrap()
})
},
);

// IntervalDayTime
let values = Arc::new(create_interval_day_time_array(n_distinct)) as ArrayRef;
c.bench_function(
&format!("approx_distinct interval day_time {pct}% distinct"),
|b| {
b.iter(|| {
let mut accumulator =
prepare_accumulator(DataType::Interval(IntervalUnit::DayTime));
accumulator
.update_batch(std::slice::from_ref(&values))
.unwrap()
})
},
);

// IntervalMonthDayNano
let values =
Arc::new(create_interval_month_day_nano_array(n_distinct)) as ArrayRef;
c.bench_function(
&format!("approx_distinct interval month_day_nano {pct}% distinct"),
|b| {
b.iter(|| {
let mut accumulator = prepare_accumulator(DataType::Interval(
IntervalUnit::MonthDayNano,
));
accumulator
.update_batch(std::slice::from_ref(&values))
.unwrap()
})
},
);
}
}

/// Build a `GroupsAccumulator` the same way the aggregate operator does: use the
Expand Down Expand Up @@ -424,6 +510,32 @@ fn build_grouped_batches(data_type: &DataType) -> Vec<(ArrayRef, Vec<usize>)> {
.with_precision_and_scale(*p, *s)
.unwrap(),
),
DataType::Interval(IntervalUnit::YearMonth) => Arc::new(
(0..BATCH_SIZE)
.map(|_| Some(rng.random::<i32>()))
.collect::<IntervalYearMonthArray>(),
),
DataType::Interval(IntervalUnit::DayTime) => Arc::new(
(0..BATCH_SIZE)
.map(|_| {
Some(IntervalDayTime::new(
rng.random::<i32>(),
rng.random::<i32>(),
))
})
.collect::<IntervalDayTimeArray>(),
),
DataType::Interval(IntervalUnit::MonthDayNano) => Arc::new(
(0..BATCH_SIZE)
.map(|_| {
Some(IntervalMonthDayNano::new(
rng.random::<i32>(),
rng.random::<i32>(),
rng.random::<i64>(),
))
})
.collect::<IntervalMonthDayNanoArray>(),
),
other => panic!("unsupported grouped bench type: {other}"),
};
(values, group_indices)
Expand All @@ -445,6 +557,9 @@ fn approx_distinct_grouped_benchmark(c: &mut Criterion) {
DataType::Decimal64(DECIMAL64_PRECISION, DECIMAL_SCALE),
DataType::Decimal128(DECIMAL128_PRECISION, DECIMAL_SCALE),
DataType::Decimal256(DECIMAL256_PRECISION, DECIMAL_SCALE),
DataType::Interval(IntervalUnit::YearMonth),
DataType::Interval(IntervalUnit::DayTime),
DataType::Interval(IntervalUnit::MonthDayNano),
] {
let batches = build_grouped_batches(&data_type);
let label = format!("{data_type:?} {N_GROUPS} groups");
Expand Down
48 changes: 46 additions & 2 deletions datafusion/functions-aggregate/src/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use arrow::buffer::NullBuffer;
use arrow::datatypes::{
ArrowPrimitiveType, DataType, Date32Type, Date64Type, Decimal32Type, Decimal64Type,
Decimal128Type, Decimal256Type, Field, FieldRef, Int32Type, Int64Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType, UInt32Type, UInt64Type,
Expand Down Expand Up @@ -759,6 +760,15 @@ impl AggregateUDFImpl for ApproxDistinct {
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
Box::new(NumericHLLAccumulator::<TimestampNanosecondType>::new())
}
DataType::Interval(IntervalUnit::YearMonth) => {
Box::new(NumericHLLAccumulator::<IntervalYearMonthType>::new())
}
DataType::Interval(IntervalUnit::DayTime) => {
Box::new(NumericHLLAccumulator::<IntervalDayTimeType>::new())
}
DataType::Interval(IntervalUnit::MonthDayNano) => {
Box::new(NumericHLLAccumulator::<IntervalMonthDayNanoType>::new())
}
DataType::Decimal32(_, _) => {
Box::new(NumericHLLAccumulator::<Decimal32Type>::new())
}
Expand Down Expand Up @@ -831,6 +841,9 @@ fn is_hll_groups_type(data_type: &DataType) -> bool {
| DataType::Timestamp(TimeUnit::Millisecond, _)
| DataType::Timestamp(TimeUnit::Microsecond, _)
| DataType::Timestamp(TimeUnit::Nanosecond, _)
| DataType::Interval(IntervalUnit::YearMonth)
| DataType::Interval(IntervalUnit::DayTime)
| DataType::Interval(IntervalUnit::MonthDayNano)
| DataType::Decimal32(_, _)
| DataType::Decimal64(_, _)
| DataType::Decimal128(_, _)
Expand All @@ -853,9 +866,10 @@ mod tests {
use super::*;
use arrow::array::{
AsArray, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array,
Int64Array, StringViewArray,
Int64Array, IntervalDayTimeArray, IntervalMonthDayNanoArray,
IntervalYearMonthArray, StringViewArray,
};
use arrow::datatypes::i256;
use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano, i256};
use std::sync::Arc;
// A string longer than the 12-byte inline limit
const LONG: &str = "this string is definitely longer than twelve bytes";
Expand Down Expand Up @@ -995,6 +1009,36 @@ mod tests {
assert_count_numerical_acc_and_group_acc::<Decimal256Type>(decimal_256, 6);
}

#[test]
fn interval_support_numerical_acc_and_group_acc() {
let year_month: ArrayRef =
Arc::new(IntervalYearMonthArray::from(vec![1, 2, 2, 3, 3, 3, 0, 0]));
assert_count_numerical_acc_and_group_acc::<IntervalYearMonthType>(
year_month, 4,
);

let day_time: ArrayRef = Arc::new(IntervalDayTimeArray::from(vec![
IntervalDayTime::new(1, 0),
IntervalDayTime::new(1, 0),
IntervalDayTime::new(1, 5),
IntervalDayTime::new(2, 0),
]));
assert_count_numerical_acc_and_group_acc::<IntervalDayTimeType>(day_time, 3);

let month_day_nano: ArrayRef =
Arc::new(IntervalMonthDayNanoArray::from(vec![
IntervalMonthDayNano::new(1, 0, 0),
IntervalMonthDayNano::new(1, 0, 0),
IntervalMonthDayNano::new(1, 0, 5),
IntervalMonthDayNano::new(0, 2, 0),
IntervalMonthDayNano::new(0, 0, 0),
]));
assert_count_numerical_acc_and_group_acc::<IntervalMonthDayNanoType>(
month_day_nano,
4,
);
}

/// `approx_distinct(v) FILTER (WHERE nullable_bool)` — a NULL filter row
/// must not be counted (null filter is treated the same as false).
#[test]
Expand Down
29 changes: 29 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2006,6 +2006,35 @@ statement ok
DROP TABLE approx_distinct_decimal_test;


# This test runs approx_distinct over the intervals YearMonth,
# DayTime, MonthDayNano for the scalar and the grouped path.
statement ok
CREATE TABLE approx_distinct_interval_test (g INT, ym INTERVAL, dt INTERVAL, mdn INTERVAL) AS VALUES
(1, INTERVAL '1' MONTH, INTERVAL '1' DAY, INTERVAL '1' MONTH),
(1, INTERVAL '2' MONTH, INTERVAL '1 day 5 hours', INTERVAL '1 day 5 nanoseconds'),
(1, INTERVAL '2' MONTH, INTERVAL '1 day 5 hours', INTERVAL '1 day 5 nanoseconds'),
(2, INTERVAL '3' YEAR, INTERVAL '2' DAY, INTERVAL '2' DAY),
(2, INTERVAL '0' MONTH, INTERVAL '0' DAY, INTERVAL '0' DAY),
(2, INTERVAL '0' MONTH, INTERVAL '0' DAY, INTERVAL '0' DAY);

# Scalar path
query III
SELECT approx_distinct(ym), approx_distinct(dt), approx_distinct(mdn) FROM approx_distinct_interval_test;
----
4 4 4

# Grouped path
query IIII
SELECT g, approx_distinct(ym), approx_distinct(dt), approx_distinct(mdn)
FROM approx_distinct_interval_test GROUP BY g ORDER BY g;
----
1 2 2 2
2 2 2 2

statement ok
DROP TABLE approx_distinct_interval_test;



## This test executes the APPROX_PERCENTILE_CONT aggregation against the test
## data, asserting the estimated quantiles are ±5% their actual values.
Expand Down
Loading