Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ mod primitive_filter;
mod result;
mod static_filter;
mod strategy;
mod transform;

use static_filter::StaticFilter;
use strategy::instantiate_static_filter;
Expand Down
111 changes: 108 additions & 3 deletions datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,22 @@ where
fn check(&self, needle: T::Native) -> bool {
self.bits.get_bit(needle.as_usize())
}

/// Check membership using a raw values slice (zero-copy path for type reinterpretation).
#[inline]
pub(super) fn contains_slice(
&self,
values: &[T::Native],
nulls: Option<&NullBuffer>,
negated: bool,
) -> BooleanArray {
build_in_list_result(values.len(), nulls, self.null_count > 0, negated, |i| {
// SAFETY: `build_in_list_result` invokes this closure for
// indices in `0..values.len()`.
let needle = unsafe { *values.get_unchecked(i) };
self.check(needle)
})
}
}

impl<T> StaticFilter for BitmapFilter<T>
Expand Down Expand Up @@ -174,6 +190,98 @@ where
}
}

/// A branchless filter for very small fixed-width primitive IN lists.
///
/// Uses const generics to unroll the membership check into a fixed-size
/// comparison chain, outperforming hash lookups for small lists due to:
/// - No branching (uses bitwise OR to combine comparisons)
/// - Better CPU pipelining
/// - No hash computation overhead
pub(super) struct BranchlessFilter<T: ArrowPrimitiveType, const N: usize> {
null_count: usize,
values: [T::Native; N],
}

impl<T: ArrowPrimitiveType, const N: usize> BranchlessFilter<T, N>
where
T::Native: Copy + PartialEq,
{
/// Try to create a branchless filter if the array has exactly N non-null values.
pub(super) fn try_new(in_array: &ArrayRef) -> Option<Result<Self>> {
let in_array = in_array.as_primitive_opt::<T>()?;
let non_null_count = in_array.len() - in_array.null_count();
if non_null_count != N {
return None;
}
// Use default_value() from ArrowPrimitiveType trait instead of Default::default()
let mut arr = [T::default_value(); N];
let mut i = 0;
for value in in_array.iter().flatten() {
arr[i] = value;
i += 1;
}
debug_assert_eq!(i, N);
Some(Ok(Self {
null_count: in_array.null_count(),
values: arr,
}))
}

/// Branchless membership check using OR-chain.
#[inline(always)]
fn check(&self, needle: T::Native) -> bool {
self.values
.iter()
.fold(false, |acc, &v| acc | (v == needle))
}

/// Check membership using a raw values slice (zero-copy path for type reinterpretation).
#[inline]
pub(super) fn contains_slice(
&self,
values: &[T::Native],
nulls: Option<&NullBuffer>,
negated: bool,
) -> BooleanArray {
build_in_list_result(values.len(), nulls, self.null_count > 0, negated, |i| {
// SAFETY: `build_in_list_result` invokes this closure for
// indices in `0..values.len()`.
let needle = unsafe { *values.get_unchecked(i) };
self.check(needle)
})
}
}

impl<T: ArrowPrimitiveType, const N: usize> StaticFilter for BranchlessFilter<T, N>
where
T::Native: Copy + PartialEq + Send + Sync,
{
fn null_count(&self) -> usize {
self.null_count
}

fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
handle_dictionary!(self, v, negated);
let v = v.as_primitive_opt::<T>().ok_or_else(|| {
exec_datafusion_err!("Failed to downcast array to primitive type")
})?;
let input_values = v.values();
Ok(build_in_list_result(
v.len(),
v.nulls(),
self.null_count > 0,
negated,
#[inline(always)]
|i| {
// SAFETY: `build_in_list_result` invokes this closure for
// indices in `0..v.len()`, which matches `input_values.len()`.
let needle = unsafe { *input_values.get_unchecked(i) };
self.check(needle)
},
))
}
}

/// Wrapper for f32 that implements Hash and Eq using bit comparison.
/// This treats NaN values as equal to each other when they have the same bit pattern.
#[derive(Clone, Copy)]
Expand Down Expand Up @@ -359,9 +467,6 @@ macro_rules! primitive_static_filter {
};
}

// Generate specialized filters for all integer primitive types
primitive_static_filter!(Int8StaticFilter, Int8Type);
primitive_static_filter!(Int16StaticFilter, Int16Type);
primitive_static_filter!(Int32StaticFilter, Int32Type);
primitive_static_filter!(Int64StaticFilter, Int64Type);
primitive_static_filter!(UInt32StaticFilter, UInt32Type);
Expand Down
135 changes: 116 additions & 19 deletions datafusion/physical-expr/src/expressions/in_list/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,84 @@ use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::compute::cast;
use arrow::datatypes::{DataType, UInt8Type, UInt16Type};
use datafusion_common::Result;
use arrow::datatypes::*;
use datafusion_common::{Result, exec_datafusion_err};

use super::array_static_filter::ArrayStaticFilter;
use super::primitive_filter::*;
use super::static_filter::StaticFilter;
use super::transform::{make_bitmap_filter, make_branchless_filter};

/// Maximum list size for branchless lookup on 1-byte primitives (Int8, UInt8).
const BRANCHLESS_MAX_1B: usize = 16;

/// Maximum list size for branchless lookup on 2-byte primitives (Int16, UInt16).
const BRANCHLESS_MAX_2B: usize = 8;

/// Maximum list size for branchless lookup on 4-byte primitives (Int32, UInt32, Float32).
const BRANCHLESS_MAX_4B: usize = 32;

/// Maximum list size for branchless lookup on 8-byte primitives (Int64, UInt64, Float64).
const BRANCHLESS_MAX_8B: usize = 16;

/// Maximum list size for branchless lookup on 16-byte types (Decimal128).
const BRANCHLESS_MAX_16B: usize = 4;

/// The lookup strategy to use for a given data type and list size.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum FilterStrategy {
/// Bitmap filter for u8/u16 domains.
Bitmap1B,
Bitmap2B,
/// Branchless OR-chain for small lists.
Branchless,
/// Generic ArrayStaticFilter fallback.
Generic,
}

/// Selects the lookup strategy based on data type and list size.
fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy {
match dt.primitive_width() {
Some(1) => {
if len <= BRANCHLESS_MAX_1B {
FilterStrategy::Branchless
} else {
FilterStrategy::Bitmap1B
}
}
Some(2) => {
if len <= BRANCHLESS_MAX_2B {
FilterStrategy::Branchless
} else {
FilterStrategy::Bitmap2B
}
}
Some(4) => {
if len <= BRANCHLESS_MAX_4B {
FilterStrategy::Branchless
} else {
FilterStrategy::Generic
}
}
Some(8) => {
if len <= BRANCHLESS_MAX_8B {
FilterStrategy::Branchless
} else {
FilterStrategy::Generic
}
}
Some(16) => {
if len <= BRANCHLESS_MAX_16B {
FilterStrategy::Branchless
} else {
FilterStrategy::Generic
}
}
_ => FilterStrategy::Generic,
}
}

/// Creates the optimal static filter for the given array.
pub(super) fn instantiate_static_filter(
in_array: ArrayRef,
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
Expand All @@ -36,22 +107,48 @@ pub(super) fn instantiate_static_filter(
DataType::Dictionary(_, value_type) => cast(&in_array, value_type.as_ref())?,
_ => in_array,
};
match in_array.data_type() {
// Integer primitive types
DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)),
DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)),
DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)),
DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)),
DataType::UInt8 => Ok(Arc::new(BitmapFilter::<UInt8Type>::try_new(&in_array)?)),
DataType::UInt16 => Ok(Arc::new(BitmapFilter::<UInt16Type>::try_new(&in_array)?)),
DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)),
DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)),
// Float primitive types (use ordered wrappers for Hash/Eq)
DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)),
DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)),
_ => {
/* fall through to generic implementation for unsupported types (Struct, etc.) */
Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?))
}
use FilterStrategy::*;

let len = in_array.len();
let dt = in_array.data_type();
let strategy = select_strategy(dt, len);

match (dt, strategy) {
// Bitmap filters for 1-byte and 2-byte types
(_, Bitmap1B) => make_bitmap_filter::<UInt8Type>(&in_array),
(_, Bitmap2B) => make_bitmap_filter::<UInt16Type>(&in_array),

// Branchless filters for small lists of primitives
(_, Branchless) => dispatch_branchless(&in_array).ok_or_else(|| {
exec_datafusion_err!(
"Branchless strategy selected but no filter for {:?}",
dt
)
})?,

// Fallback for larger primitive lists or complex types.
(_, Generic) => match dt {
DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)),
DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)),
DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)),
DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)),
DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)),
DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)),
_ => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)),
},
}
}

fn dispatch_branchless(
arr: &ArrayRef,
) -> Option<Result<Arc<dyn StaticFilter + Send + Sync>>> {
// Dispatch to width-specific branchless filter.
match arr.data_type().primitive_width() {
Some(1) => Some(make_branchless_filter::<UInt8Type>(arr)),
Some(2) => Some(make_branchless_filter::<UInt16Type>(arr)),
Some(4) => Some(make_branchless_filter::<UInt32Type>(arr)),
Some(8) => Some(make_branchless_filter::<UInt64Type>(arr)),
Some(16) => Some(make_branchless_filter::<Decimal128Type>(arr)),
_ => None,
}
}
Loading
Loading