diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 50ff3936937bf..be73b0a9d11be 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -41,6 +41,7 @@ mod primitive_filter; mod result; mod static_filter; mod strategy; +mod transform; use static_filter::StaticFilter; use strategy::instantiate_static_filter; diff --git a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs index 0e2ee564656ac..87fb35f5dfc36 100644 --- a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs +++ b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs @@ -142,6 +142,22 @@ where fn check(&self, needle: T::Native) -> bool { self.bits.get_bit(needle.as_usize()) } + + /// Check membership using a raw values slice (zero-copy path for type reinterpretation). + #[inline] + pub(super) fn contains_slice( + &self, + values: &[T::Native], + nulls: Option<&NullBuffer>, + negated: bool, + ) -> BooleanArray { + build_in_list_result(values.len(), nulls, self.null_count > 0, negated, |i| { + // SAFETY: `build_in_list_result` invokes this closure for + // indices in `0..values.len()`. + let needle = unsafe { *values.get_unchecked(i) }; + self.check(needle) + }) + } } impl StaticFilter for BitmapFilter @@ -174,6 +190,98 @@ where } } +/// A branchless filter for very small fixed-width primitive IN lists. +/// +/// Uses const generics to unroll the membership check into a fixed-size +/// comparison chain, outperforming hash lookups for small lists due to: +/// - No branching (uses bitwise OR to combine comparisons) +/// - Better CPU pipelining +/// - No hash computation overhead +pub(super) struct BranchlessFilter { + null_count: usize, + values: [T::Native; N], +} + +impl BranchlessFilter +where + T::Native: Copy + PartialEq, +{ + /// Try to create a branchless filter if the array has exactly N non-null values. + pub(super) fn try_new(in_array: &ArrayRef) -> Option> { + let in_array = in_array.as_primitive_opt::()?; + let non_null_count = in_array.len() - in_array.null_count(); + if non_null_count != N { + return None; + } + // Use default_value() from ArrowPrimitiveType trait instead of Default::default() + let mut arr = [T::default_value(); N]; + let mut i = 0; + for value in in_array.iter().flatten() { + arr[i] = value; + i += 1; + } + debug_assert_eq!(i, N); + Some(Ok(Self { + null_count: in_array.null_count(), + values: arr, + })) + } + + /// Branchless membership check using OR-chain. + #[inline(always)] + fn check(&self, needle: T::Native) -> bool { + self.values + .iter() + .fold(false, |acc, &v| acc | (v == needle)) + } + + /// Check membership using a raw values slice (zero-copy path for type reinterpretation). + #[inline] + pub(super) fn contains_slice( + &self, + values: &[T::Native], + nulls: Option<&NullBuffer>, + negated: bool, + ) -> BooleanArray { + build_in_list_result(values.len(), nulls, self.null_count > 0, negated, |i| { + // SAFETY: `build_in_list_result` invokes this closure for + // indices in `0..values.len()`. + let needle = unsafe { *values.get_unchecked(i) }; + self.check(needle) + }) + } +} + +impl StaticFilter for BranchlessFilter +where + T::Native: Copy + PartialEq + Send + Sync, +{ + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + let v = v.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("Failed to downcast array to primitive type") + })?; + let input_values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + #[inline(always)] + |i| { + // SAFETY: `build_in_list_result` invokes this closure for + // indices in `0..v.len()`, which matches `input_values.len()`. + let needle = unsafe { *input_values.get_unchecked(i) }; + self.check(needle) + }, + )) + } +} + /// Wrapper for f32 that implements Hash and Eq using bit comparison. /// This treats NaN values as equal to each other when they have the same bit pattern. #[derive(Clone, Copy)] @@ -359,9 +467,6 @@ macro_rules! primitive_static_filter { }; } -// Generate specialized filters for all integer primitive types -primitive_static_filter!(Int8StaticFilter, Int8Type); -primitive_static_filter!(Int16StaticFilter, Int16Type); primitive_static_filter!(Int32StaticFilter, Int32Type); primitive_static_filter!(Int64StaticFilter, Int64Type); primitive_static_filter!(UInt32StaticFilter, UInt32Type); diff --git a/datafusion/physical-expr/src/expressions/in_list/strategy.rs b/datafusion/physical-expr/src/expressions/in_list/strategy.rs index 21b658fad0382..e053ca3282b2f 100644 --- a/datafusion/physical-expr/src/expressions/in_list/strategy.rs +++ b/datafusion/physical-expr/src/expressions/in_list/strategy.rs @@ -19,13 +19,84 @@ use std::sync::Arc; use arrow::array::ArrayRef; use arrow::compute::cast; -use arrow::datatypes::{DataType, UInt8Type, UInt16Type}; -use datafusion_common::Result; +use arrow::datatypes::*; +use datafusion_common::{Result, exec_datafusion_err}; use super::array_static_filter::ArrayStaticFilter; use super::primitive_filter::*; use super::static_filter::StaticFilter; +use super::transform::{make_bitmap_filter, make_branchless_filter}; +/// Maximum list size for branchless lookup on 1-byte primitives (Int8, UInt8). +const BRANCHLESS_MAX_1B: usize = 16; + +/// Maximum list size for branchless lookup on 2-byte primitives (Int16, UInt16). +const BRANCHLESS_MAX_2B: usize = 8; + +/// Maximum list size for branchless lookup on 4-byte primitives (Int32, UInt32, Float32). +const BRANCHLESS_MAX_4B: usize = 32; + +/// Maximum list size for branchless lookup on 8-byte primitives (Int64, UInt64, Float64). +const BRANCHLESS_MAX_8B: usize = 16; + +/// Maximum list size for branchless lookup on 16-byte types (Decimal128). +const BRANCHLESS_MAX_16B: usize = 4; + +/// The lookup strategy to use for a given data type and list size. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FilterStrategy { + /// Bitmap filter for u8/u16 domains. + Bitmap1B, + Bitmap2B, + /// Branchless OR-chain for small lists. + Branchless, + /// Generic ArrayStaticFilter fallback. + Generic, +} + +/// Selects the lookup strategy based on data type and list size. +fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy { + match dt.primitive_width() { + Some(1) => { + if len <= BRANCHLESS_MAX_1B { + FilterStrategy::Branchless + } else { + FilterStrategy::Bitmap1B + } + } + Some(2) => { + if len <= BRANCHLESS_MAX_2B { + FilterStrategy::Branchless + } else { + FilterStrategy::Bitmap2B + } + } + Some(4) => { + if len <= BRANCHLESS_MAX_4B { + FilterStrategy::Branchless + } else { + FilterStrategy::Generic + } + } + Some(8) => { + if len <= BRANCHLESS_MAX_8B { + FilterStrategy::Branchless + } else { + FilterStrategy::Generic + } + } + Some(16) => { + if len <= BRANCHLESS_MAX_16B { + FilterStrategy::Branchless + } else { + FilterStrategy::Generic + } + } + _ => FilterStrategy::Generic, + } +} + +/// Creates the optimal static filter for the given array. pub(super) fn instantiate_static_filter( in_array: ArrayRef, ) -> Result> { @@ -36,22 +107,48 @@ pub(super) fn instantiate_static_filter( DataType::Dictionary(_, value_type) => cast(&in_array, value_type.as_ref())?, _ => in_array, }; - match in_array.data_type() { - // Integer primitive types - DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), - DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), - DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), - DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), - DataType::UInt8 => Ok(Arc::new(BitmapFilter::::try_new(&in_array)?)), - DataType::UInt16 => Ok(Arc::new(BitmapFilter::::try_new(&in_array)?)), - DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), - DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), - // Float primitive types (use ordered wrappers for Hash/Eq) - DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)), - DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)), - _ => { - /* fall through to generic implementation for unsupported types (Struct, etc.) */ - Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) - } + use FilterStrategy::*; + + let len = in_array.len(); + let dt = in_array.data_type(); + let strategy = select_strategy(dt, len); + + match (dt, strategy) { + // Bitmap filters for 1-byte and 2-byte types + (_, Bitmap1B) => make_bitmap_filter::(&in_array), + (_, Bitmap2B) => make_bitmap_filter::(&in_array), + + // Branchless filters for small lists of primitives + (_, Branchless) => dispatch_branchless(&in_array).ok_or_else(|| { + exec_datafusion_err!( + "Branchless strategy selected but no filter for {:?}", + dt + ) + })?, + + // Fallback for larger primitive lists or complex types. + (_, Generic) => match dt { + DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), + DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), + DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), + DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), + DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)), + DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)), + _ => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), + }, + } +} + +fn dispatch_branchless( + arr: &ArrayRef, +) -> Option>> { + // Dispatch to width-specific branchless filter. + match arr.data_type().primitive_width() { + Some(1) => Some(make_branchless_filter::(arr)), + Some(2) => Some(make_branchless_filter::(arr)), + Some(4) => Some(make_branchless_filter::(arr)), + Some(8) => Some(make_branchless_filter::(arr)), + Some(16) => Some(make_branchless_filter::(arr)), + _ => None, } } diff --git a/datafusion/physical-expr/src/expressions/in_list/transform.rs b/datafusion/physical-expr/src/expressions/in_list/transform.rs new file mode 100644 index 0000000000000..5cac04af4985b --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/transform.rs @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Type transformation utilities for InList filters. +//! +//! Some filters only depend on fixed-width value bit patterns. For those cases, +//! compatible primitive arrays can be reinterpreted to the filter's unsigned +//! storage type without copying values. + +use std::mem::size_of; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, BooleanArray, PrimitiveArray}; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::{ArrowPrimitiveType, DataType}; +use datafusion_common::{Result, exec_datafusion_err}; + +use super::primitive_filter::{BitmapFilter, BitmapFilterType, BranchlessFilter}; +use super::static_filter::{StaticFilter, handle_dictionary}; + +/// Bitmap filter for signed 1-byte and 2-byte primitive arrays. +/// +/// The bitmap implementation is keyed by an unsigned primitive type (`UInt8` or +/// `UInt16`). This wrapper keeps the original array type, such as `Int8`, and +/// only reinterprets values as the unsigned type when probing the bitmap. +struct ReinterpretedBitmap { + expected_data_type: DataType, + inner: BitmapFilter, +} + +impl StaticFilter for ReinterpretedBitmap { + fn null_count(&self) -> usize { + self.inner.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + if v.data_type() != &self.expected_data_type { + return Err(exec_datafusion_err!( + "BitmapFilter: expected {} array, got {}", + self.expected_data_type, + v.data_type() + )); + } + + let data = v.to_data(); + let values: &[T::Native] = &data.buffer::(0)[..v.len()]; + + Ok(self.inner.contains_slice(values, data.nulls(), negated)) + } +} + +/// Branchless filter for primitive arrays that share the same byte width. +/// +/// The inner filter stores values using the unsigned primitive type selected for +/// that width. This wrapper keeps the original array type and only reinterprets +/// values as the unsigned type while probing. +struct ReinterpretedBranchless { + expected_data_type: DataType, + inner: BranchlessFilter, +} + +impl StaticFilter for ReinterpretedBranchless +where + T: ArrowPrimitiveType + 'static, + T::Native: Copy + PartialEq + Send + Sync + 'static, +{ + fn null_count(&self) -> usize { + self.inner.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + if v.data_type() != &self.expected_data_type { + return Err(exec_datafusion_err!( + "BranchlessFilter: expected {} array, got {}", + self.expected_data_type, + v.data_type() + )); + } + + let data = v.to_data(); + let values: &[T::Native] = &data.buffer::(0)[..v.len()]; + + Ok(self.inner.contains_slice(values, data.nulls(), negated)) + } +} + +/// Views a primitive array as another primitive type with the same byte width. +/// +/// This does not convert values. It reuses the existing values buffer and +/// interprets each value's bytes as `T::Native`, preserving the null buffer. +/// The caller must check that the source and target primitive types have the +/// same width. +#[inline] +pub(crate) fn reinterpret_any_primitive_to( + array: &dyn Array, +) -> ArrayRef { + let data = array.to_data(); + let values = data.buffers()[0].clone(); + let buffer = ScalarBuffer::::new(values, data.offset(), data.len()); + Arc::new(PrimitiveArray::::new(buffer, array.nulls().cloned())) +} + +/// Creates a bitmap filter for 1-byte or 2-byte primitive arrays. +/// +/// Unsigned inputs use the bitmap filter directly. Signed inputs of the same +/// width are reinterpreted as the unsigned bitmap type, without copying. +pub(crate) fn make_bitmap_filter( + in_array: &ArrayRef, +) -> Result> +where + T: BitmapFilterType, +{ + if in_array.data_type() == &T::DATA_TYPE { + return Ok(Arc::new(BitmapFilter::::try_new(in_array)?)); + } + + let width = size_of::(); + if in_array.data_type().primitive_width() != Some(width) { + return Err(exec_datafusion_err!( + "BitmapFilter: expected {}-byte primitive array for {} bitmap, got {}", + width, + T::DATA_TYPE, + in_array.data_type() + )); + } + + let reinterpreted = reinterpret_any_primitive_to::(in_array.as_ref()); + let inner = BitmapFilter::::try_new(&reinterpreted)?; + Ok(Arc::new(ReinterpretedBitmap { + expected_data_type: in_array.data_type().clone(), + inner, + })) +} + +/// Creates a branchless filter for primitive types. +/// +/// Dispatches based on byte width and element count: +/// - 1-byte types (Int8, UInt8): supports 0-16 elements +/// - 2-byte types (Int16, UInt16): supports 0-8 elements +/// - 4-byte types (Int32, Float32, etc.): supports 0-32 elements +/// - 8-byte types (Int64, Float64, Timestamp, etc.): supports 0-16 elements +/// - 16-byte types (Decimal128): supports 0-4 elements +pub(crate) fn make_branchless_filter( + in_array: &ArrayRef, +) -> Result> +where + D: ArrowPrimitiveType + 'static, + D::Native: Copy + PartialEq + Send + Sync + 'static, +{ + let is_native = in_array.data_type() == &D::DATA_TYPE; + let width = size_of::(); + let arr = if is_native { + Arc::clone(in_array) + } else { + if in_array.data_type().primitive_width() != Some(width) { + return Err(exec_datafusion_err!( + "BranchlessFilter: expected {width}-byte primitive array, got {}", + in_array.data_type() + )); + } + reinterpret_any_primitive_to::(in_array.as_ref()) + }; + let n = arr.len() - arr.null_count(); + + // Helper to create the filter for a known size N + #[inline] + fn create( + arr: &ArrayRef, + is_native: bool, + expected_data_type: &DataType, + ) -> Result> + where + D::Native: Copy + PartialEq + Send + Sync + 'static, + { + let inner = BranchlessFilter::::try_new(arr) + .expect("size verified") + .expect("type verified"); + if is_native { + Ok(Arc::new(inner)) + } else { + Ok(Arc::new(ReinterpretedBranchless { + expected_data_type: expected_data_type.clone(), + inner, + })) + } + } + + // Keep the branchless path to the list sizes that benchmark well for each + // primitive width. Wider lists are expected to use bitmap or hash filters. + let max_n = match width { + 1 => 16, + 2 => 8, + 4 => 32, + 8 => 16, + 16 => 4, + w => { + return datafusion_common::exec_err!( + "Branchless filter not supported for {w}-byte types" + ); + } + }; + + if n > max_n { + return datafusion_common::exec_err!( + "Branchless filter for {width}-byte types supports 0-{max_n} elements, got {n}" + ); + } + + // `BranchlessFilter` needs `N` at compile time, so map the runtime + // list length to the corresponding const-generic instantiation. + // + // For example, this expands to: + // + // match n { + // 0 => create::(&arr, is_native), + // 1 => create::(&arr, is_native), + // ... + // 32 => create::(&arr, is_native), + // _ => unreachable!("validated branchless list length"), + // } + macro_rules! dispatch_n { + ($($n:literal),* $(,)?) => { + match n { + $($n => create::(&arr, is_native, in_array.data_type()),)* + _ => unreachable!("validated branchless list length"), + } + }; + } + + dispatch_n!( + 0, 1, 2, 3, 4, 5, 6, 7, // 0..=7 + 8, 9, 10, 11, 12, 13, 14, 15, // 8..=15 + 16, 17, 18, 19, 20, 21, 22, 23, // 16..=23 + 24, 25, 26, 27, 28, 29, 30, 31, // 24..=31 + 32, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow::array::{ArrayRef, BooleanArray, Int8Array, Int16Array, Int32Array}; + use arrow::datatypes::{UInt8Type, UInt16Type, UInt32Type}; + + #[test] + fn reinterpreted_bitmap_handles_signed_boundaries_and_slices() -> Result<()> { + let haystack: ArrayRef = Arc::new( + Int8Array::from(vec![Some(99), Some(i8::MIN), None, Some(-1), Some(42)]) + .slice(1, 3), + ); + let filter = make_bitmap_filter::(&haystack)?; + let needles = + Int8Array::from(vec![Some(7), Some(i8::MIN), Some(-1), None]).slice(1, 3); + + assert_eq!( + filter.contains(&needles, false)?, + BooleanArray::from(vec![Some(true), Some(true), None]) + ); + assert_eq!( + filter.contains(&needles, true)?, + BooleanArray::from(vec![Some(false), Some(false), None]) + ); + + let haystack: ArrayRef = Arc::new( + Int16Array::from(vec![ + Some(123), + Some(i16::MIN), + None, + Some(-1), + Some(i16::MAX), + ]) + .slice(1, 4), + ); + let filter = make_bitmap_filter::(&haystack)?; + let needles = + Int16Array::from(vec![Some(0), Some(i16::MIN), Some(7), Some(i16::MAX)]) + .slice(1, 3); + + assert_eq!( + filter.contains(&needles, false)?, + BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + assert_eq!( + filter.contains(&needles, true)?, + BooleanArray::from(vec![Some(false), None, Some(false)]) + ); + + Ok(()) + } + + #[test] + fn reinterpreted_branchless_handles_slices() -> Result<()> { + let haystack: ArrayRef = Arc::new( + Int32Array::from(vec![Some(99), Some(-7), None, Some(42)]).slice(1, 3), + ); + let filter = make_branchless_filter::(&haystack)?; + let needles = + Int32Array::from(vec![Some(0), Some(-7), Some(1), Some(42)]).slice(1, 3); + + assert_eq!( + filter.contains(&needles, false)?, + BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + + Ok(()) + } +}