diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 39c8541b51b2f..32b51a2fc7756 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -39,12 +39,12 @@ use datafusion_common::{ metadata::FieldMetadata, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; -use datafusion_expr::expr::HigherOrderFunction; use datafusion_expr::{ BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator, Volatility, and, binary::BinaryTypeCoercer, lit, or, preimage::PreimageResult, }; use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult}; +use datafusion_expr::{binary_expr, expr::HigherOrderFunction}; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ expr::{InList, InSubquery}, @@ -1095,6 +1095,11 @@ impl TreeNodeRewriter for Simplifier<'_> { Transformed::yes(*right) } + // (A + 1) + 2 -> A + (1 + 2) + expr if is_associative_with_adjacent_literals(&expr, self.info) => { + Transformed::yes(reassociate_literals(expr, self.info)) + } + // // Rules for Multiply // @@ -2355,6 +2360,60 @@ fn simplify_right_is_one_case( } } +fn reassociate_literals(expr: Expr, info: &SimplifyContext) -> Expr { + fn flatten( + op: Operator, + datatype: &DataType, + info: &SimplifyContext, + expr: Expr, + out: &mut Vec, + ) { + match &expr { + Expr::BinaryExpr(binary) + if binary.op == op + && matches!(info.get_data_type(&expr), Ok(dt) if &dt == datatype) => + { + let Expr::BinaryExpr(binary) = expr else { + unreachable!() + }; + flatten(op, datatype, info, *binary.left, out); + flatten(op, datatype, info, *binary.right, out); + } + _ => out.push(expr), + } + } + + let (op, datatype) = match (&expr, info.get_data_type(&expr)) { + (Expr::BinaryExpr(expr), Ok(datatype)) => (expr.op, datatype), + _ => return expr, + }; + let mut exprs = Vec::new(); + flatten(op, &datatype, info, expr, &mut exprs); + let mut exprs = exprs.into_iter(); + + let mut out = exprs.next().unwrap(); + let mut lit = None; + for expr in exprs { + if is_lit(&expr) { + if let Some(left) = lit { + lit = Some(binary_expr(left, op, expr)); + } else { + lit = Some(expr); + } + } else { + if let Some(lit) = lit.take() { + out = binary_expr(out, op, lit); + } + out = binary_expr(out, op, expr); + } + } + if let Some(lit) = lit.take() { + binary_expr(out, op, lit) + } else { + out + } +} + #[cfg(test)] mod tests { use super::*; @@ -3368,6 +3427,60 @@ mod tests { assert_eq!(simplify(expr_eq), lit(true)); } + #[test] + fn test_simplify_nested_literals_associative() { + assert_change( + col("c3") + lit(1i64) + lit(2i64) + lit(3i64) + col("c3_non_null"), + col("c3") + lit(6i64) + col("c3_non_null"), + ); + assert_change( + (col("c3") + lit(1i64)) + (lit(2i64) + col("c3_non_null")), + col("c3") + lit(3i64) + col("c3_non_null"), + ); + assert_change(lit(1i64) + lit(2i64) + col("c3"), lit(3i64) + col("c3")); + assert_change(col("c4") + lit(1u32) + lit(2u32), col("c4") + lit(3u32)); + assert_change( + col("c3") + lit(1i64) + col("c3_non_null") + lit(2i64) + lit(3i64), + col("c3") + lit(1i64) + col("c3_non_null") + lit(5i64), + ); + assert_change( + col("c3") * col("c3") + lit(2i64) + (lit(3i64) + lit(4i64) * col("c3")), + col("c3") * col("c3") + lit(5i64) + lit(4i64) * col("c3"), + ); + assert_change(col("c3") * lit(2i64) * lit(3i64), col("c3") * lit(6i64)); + assert_change( + (col("c3") * lit(2i64)) * (lit(3i64) * col("c3_non_null")), + col("c3") * lit(6i64) * col("c3_non_null"), + ); + + let cat = |left, right| binary_expr(left, Operator::StringConcat, right); + assert_change( + cat(cat(cat(col("c1"), lit("a")), lit("b")), col("c1")), + cat(cat(col("c1"), lit("ab")), col("c1")), + ); + + assert_change(col("c3") & lit(12i64) & lit(10i64), col("c3") & lit(8i64)); + assert_change( + (col("c4") & lit(12u32)) & (lit(10u32) & col("c4_non_null")), + col("c4") & lit(8u32) & col("c4_non_null"), + ); + + assert_change(col("c3") | lit(1i64) | lit(2i64), col("c3") | lit(3i64)); + assert_change(col("c4") ^ lit(1u32) ^ lit(3u32), col("c4") ^ lit(2u32)); + + assert_no_change(col("c3") + lit(1i64)); + assert_no_change(lit(1i64) + col("c3")); + assert_no_change(col("c6") + lit(1.0) + lit(2.0)); + assert_no_change(col("c6") * lit(3.0) * lit(4.0)); + assert_no_change(col("c4") + lit(5u32) + lit(6u8)); + assert_no_change(lit(7i64) + col("c3") + lit(8i64)); + + assert_change( + col("c4") + lit(1u32) + lit(2i64) + lit(3i64), + col("c4") + lit(1u32) + lit(5i64), + ); + } + #[test] fn test_simplify_regex() { // malformed regex @@ -3685,6 +3798,7 @@ mod tests { Field::new("c3_non_null", DataType::Int64, false), Field::new("c4_non_null", DataType::UInt32, false), Field::new("c5", DataType::FixedSizeBinary(3), true), + Field::new("c6", DataType::Float64, true), ] .into(), HashMap::new(), diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index b0908b47602f7..65a9d06c40ee2 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -17,12 +17,16 @@ //! Utility functions for expression simplification -use arrow::datatypes::i256; -use datafusion_common::{Result, ScalarValue, internal_err}; +use arrow::datatypes::{DataType, i256}; +use datafusion_common::{ + Result, ScalarValue, internal_err, + tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, +}; use datafusion_expr::{ Case, Expr, Like, Operator, expr::{Between, BinaryExpr, InList}, expr_fn::{and, bitwise_and, bitwise_or, or}, + simplify::SimplifyContext, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -290,6 +294,77 @@ pub fn is_lit(expr: &Expr) -> bool { matches!(expr, Expr::Literal(_, _)) } +pub fn is_associative(op: Operator, datatype: &DataType) -> bool { + match op { + Operator::Plus | Operator::Multiply => datatype.is_integer(), + Operator::StringConcat + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor => true, + _ => false, + } +} + +pub fn is_associative_with_adjacent_literals( + expr: &Expr, + info: &SimplifyContext, +) -> bool { + struct AdjacentLiteralVisitor<'a> { + last_expr_was_literal: bool, + op: Operator, + datatype: DataType, + info: &'a SimplifyContext, + } + + impl<'a, 'n> TreeNodeVisitor<'n> for AdjacentLiteralVisitor<'a> { + type Node = Expr; + + fn f_down(&mut self, node: &'n Self::Node) -> Result { + match self.info.get_data_type(node) { + Ok(datatype) if datatype == self.datatype => {} + _ => { + self.last_expr_was_literal = false; + return Ok(TreeNodeRecursion::Jump); + } + } + + match node { + Expr::BinaryExpr(expr) if expr.op == self.op => { + Ok(TreeNodeRecursion::Continue) + } + Expr::Literal(_, _) => { + if self.last_expr_was_literal { + Ok(TreeNodeRecursion::Stop) + } else { + self.last_expr_was_literal = true; + Ok(TreeNodeRecursion::Continue) + } + } + _ => { + self.last_expr_was_literal = false; + Ok(TreeNodeRecursion::Jump) + } + } + } + } + + let (op, datatype) = match (expr, info.get_data_type(expr)) { + (Expr::BinaryExpr(expr), Ok(datatype)) => (expr.op, datatype), + _ => return false, + }; + if !is_associative(op, &datatype) { + return false; + } + + let mut visitor = AdjacentLiteralVisitor { + last_expr_was_literal: false, + op, + datatype, + info, + }; + expr.visit(&mut visitor).unwrap() == TreeNodeRecursion::Stop +} + /// Checks if `eq_expr` is `A = L1` and `ne_expr` is `A != L2` where L1 != L2. /// This pattern can be simplified to just `A = L1` since if A equals L1 /// and L1 is different from L2, then A is automatically not equal to L2.