From 480bb572d407ebc59449d08389489261289254f4 Mon Sep 17 00:00:00 2001 From: Hadrian Date: Fri, 26 Jun 2026 22:57:52 -0400 Subject: [PATCH 1/5] first pass --- .../simplify_expressions/expr_simplifier.rs | 84 +++++++++++++++++++ .../src/simplify_expressions/utils.rs | 53 +++++++++++- 2 files changed, 136 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 39c8541b51b2f..5e14b71089c8b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1095,6 +1095,12 @@ impl TreeNodeRewriter for Simplifier<'_> { Transformed::yes(*right) } + expr @ Expr::BinaryExpr(_) + if has_associative_op(&expr) && has_adjacent_literals(&expr) => + { + Transformed::yes(reassociate_literals(expr)) + } + // // Rules for Multiply // @@ -2355,6 +2361,63 @@ fn simplify_right_is_one_case( } } +fn reassociate_literals(expr: Expr) -> Expr { + fn flatten(op: Operator, expr: Expr, out: &mut Vec) { + match expr { + Expr::BinaryExpr(expr) if expr.op == op => { + flatten(op, *expr.left, out); + flatten(op, *expr.right, out); + } + expr => out.push(expr), + } + } + + let op = match &expr { + Expr::BinaryExpr(expr) => expr.op, + _ => unreachable!(), + }; + let mut exprs = Vec::new(); + flatten(op, expr, &mut exprs); + let mut exprs = exprs.into_iter(); + + let mut out = exprs.next().unwrap(); + let mut lit = None; + for expr in exprs { + if matches!(expr, Expr::Literal(_, _)) { + if let Some(left) = lit { + lit = Some(Expr::BinaryExpr(BinaryExpr { + left: Box::new(left), + op, + right: Box::new(expr), + })); + } else { + lit = Some(expr); + } + } else { + if let Some(lit) = lit.take() { + out = Expr::BinaryExpr(BinaryExpr { + left: Box::new(out), + op, + right: Box::new(lit), + }); + } + out = Expr::BinaryExpr(BinaryExpr { + left: Box::new(out), + op, + right: Box::new(expr), + }); + } + } + if let Some(lit) = lit.take() { + out = Expr::BinaryExpr(BinaryExpr { + left: Box::new(out), + op, + right: Box::new(lit), + }); + } + out +} + #[cfg(test)] mod tests { use super::*; @@ -3368,6 +3431,27 @@ mod tests { assert_eq!(simplify(expr_eq), lit(true)); } + #[test] + fn test_simplify_nested_associative_expr() { + let plus_expr = (col("c1") + lit(1)) + (lit(2) + col("c2")); + assert_eq!(simplify(plus_expr), col("c1") + lit(3) + col("c2")); + + let mixed_expr = + col("c1") * col("c2") + lit(1) + lit(2) + (lit(3) + lit(4) * col("c3")); + assert_eq!( + simplify(mixed_expr), + col("c1") * col("c2") + lit(6) + lit(4) * col("c3") + ); + + let concat = |left, right| binary_expr(left, Operator::StringConcat, right); + let concat_expr = + concat(concat(concat(col("c1"), lit("a")), lit("b")), col("c2")); + assert_eq!( + simplify(concat_expr), + concat(concat(col("c1"), lit("ab")), col("c2")) + ); + } + #[test] fn test_simplify_regex() { // malformed regex diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index b0908b47602f7..06b600c9279e9 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -18,7 +18,10 @@ //! Utility functions for expression simplification use arrow::datatypes::i256; -use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_common::{ + Result, ScalarValue, internal_err, + tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, +}; use datafusion_expr::{ Case, Expr, Like, Operator, expr::{Between, BinaryExpr, InList}, @@ -290,6 +293,54 @@ pub fn is_lit(expr: &Expr) -> bool { matches!(expr, Expr::Literal(_, _)) } +pub fn has_associative_op(expr: &Expr) -> bool { + let op = match expr { + Expr::BinaryExpr(expr) => expr.op, + _ => unreachable!(), + }; + // TODO: add other associative ops + // TODO: check types (float addition isn't associative) + matches!(op, Operator::Plus | Operator::StringConcat) +} + +pub fn has_adjacent_literals(expr: &Expr) -> bool { + struct AdjacentLiteralVisitor { + op: Operator, + last_expr_was_literal: bool, + } + + impl<'n> TreeNodeVisitor<'n> for AdjacentLiteralVisitor { + type Node = Expr; + + fn f_down(&mut self, node: &'n Self::Node) -> Result { + match node { + Expr::BinaryExpr(expr) if expr.op == self.op => { + Ok(TreeNodeRecursion::Continue) + } + Expr::Literal(_, _) => { + if self.last_expr_was_literal { + Ok(TreeNodeRecursion::Stop) + } else { + self.last_expr_was_literal = true; + Ok(TreeNodeRecursion::Continue) + } + } + _ => Ok(TreeNodeRecursion::Jump), + } + } + } + + let op = match expr { + Expr::BinaryExpr(expr) => expr.op, + _ => unreachable!(), + }; + let mut visitor = AdjacentLiteralVisitor { + op, + last_expr_was_literal: false, + }; + expr.visit(&mut visitor).unwrap() == TreeNodeRecursion::Stop +} + /// Checks if `eq_expr` is `A = L1` and `ne_expr` is `A != L2` where L1 != L2. /// This pattern can be simplified to just `A = L1` since if A equals L1 /// and L1 is different from L2, then A is automatically not equal to L2. From 81032a9c9aeab1a8d37cdcb36946591a776ee78b Mon Sep 17 00:00:00 2001 From: Hadrian Date: Fri, 26 Jun 2026 23:12:23 -0400 Subject: [PATCH 2/5] clean up code --- .../simplify_expressions/expr_simplifier.rs | 31 +++++-------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5e14b71089c8b..495543a88dcb9 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -39,12 +39,12 @@ use datafusion_common::{ metadata::FieldMetadata, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; -use datafusion_expr::expr::HigherOrderFunction; use datafusion_expr::{ BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator, Volatility, and, binary::BinaryTypeCoercer, lit, or, preimage::PreimageResult, }; use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult}; +use datafusion_expr::{binary_expr, expr::HigherOrderFunction}; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ expr::{InList, InSubquery}, @@ -2383,39 +2383,24 @@ fn reassociate_literals(expr: Expr) -> Expr { let mut out = exprs.next().unwrap(); let mut lit = None; for expr in exprs { - if matches!(expr, Expr::Literal(_, _)) { + if is_lit(&expr) { if let Some(left) = lit { - lit = Some(Expr::BinaryExpr(BinaryExpr { - left: Box::new(left), - op, - right: Box::new(expr), - })); + lit = Some(binary_expr(left, op, expr)); } else { lit = Some(expr); } } else { if let Some(lit) = lit.take() { - out = Expr::BinaryExpr(BinaryExpr { - left: Box::new(out), - op, - right: Box::new(lit), - }); + out = binary_expr(out, op, lit); } - out = Expr::BinaryExpr(BinaryExpr { - left: Box::new(out), - op, - right: Box::new(expr), - }); + out = binary_expr(out, op, expr); } } if let Some(lit) = lit.take() { - out = Expr::BinaryExpr(BinaryExpr { - left: Box::new(out), - op, - right: Box::new(lit), - }); + binary_expr(out, op, lit) + } else { + out } - out } #[cfg(test)] From d8a3be540ca9e0fcc65071dccb78403af39590e5 Mon Sep 17 00:00:00 2001 From: Hadrian Date: Sat, 27 Jun 2026 07:19:28 -0400 Subject: [PATCH 3/5] check type --- .../src/simplify_expressions/expr_simplifier.rs | 3 ++- .../optimizer/src/simplify_expressions/utils.rs | 11 ++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 495543a88dcb9..dadf7160311ca 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1096,7 +1096,8 @@ impl TreeNodeRewriter for Simplifier<'_> { } expr @ Expr::BinaryExpr(_) - if has_associative_op(&expr) && has_adjacent_literals(&expr) => + if has_associative_op(&expr, self.info)? + && has_adjacent_literals(&expr) => { Transformed::yes(reassociate_literals(expr)) } diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 06b600c9279e9..344f792eebe42 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -26,6 +26,7 @@ use datafusion_expr::{ Case, Expr, Like, Operator, expr::{Between, BinaryExpr, InList}, expr_fn::{and, bitwise_and, bitwise_or, or}, + simplify::SimplifyContext, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -293,14 +294,18 @@ pub fn is_lit(expr: &Expr) -> bool { matches!(expr, Expr::Literal(_, _)) } -pub fn has_associative_op(expr: &Expr) -> bool { +pub fn has_associative_op(expr: &Expr, info: &SimplifyContext) -> Result { let op = match expr { Expr::BinaryExpr(expr) => expr.op, _ => unreachable!(), }; + let datatype = info.get_data_type(expr)?; // TODO: add other associative ops - // TODO: check types (float addition isn't associative) - matches!(op, Operator::Plus | Operator::StringConcat) + match op { + Operator::Plus => Ok(datatype.is_integer()), + Operator::StringConcat => Ok(datatype.is_string() || datatype.is_binary()), + _ => Ok(false), + } } pub fn has_adjacent_literals(expr: &Expr) -> bool { From 7c4e8e9a4cba4512044b6590ec73b6b87b4eb6f3 Mon Sep 17 00:00:00 2001 From: Hadrian Date: Sat, 27 Jun 2026 18:12:44 -0400 Subject: [PATCH 4/5] check types --- .../simplify_expressions/expr_simplifier.rs | 104 +++++++++++++----- .../src/simplify_expressions/utils.rs | 59 ++++++---- 2 files changed, 113 insertions(+), 50 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index dadf7160311ca..5b69d05437ea8 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1095,11 +1095,9 @@ impl TreeNodeRewriter for Simplifier<'_> { Transformed::yes(*right) } - expr @ Expr::BinaryExpr(_) - if has_associative_op(&expr, self.info)? - && has_adjacent_literals(&expr) => - { - Transformed::yes(reassociate_literals(expr)) + // (A + 1) + 2 -> A + (1 + 2) + expr if is_associative_with_adjacent_literals(&expr, self.info) => { + Transformed::yes(reassociate_literals(expr, self.info)) } // @@ -2362,23 +2360,35 @@ fn simplify_right_is_one_case( } } -fn reassociate_literals(expr: Expr) -> Expr { - fn flatten(op: Operator, expr: Expr, out: &mut Vec) { - match expr { - Expr::BinaryExpr(expr) if expr.op == op => { - flatten(op, *expr.left, out); - flatten(op, *expr.right, out); +fn reassociate_literals(expr: Expr, info: &SimplifyContext) -> Expr { + fn flatten( + op: Operator, + datatype: &DataType, + info: &SimplifyContext, + expr: Expr, + out: &mut Vec, + ) { + match &expr { + Expr::BinaryExpr(binary) + if binary.op == op + && matches!(info.get_data_type(&expr), Ok(dt) if &dt == datatype) => + { + let Expr::BinaryExpr(binary) = expr else { + unreachable!() + }; + flatten(op, datatype, info, *binary.left, out); + flatten(op, datatype, info, *binary.right, out); } - expr => out.push(expr), + _ => out.push(expr), } } - let op = match &expr { - Expr::BinaryExpr(expr) => expr.op, - _ => unreachable!(), + let (op, datatype) = match (&expr, info.get_data_type(&expr)) { + (Expr::BinaryExpr(expr), Ok(datatype)) => (expr.op, datatype), + _ => return expr, }; let mut exprs = Vec::new(); - flatten(op, expr, &mut exprs); + flatten(op, &datatype, info, expr, &mut exprs); let mut exprs = exprs.into_iter(); let mut out = exprs.next().unwrap(); @@ -3418,23 +3428,56 @@ mod tests { } #[test] - fn test_simplify_nested_associative_expr() { - let plus_expr = (col("c1") + lit(1)) + (lit(2) + col("c2")); - assert_eq!(simplify(plus_expr), col("c1") + lit(3) + col("c2")); + fn test_simplify_nested_literals_associative() { + assert_change( + col("c3") + lit(1i64) + lit(2i64) + lit(3i64) + col("c3_non_null"), + col("c3") + lit(6i64) + col("c3_non_null"), + ); + assert_change( + (col("c3") + lit(1i64)) + (lit(2i64) + col("c3_non_null")), + col("c3") + lit(3i64) + col("c3_non_null"), + ); + assert_change(lit(1i64) + lit(2i64) + col("c3"), lit(3i64) + col("c3")); + assert_change(col("c4") + lit(1u32) + lit(2u32), col("c4") + lit(3u32)); + assert_change( + col("c3") + lit(1i64) + col("c3_non_null") + lit(2i64) + lit(3i64), + col("c3") + lit(1i64) + col("c3_non_null") + lit(5i64), + ); + assert_change( + col("c3") * col("c3") + lit(2i64) + (lit(3i64) + lit(4i64) * col("c3")), + col("c3") * col("c3") + lit(5i64) + lit(4i64) * col("c3"), + ); + assert_change(col("c3") * lit(2i64) * lit(3i64), col("c3") * lit(6i64)); + assert_change( + (col("c3") * lit(2i64)) * (lit(3i64) * col("c3_non_null")), + col("c3") * lit(6i64) * col("c3_non_null"), + ); - let mixed_expr = - col("c1") * col("c2") + lit(1) + lit(2) + (lit(3) + lit(4) * col("c3")); - assert_eq!( - simplify(mixed_expr), - col("c1") * col("c2") + lit(6) + lit(4) * col("c3") + let cat = |left, right| binary_expr(left, Operator::StringConcat, right); + assert_change( + simplify(cat(cat(cat(col("c1"), lit("a")), lit("b")), col("c1"))), + cat(cat(col("c1"), lit("ab")), col("c1")), ); - let concat = |left, right| binary_expr(left, Operator::StringConcat, right); - let concat_expr = - concat(concat(concat(col("c1"), lit("a")), lit("b")), col("c2")); - assert_eq!( - simplify(concat_expr), - concat(concat(col("c1"), lit("ab")), col("c2")) + assert_change(col("c3") & lit(12i64) & lit(10i64), col("c3") & lit(8i64)); + assert_change( + (col("c4") & lit(12u32)) & (lit(10u32) & col("c4_non_null")), + col("c4") & lit(8u32) & col("c4_non_null"), + ); + + assert_change(col("c3") | lit(1i64) | lit(2i64), col("c3") | lit(3i64)); + assert_change(col("c4") ^ lit(1u32) ^ lit(3u32), col("c4") ^ lit(2u32)); + + assert_no_change(col("c3") + lit(1i64)); + assert_no_change(lit(1i64) + col("c3")); + assert_no_change(col("c6") + lit(1.0) + lit(2.0)); + assert_no_change(col("c6") * lit(3.0) * lit(4.0)); + assert_no_change(col("c4") + lit(5u32) + lit(6u8)); + assert_no_change(lit(7i64) + col("c3") + lit(8i64)); + + assert_change( + col("c4") + lit(1u32) + lit(2i64) + lit(3i64), + col("c4") + lit(1u32) + lit(5i64), ); } @@ -3755,6 +3798,7 @@ mod tests { Field::new("c3_non_null", DataType::Int64, false), Field::new("c4_non_null", DataType::UInt32, false), Field::new("c5", DataType::FixedSizeBinary(3), true), + Field::new("c6", DataType::Float64, true), ] .into(), HashMap::new(), diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 344f792eebe42..65a9d06c40ee2 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -17,7 +17,7 @@ //! Utility functions for expression simplification -use arrow::datatypes::i256; +use arrow::datatypes::{DataType, i256}; use datafusion_common::{ Result, ScalarValue, internal_err, tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, @@ -294,30 +294,40 @@ pub fn is_lit(expr: &Expr) -> bool { matches!(expr, Expr::Literal(_, _)) } -pub fn has_associative_op(expr: &Expr, info: &SimplifyContext) -> Result { - let op = match expr { - Expr::BinaryExpr(expr) => expr.op, - _ => unreachable!(), - }; - let datatype = info.get_data_type(expr)?; - // TODO: add other associative ops +pub fn is_associative(op: Operator, datatype: &DataType) -> bool { match op { - Operator::Plus => Ok(datatype.is_integer()), - Operator::StringConcat => Ok(datatype.is_string() || datatype.is_binary()), - _ => Ok(false), + Operator::Plus | Operator::Multiply => datatype.is_integer(), + Operator::StringConcat + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor => true, + _ => false, } } -pub fn has_adjacent_literals(expr: &Expr) -> bool { - struct AdjacentLiteralVisitor { - op: Operator, +pub fn is_associative_with_adjacent_literals( + expr: &Expr, + info: &SimplifyContext, +) -> bool { + struct AdjacentLiteralVisitor<'a> { last_expr_was_literal: bool, + op: Operator, + datatype: DataType, + info: &'a SimplifyContext, } - impl<'n> TreeNodeVisitor<'n> for AdjacentLiteralVisitor { + impl<'a, 'n> TreeNodeVisitor<'n> for AdjacentLiteralVisitor<'a> { type Node = Expr; fn f_down(&mut self, node: &'n Self::Node) -> Result { + match self.info.get_data_type(node) { + Ok(datatype) if datatype == self.datatype => {} + _ => { + self.last_expr_was_literal = false; + return Ok(TreeNodeRecursion::Jump); + } + } + match node { Expr::BinaryExpr(expr) if expr.op == self.op => { Ok(TreeNodeRecursion::Continue) @@ -330,18 +340,27 @@ pub fn has_adjacent_literals(expr: &Expr) -> bool { Ok(TreeNodeRecursion::Continue) } } - _ => Ok(TreeNodeRecursion::Jump), + _ => { + self.last_expr_was_literal = false; + Ok(TreeNodeRecursion::Jump) + } } } } - let op = match expr { - Expr::BinaryExpr(expr) => expr.op, - _ => unreachable!(), + let (op, datatype) = match (expr, info.get_data_type(expr)) { + (Expr::BinaryExpr(expr), Ok(datatype)) => (expr.op, datatype), + _ => return false, }; + if !is_associative(op, &datatype) { + return false; + } + let mut visitor = AdjacentLiteralVisitor { - op, last_expr_was_literal: false, + op, + datatype, + info, }; expr.visit(&mut visitor).unwrap() == TreeNodeRecursion::Stop } From 9408aed9007b13f55c0322791e3401f4130c38cf Mon Sep 17 00:00:00 2001 From: Hadrian Date: Sun, 28 Jun 2026 11:33:06 -0400 Subject: [PATCH 5/5] fix typo --- .../optimizer/src/simplify_expressions/expr_simplifier.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5b69d05437ea8..32b51a2fc7756 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3455,7 +3455,7 @@ mod tests { let cat = |left, right| binary_expr(left, Operator::StringConcat, right); assert_change( - simplify(cat(cat(cat(col("c1"), lit("a")), lit("b")), col("c1"))), + cat(cat(cat(col("c1"), lit("a")), lit("b")), col("c1")), cat(cat(col("c1"), lit("ab")), col("c1")), );