Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions datafusion/functions-nested/src/array_avg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`ScalarUDFImpl`] definitions for array_avg function.

use crate::utils::make_scalar_function;
use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait};
use arrow::datatypes::{
DataType,
DataType::{FixedSizeList, LargeList, List, Null},
Field,
};
use datafusion_common::cast::{as_float64_array, as_generic_list_array};
use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_macros::user_doc;
use std::sync::Arc;

make_udf_expr_and_func!(
ArrayAvg,
array_avg,
array,
"returns the arithmetic mean of elements in a numeric array.",
array_avg_udf
);

#[user_doc(
doc_section(label = "Array Functions"),
description = "Returns the arithmetic mean (sum divided by count) of the elements of the input array. NULL elements are skipped (per SQL aggregate convention) and excluded from the count. Returns NULL if the input row is NULL, every element is NULL, or the array is empty.",
syntax_example = "array_avg(array)",
sql_example = r#"```sql
> select array_avg([1.0, 2.0, 3.0]);
+----------------------------+
| array_avg(List([1.0,2.0,3.0])) |
+----------------------------+
| 2.0 |
+----------------------------+
```"#,
argument(
name = "array",
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ArrayAvg {
signature: Signature,
aliases: Vec<String>,
}

impl Default for ArrayAvg {
fn default() -> Self {
Self::new()
}
}

impl ArrayAvg {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec!["list_avg".to_string()],
}
}
}

impl ScalarUDFImpl for ArrayAvg {
fn name(&self) -> &str {
"array_avg"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let [arg_type] = take_function_args(self.name(), arg_types)?;
let coercion = Some(&ListCoercion::FixedSizedListToList);

if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
return plan_err!("{} does not support type {arg_type}", self.name());
}

let coerced = if matches!(arg_type, Null) {
List(Arc::new(Field::new_list_field(DataType::Float64, true)))
} else {
coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion)
};

Ok(vec![coerced])
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_avg_inner)(&args.args)
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

fn array_avg_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array] = take_function_args("array_avg", args)?;
match array.data_type() {
List(_) => general_array_avg::<i32>(array),
LargeList(_) => general_array_avg::<i64>(array),
arg_type => {
internal_err!("array_avg received unexpected type after coercion: {arg_type}")
}
}
}

fn general_array_avg<O: OffsetSizeTrait>(array: &ArrayRef) -> Result<ArrayRef> {
let list_array = as_generic_list_array::<O>(array)?;
let values = as_float64_array(list_array.values())?;
let offsets = list_array.value_offsets();

let mut builder = Float64Array::builder(list_array.len());

for row in 0..list_array.len() {
if list_array.is_null(row) {
builder.append_null();
continue;
}

let start = offsets[row].as_usize();
let end = offsets[row + 1].as_usize();

// Skip NULL elements per SQL aggregate convention (matches PostgreSQL
// AVG, DuckDB list_avg, Spark aggregate). Empty arrays and all-NULL
// arrays both yield NULL — same behavior as SQL AVG over an empty
// set or all-NULL column.
let mut sum = 0.0_f64;
let mut count: u64 = 0;
for i in start..end {
if values.is_valid(i) {
sum += values.value(i);
count += 1;
}
}

if count > 0 {
builder.append_value(sum / count as f64);
} else {
builder.append_null();
}
}

Ok(Arc::new(builder.finish()))
}
3 changes: 3 additions & 0 deletions datafusion/functions-nested/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub mod macros_lambda;

pub mod array_add;
pub mod array_any_match;
pub mod array_avg;
pub mod array_compact;
pub mod array_filter;
pub mod array_has;
Expand Down Expand Up @@ -95,6 +96,7 @@ use std::sync::Arc;
pub mod expr_fn {
pub use super::array_add::array_add;
pub use super::array_any_match::array_any_match;
pub use super::array_avg::array_avg;
pub use super::array_compact::array_compact;
pub use super::array_filter::array_filter;
pub use super::array_has::array_has;
Expand Down Expand Up @@ -181,6 +183,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
length::array_length_udf(),
array_normalize::array_normalize_udf(),
array_add::array_add_udf(),
array_avg::array_avg_udf(),
array_product::array_product_udf(),
array_scale::array_scale_udf(),
array_subtract::array_subtract_udf(),
Expand Down
165 changes: 165 additions & 0 deletions datafusion/sqllogictest/test_files/array_avg.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

## array_avg

# Basic case
query R
select array_avg([1.0, 2.0, 3.0]);
----
2

# Single element
query R
select array_avg([5.0]);
----
5

# Negative values
query R
select array_avg([-1.0, -2.0, -3.0]);
----
-2

# Positive and negative cancel
query R
select array_avg([1.0, -1.0, 2.0, -2.0]);
----
0

# Non-integer mean (sum / count)
query R
select array_avg([1.0, 2.0]);
----
1.5

# Empty array returns NULL (matches PostgreSQL AVG, DuckDB list_avg, SQL Standard AVG-of-empty-set)
query R
select array_avg(arrow_cast(make_array(), 'List(Float64)'));
----
NULL

# Bare NULL input returns NULL row
query R
select array_avg(NULL);
----
NULL

# NULL elements are skipped from BOTH the sum and the count (SQL aggregate convention).
# avg([1, NULL, 3]) = (1 + 3) / 2 = 2 — not (1 + 3) / 3.
query R
select array_avg([1.0, NULL, 3.0]);
----
2

# Single NULL among numeric: skip the NULL, divide by 1
query R
select array_avg([NULL, 10.0]);
----
10

# All-NULL array returns NULL row (matches SQL AVG over all-NULL)
query R
select array_avg(arrow_cast([NULL, NULL], 'List(Float64)'));
----
NULL

# LargeList support
query R
select array_avg(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'));
----
2

# FixedSizeList input (coerced to List)
query R
select array_avg(arrow_cast([1.0, 2.0, 3.0], 'FixedSizeList(3, Float64)'));
----
2

# Float32 inner type (coerced to Float64)
query R
select array_avg(arrow_cast([1.0, 2.0, 3.0], 'List(Float32)'));
----
2

# Int64 inner type (coerced to Float64) — integer mean returned as Float64
query R
select array_avg(arrow_cast([1, 2, 3], 'List(Int64)'));
----
2

# Integer literals (coerced to Float64)
query R
select array_avg([1, 2, 3]);
----
2

# Integer mean that is NOT an integer (3 / 2 = 1.5)
query R
select array_avg([1, 2]);
----
1.5

# Unsupported non-list input (plan error)
query error array_avg does not support type
select array_avg(1);

# Multi-row query with mix of normal, partial-NULL, all-NULL elements, empty, NULL row
query R
select array_avg(column1) from (values
(make_array(1.0, 2.0, 3.0)),
(make_array(0.0)),
(make_array(1.0, NULL, 4.0)),
(arrow_cast(make_array(), 'List(Float64)')),
(NULL)
) as t(column1);
----
2
0
2.5
NULL
NULL

# Wrong arity (zero args)
query error array_avg function requires 1 argument, got 0
select array_avg();

# Wrong arity (two args)
query error array_avg function requires 1 argument, got 2
select array_avg([1.0], [2.0]);

# Return type is Float64
query RT
select array_avg([1.0, 2.0, 3.0]), arrow_typeof(array_avg([1.0, 2.0, 3.0]));
----
2 Float64

# list_avg alias produces the same result
query R
select list_avg([1.0, 2.0, 3.0]);
----
2

# list_avg alias with NULL row propagates correctly
query R
select list_avg(column1) from (values
(make_array(1.0, 2.0)),
(NULL)
) as t(column1);
----
1.5
NULL
Loading
Loading