diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index f7c763e207..28d3532b8b 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -93,12 +93,19 @@ def _(expr: TypedExpr) -> sge.Expression: def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ + # |x| < 1: The standard formula + sge.If( + this=sge.func("ABS", expr.expr) < sge.convert(1), + true=sge.func("ATANH", expr.expr), + ), + # |x| > 1: Returns NaN sge.If( this=sge.func("ABS", expr.expr) > sge.convert(1), true=constants._NAN, - ) + ), ], - default=sge.func("ATANH", expr.expr), + # |x| = 1: Returns Infinity or -Infinity + default=sge.Mul(this=constants._INF, expression=expr.expr), ) @@ -145,15 +152,11 @@ def _(expr: TypedExpr) -> sge.Expression: @register_unary_op(ops.expm1_op) def _(expr: TypedExpr) -> sge.Expression: - return sge.Case( - ifs=[ - sge.If( - this=expr.expr > constants._FLOAT64_EXP_BOUND, - true=constants._INF, - ) - ], - default=sge.func("EXP", expr.expr), - ) - sge.convert(1) + return sge.If( + this=expr.expr > constants._FLOAT64_EXP_BOUND, + true=constants._INF, + false=sge.func("EXP", expr.expr) - sge.convert(1), + ) @register_unary_op(ops.floor_op) @@ -166,11 +169,22 @@ def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( - this=expr.expr <= sge.convert(0), + this=sge.Is(this=expr.expr, expression=sge.Null()), + true=sge.null(), + ), + # |x| > 0: The standard formula + sge.If( + this=expr.expr > sge.convert(0), + true=sge.Ln(this=expr.expr), + ), + # |x| < 0: Returns NaN + sge.If( + this=expr.expr < sge.convert(0), true=constants._NAN, - ) + ), ], - default=sge.Ln(this=expr.expr), + # |x| == 0: Returns -Infinity + default=constants._NEG_INF, ) @@ -179,11 +193,22 @@ def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( - this=expr.expr <= sge.convert(0), + this=sge.Is(this=expr.expr, expression=sge.Null()), + true=sge.null(), + ), + # |x| > 0: The standard formula + sge.If( + this=expr.expr > sge.convert(0), + true=sge.Log(this=sge.convert(10), expression=expr.expr), + ), + # |x| < 0: Returns NaN + sge.If( + this=expr.expr < sge.convert(0), true=constants._NAN, - ) + ), ], - default=sge.Log(this=expr.expr, expression=sge.convert(10)), + # |x| == 0: Returns -Infinity + default=constants._NEG_INF, ) @@ -192,11 +217,22 @@ def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( - this=expr.expr <= sge.convert(-1), + this=sge.Is(this=expr.expr, expression=sge.Null()), + true=sge.null(), + ), + # Domain: |x| > -1 (The standard formula) + sge.If( + this=expr.expr > sge.convert(-1), + true=sge.Ln(this=sge.convert(1) + expr.expr), + ), + # Out of Domain: |x| < -1 (Returns NaN) + sge.If( + this=expr.expr < sge.convert(-1), true=constants._NAN, - ) + ), ], - default=sge.Ln(this=sge.convert(1) + expr.expr), + # Boundary: |x| == -1 (Returns -Infinity) + default=constants._NEG_INF, ) @@ -608,7 +644,7 @@ def isfinite(arg: TypedExpr) -> sge.Expression: return sge.Not( this=sge.Or( this=sge.IsInf(this=arg.expr), - right=sge.IsNan(this=arg.expr), + expression=sge.IsNan(this=arg.expr), ), ) diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index cefe983e24..d4dc4ecc06 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -674,7 +674,7 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: expressions=[_literal(value=v, dtype=value_type) for v in value] ) return values if len(value) > 0 else _cast(values, sqlglot_type) - elif pd.isna(value): + elif pd.isna(value) or (isinstance(value, pa.Scalar) and not value.is_valid): return _cast(sge.Null(), sqlglot_type) elif dtype == dtypes.JSON_DTYPE: return sge.ParseJSON(this=sge.convert(str(value))) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_arctanh/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_arctanh/out.sql index 197bf59306..dc6de62e7b 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_arctanh/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_arctanh/out.sql @@ -6,9 +6,11 @@ WITH `bfcte_0` AS ( SELECT *, CASE + WHEN ABS(`float64_col`) < 1 + THEN ATANH(`float64_col`) WHEN ABS(`float64_col`) > 1 THEN CAST('NaN' AS FLOAT64) - ELSE ATANH(`float64_col`) + ELSE CAST('Infinity' AS FLOAT64) * `float64_col` END AS `bfcol_1` FROM `bfcte_0` ) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_expm1/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_expm1/out.sql index 076ad584c2..13038bf8e8 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_expm1/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_expm1/out.sql @@ -5,11 +5,7 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - CASE - WHEN `float64_col` > 709.78 - THEN CAST('Infinity' AS FLOAT64) - ELSE EXP(`float64_col`) - END - 1 AS `bfcol_1` + IF(`float64_col` > 709.78, CAST('Infinity' AS FLOAT64), EXP(`float64_col`) - 1) AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_ln/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_ln/out.sql index 776cc33e0f..bd4cfa7c9a 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_ln/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_ln/out.sql @@ -5,7 +5,15 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - CASE WHEN `float64_col` <= 0 THEN CAST('NaN' AS FLOAT64) ELSE LN(`float64_col`) END AS `bfcol_1` + CASE + WHEN `float64_col` IS NULL + THEN NULL + WHEN `float64_col` > 0 + THEN LN(`float64_col`) + WHEN `float64_col` < 0 + THEN CAST('NaN' AS FLOAT64) + ELSE CAST('-Infinity' AS FLOAT64) + END AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_log10/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_log10/out.sql index 11a318c22d..c5bbff0e62 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_log10/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_log10/out.sql @@ -6,9 +6,13 @@ WITH `bfcte_0` AS ( SELECT *, CASE - WHEN `float64_col` <= 0 + WHEN `float64_col` IS NULL + THEN NULL + WHEN `float64_col` > 0 + THEN LOG(`float64_col`, 10) + WHEN `float64_col` < 0 THEN CAST('NaN' AS FLOAT64) - ELSE LOG(10, `float64_col`) + ELSE CAST('-Infinity' AS FLOAT64) END AS `bfcol_1` FROM `bfcte_0` ) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_log1p/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_log1p/out.sql index 4297fff227..22e67e24ee 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_log1p/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_log1p/out.sql @@ -6,9 +6,13 @@ WITH `bfcte_0` AS ( SELECT *, CASE - WHEN `float64_col` <= -1 + WHEN `float64_col` IS NULL + THEN NULL + WHEN `float64_col` > -1 + THEN LN(1 + `float64_col`) + WHEN `float64_col` < -1 THEN CAST('NaN' AS FLOAT64) - ELSE LN(1 + `float64_col`) + ELSE CAST('-Infinity' AS FLOAT64) END AS `bfcol_1` FROM `bfcte_0` )