From 1fe79bacac3b587d0287482fad968fd1a281842c Mon Sep 17 00:00:00 2001 From: Adam Alani Date: Tue, 9 Jun 2026 10:22:41 -0400 Subject: [PATCH] fix(lambda): only push referenced params into the merged batch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `LambdaExpr` compresses the body's column-index projection by enumerating every referenced `Column`/`LambdaVariable` index and packing them into a dense range. That compression is correct for outer captures, but it silently broke multi-parameter lambdas: a body like `(k, v) -> v` (where `k` is unused) would have its `LambdaVariable("v")` re-projected from index 1 to index 0 and then, at runtime, read the slot the higher-order function had filled with `k`. Per maintainer feedback on apache#22853, fix it without a breaking change to `LambdaExpr::try_new` / `expressions::lambda(...)`: * `LambdaExpr` now tracks `used_params: HashSet` — the subset of its own declared parameters that the body actually references. The set is computed during a single walk of the body in `LambdaExpr::new`, with a shadow stack that ignores `LambdaVariable`s bound by nested lambdas. For `(k, v) -> func(col, (k, v2) -> k + v2 + v)` the inner `k` shadows the outer `k`, so only `v` flows up as used by the outer lambda. * `LambdaArgument` gets an `Option>` for used parameter names plus a non-breaking `new_with_used_params(...)` constructor. The existing `new(...)` calls it with `None`, which preserves the old "push every declared parameter" behavior. * `LambdaArgument::evaluate` (through `merge_captures_with_variables`) only evaluates and pushes the closures whose parameter name appears in `used_params`, preserving the original declaration order. Unused declared parameters therefore leave no slot in the merged batch, so the body's compressed indices line up directly with the columns the evaluator actually built. * `HigherOrderFunctionExpr::evaluate` calls `new_with_used_params` and forwards `lambda.used_params().clone()`, so all in-tree higher-order UDFs benefit automatically without any callsite change. No public API breakage: `LambdaExpr::try_new`, `expressions::lambda(...)` and `LambdaArgument::new` keep their existing signatures. Two new tests cover the unused-parameter case and the nested-lambda shadowing case; existing tests in `physical-expr` and `expr` continue to pass. --- datafusion/expr/src/higher_order_function.rs | 108 ++++++++-- .../physical-expr/src/expressions/lambda.rs | 190 ++++++++++++++++-- .../src/higher_order_function.rs | 3 +- 3 files changed, 269 insertions(+), 32 deletions(-) diff --git a/datafusion/expr/src/higher_order_function.rs b/datafusion/expr/src/higher_order_function.rs index 413714f498164..373df2d262e18 100644 --- a/datafusion/expr/src/higher_order_function.rs +++ b/datafusion/expr/src/higher_order_function.rs @@ -24,7 +24,7 @@ use crate::expr::{ use crate::type_coercion::functions::value_fields_with_higher_order_udf; use crate::udf_eq::UdfEq; use crate::{ColumnarValue, Documentation, Expr, ExprSchemable}; -use arrow::array::{ArrayRef, RecordBatch}; +use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; @@ -239,6 +239,15 @@ pub struct LambdaArgument { /// For example, for `array_transform([2], v -> -v)`, /// this will be `vec![Field::new("v", DataType::Int32, true)]` params: Vec, + /// Indices into `params` of the parameters that are actually referenced + /// by `body` (taking nested-lambda shadowing into account). + /// + /// `None` means "no information, assume every declared parameter is used" + /// — that is the backwards-compatible behavior of [`Self::new`]. When set, + /// [`Self::evaluate`] skips evaluating and pushing the closures for the + /// parameters not listed here, so unused declared parameters do not shift + /// the columns the body's compressed indices expect. + used_param_indices: Option>, /// The body of the lambda /// /// For example, for `array_transform([2], v -> -v)`, @@ -257,26 +266,64 @@ pub struct LambdaArgument { } impl LambdaArgument { + /// Build a [`LambdaArgument`] that treats every declared parameter as + /// used. This is the backwards-compatible behavior. Prefer + /// [`Self::new_with_used_params`] when the caller knows which subset of + /// the lambda's parameters the body actually references — otherwise the + /// merged batch will still contain columns for unused parameters. pub fn new( params: Vec, body: Arc, captures: Option, ) -> Self { + Self::new_with_used_params(params, body, captures, None) + } + + /// Build a [`LambdaArgument`] knowing which subset of `params` (by name) + /// the lambda body actually references. + /// + /// When `used_params` is `Some(set)`, [`Self::evaluate`] only evaluates + /// and pushes the closures whose corresponding parameter name is in + /// `set`, in the original declaration order of `params`. Unused declared + /// parameters leave no slot in the merged batch, so the body's compressed + /// column indices line up directly. When `used_params` is `None`, + /// behavior is identical to [`Self::new`]. + pub fn new_with_used_params( + params: Vec, + body: Arc, + captures: Option, + used_params: Option>, + ) -> Self { + let used_param_indices = used_params.map(|set| { + params + .iter() + .enumerate() + .filter(|(_, f)| set.contains(f.name())) + .map(|(i, _)| i) + .collect::>() + }); + + let effective_params: Vec = match &used_param_indices { + Some(indices) => indices.iter().map(|i| Arc::clone(¶ms[*i])).collect(), + None => params.clone(), + }; + let fields = match &captures { Some(batch) => batch .schema_ref() .fields() .iter() .cloned() - .chain(params.clone()) + .chain(effective_params) .collect(), - None => params.clone(), + None => effective_params, }; let schema = Arc::new(Schema::new(fields)); Self { params, + used_param_indices, body, schema, captures, @@ -344,6 +391,7 @@ impl LambdaArgument { spread_captures.as_ref(), Arc::clone(&self.schema), &self.params, + self.used_param_indices.as_deref(), args, )?; @@ -355,6 +403,7 @@ fn merge_captures_with_variables( captures: Option<&RecordBatch>, schema: SchemaRef, params: &[FieldRef], + used_param_indices: Option<&[usize]>, variables: &[&dyn Fn() -> Result], ) -> Result { if variables.len() < params.len() { @@ -365,23 +414,56 @@ fn merge_captures_with_variables( ); } + let push_param_arrays = |columns: &mut Vec| -> Result<()> { + match used_param_indices { + Some(indices) => { + for &i in indices { + columns.push(variables[i]()?); + } + } + None => { + for arg in &variables[..params.len()] { + columns.push(arg()?); + } + } + } + Ok(()) + }; + let columns = match captures { Some(captures) => { let mut columns = captures.columns().to_vec(); - - for arg in &variables[..params.len()] { - columns.push(arg()?); - } - + push_param_arrays(&mut columns)?; + columns + } + None => { + let mut columns = Vec::with_capacity( + used_param_indices + .map(<[usize]>::len) + .unwrap_or(params.len()), + ); + push_param_arrays(&mut columns)?; columns } - None => variables - .iter() - .take(params.len()) - .map(|arg| arg()) - .collect::>()?, }; + if columns.is_empty() { + // Constant lambda body with no captures and no used parameters. We + // still need a row count for the merged batch, so evaluate one + // variable just to derive it. This is essentially free in the common + // case (the variables already exist as closures over arrays the + // caller computed up front). + let row_count = match variables.first() { + Some(first) => first()?.len(), + None => 0, + }; + return Ok(RecordBatch::try_new_with_options( + schema, + vec![], + &RecordBatchOptions::new().with_row_count(Some(row_count)), + )?); + } + Ok(RecordBatch::try_new(schema, columns)?) } diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs index 9275821ae9150..45edca09d1b9b 100644 --- a/datafusion/physical-expr/src/expressions/lambda.rs +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -31,7 +31,7 @@ use arrow::{ }; use datafusion_common::{ HashMap, plan_err, - tree_node::{Transformed, TreeNode, TreeNodeRecursion}, + tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor}, }; use datafusion_common::{HashSet, Result, internal_err}; use datafusion_expr::ColumnarValue; @@ -43,6 +43,15 @@ pub struct LambdaExpr { body: Arc, projected_body: Arc, projection: Vec, + /// Subset of `params` (by name) that the body actually references, + /// computed with nested-lambda shadow tracking. Empty when no parameter + /// is referenced by this lambda's own body. + /// + /// The higher-order function uses this to only evaluate and push the + /// parameters the body actually needs into the merged evaluation batch, + /// which keeps the body's compressed column indices aligned with the + /// batch layout produced at runtime. + used_params: HashSet, } // Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196] @@ -60,7 +69,7 @@ impl Hash for LambdaExpr { } impl LambdaExpr { - /// Create a new lambda expression with the given parameters and body + /// Create a new lambda expression with the given parameters and body. pub fn try_new(params: Vec, body: Arc) -> Result { if !all_unique(¶ms) { return plan_err!( @@ -75,27 +84,30 @@ impl LambdaExpr { } fn new(params: Vec, body: Arc) -> Self { - let mut used_column_indices = HashSet::new(); + let own_params: HashSet = params.iter().cloned().collect(); - body.apply(|node| { - if let Some(col) = node.downcast_ref::() { - used_column_indices.insert(col.index()); - } else if let Some(var) = node.downcast_ref::() { - used_column_indices.insert(var.index()); - } - - Ok(TreeNodeRecursion::Continue) - }) - .expect("closure should be infallible"); + let mut visitor = CollectUsedVisitor { + own_params: &own_params, + used_indices: HashSet::new(), + used_param_names: HashSet::new(), + shadow_stack: Vec::new(), + }; + body.visit(&mut visitor).expect("visitor is infallible"); + let CollectUsedVisitor { + used_indices, + used_param_names, + .. + } = visitor; - let mut projection = used_column_indices.into_iter().collect::>(); + let mut projection = used_indices.into_iter().collect::>(); projection.sort(); let column_index_map = projection .iter() + .copied() .enumerate() - .map(|(projected, original)| (*original, projected)) + .map(|(new_idx, original)| (original, new_idx)) .collect::>(); let projected_body = Arc::clone(&body) @@ -129,6 +141,7 @@ impl LambdaExpr { body, projected_body, projection, + used_params: used_param_names, } } @@ -149,6 +162,67 @@ impl LambdaExpr { pub(crate) fn projected_body(&self) -> &Arc { &self.projected_body } + + /// Subset of [`params`](Self::params) (by name) that the body actually + /// references, taking nested-lambda shadowing into account. Used by the + /// higher-order function evaluator to skip evaluating/pushing parameters + /// the lambda body does not need, so that unused declared parameters do + /// not shift the merged batch's column positions out of sync with the + /// body's compressed indices. + pub fn used_params(&self) -> &HashSet { + &self.used_params + } +} + +/// Walks the body of a [`LambdaExpr`] and collects, on a single pass: +/// +/// * `used_indices` — every `Column` / `LambdaVariable` index referenced +/// anywhere in the tree (including inside nested lambdas). This drives +/// the `projection` used to slice the outer batch. +/// * `used_param_names` — the subset of *this* lambda's `own_params` that +/// the body actually references, with nested-lambda parameters shadowing +/// the outer ones. For example, in +/// `(k, v) -> func(col, (k, v2) -> k + v2 + v)` the inner `k` shadows the +/// outer `k`, so only `v` flows up as used. +/// +/// The shadow stack uses `TreeNodeVisitor`'s `f_down` / `f_up` callbacks +/// directly: push a frame when entering a nested [`LambdaExpr`], pop it +/// when leaving. +struct CollectUsedVisitor<'a> { + own_params: &'a HashSet, + used_indices: HashSet, + used_param_names: HashSet, + shadow_stack: Vec>, +} + +impl TreeNodeVisitor<'_> for CollectUsedVisitor<'_> { + type Node = Arc; + + fn f_down(&mut self, node: &Self::Node) -> Result { + if let Some(col) = node.downcast_ref::() { + self.used_indices.insert(col.index()); + } else if let Some(var) = node.downcast_ref::() { + self.used_indices.insert(var.index()); + + let name = var.name(); + let shadowed = self.shadow_stack.iter().any(|frame| frame.contains(name)); + if !shadowed && self.own_params.contains(name) { + self.used_param_names.insert(name.to_string()); + } + } else if let Some(nested) = node.downcast_ref::() { + self.shadow_stack + .push(nested.params.iter().cloned().collect()); + } + + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, node: &Self::Node) -> Result { + if node.downcast_ref::().is_some() { + self.shadow_stack.pop(); + } + Ok(TreeNodeRecursion::Continue) + } } impl std::fmt::Display for LambdaExpr { @@ -195,7 +269,7 @@ impl PhysicalExpr for LambdaExpr { } } -/// Create a lambda expression +/// Create a lambda expression. pub fn lambda( params: impl IntoIterator>, body: Arc, @@ -234,10 +308,15 @@ fn check_async_udf(body: &Arc) -> Result<()> { #[cfg(test)] mod tests { - use crate::expressions::{NoOp, lambda::lambda}; - use arrow::{array::RecordBatch, datatypes::Schema}; + use crate::expressions::{Column, LambdaVariable, NoOp, lambda::lambda}; + use arrow::{ + array::RecordBatch, + datatypes::{DataType, Field, Schema}, + }; use std::sync::Arc; + use super::LambdaExpr; + #[test] fn test_lambda_evaluate() { let lambda = lambda(["a"], Arc::new(NoOp::new())).unwrap(); @@ -249,4 +328,79 @@ mod tests { fn test_lambda_duplicate_name() { assert!(lambda(["a", "a"], Arc::new(NoOp::new())).is_err()); } + + /// A two-parameter lambda whose body only references the second + /// parameter (`v`) must report only `v` as used. The higher-order + /// function uses this set to push only `v` into the merged batch, so + /// the body's compressed `LambdaVariable` index for `v` lines up with + /// the batch layout. + #[test] + fn test_used_params_collects_only_referenced_param() { + let v_field = Arc::new(Field::new("v", DataType::Int32, true)); + let body = Arc::new(LambdaVariable::new(1, Arc::clone(&v_field))); + + let lambda = + LambdaExpr::try_new(vec!["k".to_string(), "v".to_string()], body).unwrap(); + + assert_eq!(lambda.projection(), &[1]); + let used = lambda.used_params(); + assert!(used.contains("v")); + assert!(!used.contains("k")); + assert_eq!(used.len(), 1); + } + + /// Inside a nested lambda that re-declares one of the outer parameter + /// names, only the non-shadowed outer references should be reported as + /// used by the outer lambda. In + /// `(k, v) -> func(col, (k, v2) -> k + v2 + v)` the inner `k` shadows + /// the outer `k`, so the outer lambda must only see `v` as used. + #[test] + fn test_used_params_handles_shadowing_inside_nested_lambda() { + let outer_k_field = Arc::new(Field::new("k", DataType::Int32, true)); + let outer_v_field = Arc::new(Field::new("v", DataType::Int32, true)); + let inner_v2_field = Arc::new(Field::new("v2", DataType::Int32, true)); + + // Inner lambda body references "k" (inner's), "v2" (inner's), and + // "v" (outer's). Build it directly with the dense compressed + // indices the inner LambdaExpr::new would produce: sorted referenced + // indices, so the names alone matter here — what matters for + // shadow tracking is the names, not the indices. + let inner_body: Arc = + Arc::new(crate::expressions::BinaryExpr::new( + Arc::new(crate::expressions::BinaryExpr::new( + Arc::new(LambdaVariable::new(1, Arc::clone(&outer_k_field))), + datafusion_expr::Operator::Plus, + Arc::new(LambdaVariable::new(2, Arc::clone(&inner_v2_field))), + )), + datafusion_expr::Operator::Plus, + Arc::new(LambdaVariable::new(0, Arc::clone(&outer_v_field))), + )); + let inner_lambda = Arc::new( + LambdaExpr::try_new(vec!["k".to_string(), "v2".to_string()], inner_body) + .unwrap(), + ); + + // Outer body wraps the inner lambda in a binary op next to a + // regular column reference so the walk has something non-trivial + // to descend through. The outer body references the inner lambda + // via `inner_lambda`. + let outer_body: Arc = + Arc::new(crate::expressions::BinaryExpr::new( + Arc::new(Column::new("col", 0)), + datafusion_expr::Operator::Plus, + inner_lambda, + )); + + let outer_lambda = + LambdaExpr::try_new(vec!["k".to_string(), "v".to_string()], outer_body) + .unwrap(); + + let used = outer_lambda.used_params(); + assert!(used.contains("v"), "outer's `v` should be reported as used"); + assert!( + !used.contains("k"), + "outer's `k` is shadowed inside the nested lambda and should not be reported as used" + ); + assert_eq!(used.len(), 1); + } } diff --git a/datafusion/physical-expr/src/higher_order_function.rs b/datafusion/physical-expr/src/higher_order_function.rs index 7390eb33a0922..7fa312c48c41b 100644 --- a/datafusion/physical-expr/src/higher_order_function.rs +++ b/datafusion/physical-expr/src/higher_order_function.rs @@ -345,7 +345,7 @@ impl PhysicalExpr for HigherOrderFunctionExpr { .filter(|i| *i < batch.num_columns()) .collect::>(); - Ok(ValueOrLambda::Lambda(LambdaArgument::new( + Ok(ValueOrLambda::Lambda(LambdaArgument::new_with_used_params( params, Arc::clone(lambda.projected_body()), if projection.is_empty() { @@ -353,6 +353,7 @@ impl PhysicalExpr for HigherOrderFunctionExpr { } else { Some(batch.project(&projection)?) }, + Some(lambda.used_params().clone()), ))) } ArgSlot::Value => {