diff --git a/datafusion/spark/src/function/string/encode.rs b/datafusion/spark/src/function/string/encode.rs new file mode 100644 index 0000000000000..7bc3783d979ec --- /dev/null +++ b/datafusion/spark/src/function/string/encode.rs @@ -0,0 +1,360 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::borrow::Cow; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryBuilder, StringArray, StringArrayType, +}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err, plan_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; + +/// Spark-compatible `encode` expression. +/// Encodes a string or binary value into binary using the specified character encoding. +/// Binary input is interpreted as UTF-8 with lossy conversion (invalid bytes become U+FFFD). +/// +/// # Target Spark version +/// Emulates Spark 3.5 semantics: +/// - Accepts canonical charset names and common aliases (`UTF8`, `LATIN1`, +/// `ISO88591`, `ASCII`, `UTF-32BE`, etc.). +/// - Unmappable characters (non-ASCII in `US-ASCII`, code points above +/// `U+00FF` in `ISO-8859-1`) are silently replaced with `?`. +/// - `UTF-32` is an alias for `UTF-32BE` (no BOM), matching both Spark 3.5 +/// and Spark 4.1. +/// +/// # Spark 4.0 differences (not implemented) +/// Spark 4.0 tightened `encode` in two ways, each gated by a `spark.sql.legacy.*` +/// config that can restore the 3.5 behavior: +/// +/// - Charset whitelist — rejects aliases with `INVALID_PARAMETER_VALUE.CHARSET`. +/// Controlled by `spark.sql.legacy.javaCharsets`. +/// - Unmappable characters — raises `MALFORMED_CHARACTER_CODING`. +/// Controlled by `spark.sql.legacy.codingErrorAction`. +/// +/// TODO: wire both configs so Spark 4.0 behavior can be selected at runtime. +/// See: +/// +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkEncode { + signature: Signature, +} + +impl Default for SparkEncode { + fn default() -> Self { + Self::new() + } +} + +impl SparkEncode { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +/// Encodes a single string value using the specified charset. +/// Unmappable characters are silently replaced with `?` (Spark 3.5 behavior). +fn encode_string(s: &str, charset: &str) -> Result> { + match charset { + "UTF-8" | "UTF8" => Ok(s.as_bytes().to_vec()), + "US-ASCII" | "ASCII" => Ok(s + .chars() + .map(|c| if c.is_ascii() { c as u8 } else { b'?' }) + .collect()), + "ISO-8859-1" | "ISO88591" | "LATIN1" => Ok(s + .chars() + .map(|c| { + let cp = c as u32; + if cp <= 255 { cp as u8 } else { b'?' } + }) + .collect()), + "UTF-16BE" | "UTF16BE" => { + let mut bytes = Vec::new(); + for code_unit in s.encode_utf16() { + bytes.extend_from_slice(&code_unit.to_be_bytes()); + } + Ok(bytes) + } + "UTF-16LE" | "UTF16LE" => { + let mut bytes = Vec::new(); + for code_unit in s.encode_utf16() { + bytes.extend_from_slice(&code_unit.to_le_bytes()); + } + Ok(bytes) + } + "UTF-16" | "UTF16" => { + // BOM (big-endian marker) followed by UTF-16BE encoded bytes + let mut bytes = vec![0xFE, 0xFF]; + for code_unit in s.encode_utf16() { + bytes.extend_from_slice(&code_unit.to_be_bytes()); + } + Ok(bytes) + } + // Spark treats UTF-32 as UTF-32BE (no BOM), matching Spark 3.5 and 4.1. + "UTF-32" | "UTF32" | "UTF-32BE" | "UTF32BE" => { + let mut bytes = Vec::with_capacity(s.len() * 4); + for c in s.chars() { + bytes.extend_from_slice(&(c as u32).to_be_bytes()); + } + Ok(bytes) + } + "UTF-32LE" | "UTF32LE" => { + let mut bytes = Vec::with_capacity(s.len() * 4); + for c in s.chars() { + bytes.extend_from_slice(&(c as u32).to_le_bytes()); + } + Ok(bytes) + } + _ => exec_err!( + "Unsupported charset for encode: '{}'. Supported: US-ASCII, ISO-8859-1, UTF-8, UTF-16, UTF-16BE, UTF-16LE, UTF-32, UTF-32BE, UTF-32LE", + charset + ), + } +} + +/// The charset for each row. `Constant` is a single uppercased charset; `PerRow` +/// uppercases each row on demand, yielding `None` for null entries. +enum Charsets<'a> { + Constant(&'a str), + PerRow(&'a StringArray), +} + +impl Charsets<'_> { + fn get(&self, i: usize) -> Option> { + match self { + Charsets::Constant(charset) => Some(Cow::Borrowed(charset)), + Charsets::PerRow(array) => { + (!array.is_null(i)).then(|| Cow::Owned(array.value(i).to_uppercase())) + } + } + } +} + +fn null_binary_array(len: usize) -> ArrayRef { + let mut builder = BinaryBuilder::new(); + for _ in 0..len { + builder.append_null(); + } + Arc::new(builder.finish()) +} + +/// Encodes each row to binary. A row is null if its value or charset is null; +/// otherwise it is `encode_string(decode(i), charset)`. `decode` yields the row's +/// value as text — a borrow for string input, or lossy UTF-8 for binary input +/// (invalid bytes become U+FFFD, matching Spark). +fn encode_rows<'a>( + len: usize, + is_null: impl Fn(usize) -> bool, + decode: impl Fn(usize) -> Cow<'a, str>, + charsets: &Charsets, +) -> Result { + let mut builder = BinaryBuilder::with_capacity(len, len * 4); + for i in 0..len { + match (is_null(i), charsets.get(i)) { + (false, Some(charset)) => { + builder.append_value(&encode_string(&decode(i), &charset)?) + } + _ => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) +} + +fn encode_array<'a, S: StringArrayType<'a>>( + array: &S, + charsets: &Charsets, +) -> Result { + encode_rows( + array.len(), + |i| array.is_null(i), + |i| Cow::Borrowed(array.value(i)), + charsets, + ) +} + +fn encode_binary_array<'a, B: arrow::array::BinaryArrayType<'a>>( + array: &'a B, + charsets: &Charsets, +) -> Result { + encode_rows( + array.len(), + |i| array.is_null(i), + |i| String::from_utf8_lossy(array.value(i)), + charsets, + ) +} + +fn encode_dispatch(arr: &ArrayRef, charsets: &Charsets) -> Result { + match arr.data_type() { + DataType::Utf8 => encode_array(&arr.as_string::(), charsets), + DataType::LargeUtf8 => encode_array(&arr.as_string::(), charsets), + DataType::Utf8View => encode_array(&arr.as_string_view(), charsets), + DataType::Binary => encode_binary_array(&arr.as_binary::(), charsets), + DataType::LargeBinary => encode_binary_array(&arr.as_binary::(), charsets), + DataType::BinaryView => encode_binary_array(&arr.as_binary_view(), charsets), + DataType::Null => Ok(null_binary_array(arr.len())), + dt => exec_err!("encode expects a string or binary argument, got {dt:?}"), + } +} + +/// The uppercased charset from a scalar, or `None` if it is null. +fn scalar_charset(scalar: &ScalarValue) -> Result> { + match scalar { + ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => { + Ok(s.as_ref().map(|s| s.to_uppercase())) + } + ScalarValue::Null => Ok(None), + other => exec_err!("encode charset argument must be a string, got {other:?}"), + } +} + +impl ScalarUDFImpl for SparkEncode { + fn name(&self) -> &str { + "encode" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [value_type, charset_type] = take_function_args(self.name(), arg_types)?; + + let value_type = match value_type { + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Binary + | DataType::LargeBinary + | DataType::BinaryView + | DataType::Null => value_type.clone(), + other => { + return plan_err!( + "encode expects a string or binary first argument, got {other:?}" + ); + } + }; + + match charset_type { + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Null => {} + other => { + return plan_err!( + "encode expects a string charset second argument, got {other:?}" + ); + } + } + + Ok(vec![value_type, DataType::Utf8]) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + datafusion_common::internal_err!( + "return_type should not be called, use return_field_from_args instead" + ) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new( + self.name(), + DataType::Binary, + nullable, + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [value, charset_arg] = take_function_args(self.name(), args.args)?; + + let len = [&value, &charset_arg] + .into_iter() + .find_map(|arg| match arg { + ColumnarValue::Array(a) => Some(a.len()), + _ => None, + }); + let inferred_length = len.unwrap_or(1); + let is_scalar = len.is_none(); + + let value_arr = match value { + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(inferred_length)?, + ColumnarValue::Array(array) => array, + }; + + // A null charset yields a null result for that row (Spark NullIntolerant). + let result = match &charset_arg { + ColumnarValue::Scalar(scalar) => match scalar_charset(scalar)? { + Some(charset) => { + encode_dispatch(&value_arr, &Charsets::Constant(&charset))? + } + None => null_binary_array(value_arr.len()), + }, + ColumnarValue::Array(charset_array) => encode_dispatch( + &value_arr, + &Charsets::PerRow(charset_array.as_string::()), + )?, + }; + + if is_scalar { + ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar) + } else { + Ok(ColumnarValue::Array(result)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encode_return_field_nullable() { + let func = SparkEncode::new(); + + let nullable = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("input", DataType::Utf8, true)), + Arc::new(Field::new("charset", DataType::Utf8, false)), + ], + scalar_arguments: &[None, None], + }) + .unwrap(); + assert!(nullable.is_nullable()); + assert_eq!(nullable.data_type(), &DataType::Binary); + + let non_nullable = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("input", DataType::Utf8, false)), + Arc::new(Field::new("charset", DataType::Utf8, false)), + ], + scalar_arguments: &[None, None], + }) + .unwrap(); + assert!(!non_nullable.is_nullable()); + } +} diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs index bc94c27732c91..a13a6ad739263 100644 --- a/datafusion/spark/src/function/string/mod.rs +++ b/datafusion/spark/src/function/string/mod.rs @@ -21,6 +21,7 @@ pub mod char; pub mod concat; pub mod concat_ws; pub mod elt; +pub mod encode; pub mod format_string; pub mod ilike; pub mod is_valid_utf8; @@ -45,6 +46,7 @@ make_udf_function!(concat_ws::SparkConcatWs, concat_ws); make_udf_function!(ilike::SparkILike, ilike); make_udf_function!(length::SparkLengthFunc, length); make_udf_function!(elt::SparkElt, elt); +make_udf_function!(encode::SparkEncode, encode); make_udf_function!(like::SparkLike, like); make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check); make_udf_function!(format_string::FormatStringFunc, format_string); @@ -89,6 +91,11 @@ pub mod expr_fn { "Returns the n-th input (1-indexed), e.g. returns 2nd input when n is 2. The function returns NULL if the index is 0 or exceeds the length of the array.", select_col arg1 arg2 argn )); + export_functions!(( + encode, + "Encodes a string or binary value into binary using the specified character encoding.", + string_or_binary charset + )); export_functions!(( ilike, "Returns true if str matches pattern (case insensitive).", @@ -151,6 +158,7 @@ pub fn functions() -> Vec> { concat(), concat_ws(), elt(), + encode(), ilike(), length(), like(), diff --git a/datafusion/sqllogictest/test_files/spark/string/encode.slt b/datafusion/sqllogictest/test_files/spark/string/encode.slt index 4ad02316f4f3f..4f17536c05187 100644 --- a/datafusion/sqllogictest/test_files/spark/string/encode.slt +++ b/datafusion/sqllogictest/test_files/spark/string/encode.slt @@ -15,13 +15,182 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT encode('abc', 'utf-8'); -## PySpark 3.5.5 Result: {'encode(abc, utf-8)': bytearray(b'abc'), 'typeof(encode(abc, utf-8))': 'binary', 'typeof(abc)': 'string', 'typeof(utf-8)': 'string'} -#query -#SELECT encode('abc'::string, 'utf-8'::string); +# UTF-8 encoding +query ? +SELECT encode('Spark SQL'::string, 'utf-8'::string); +---- +537061726b2053514c + +# US-ASCII encoding +query ? +SELECT encode('Hello'::string, 'us-ascii'::string); +---- +48656c6c6f + +# ISO-8859-1 encoding (ï = 0xEF in ISO-8859-1) +query ? +SELECT encode('naïve'::string, 'iso-8859-1'::string); +---- +6e61ef7665 + +# UTF-16BE encoding +query ? +SELECT encode('AB'::string, 'utf-16be'::string); +---- +00410042 + +# UTF-16LE encoding +query ? +SELECT encode('AB'::string, 'utf-16le'::string); +---- +41004200 + +# UTF-16 with BOM (FEFF prefix followed by UTF-16BE bytes) +query ? +SELECT encode('Spark SQL'::string, 'utf-16'::string); +---- +feff0053007000610072006b002000530051004c + +# UTF-32BE encoding +query ? +SELECT encode('A'::string, 'utf-32be'::string); +---- +00000041 + +# UTF-32LE encoding +query ? +SELECT encode('AB'::string, 'utf-32le'::string); +---- +4100000042000000 + +# UTF-32 (Spark 3.5 / 4.1: no BOM, identical to UTF-32BE) +query ? +SELECT encode('A'::string, 'utf-32'::string); +---- +00000041 + +# Case-insensitive charset +query ? +SELECT encode('hello'::string, 'Utf-8'::string); +---- +68656c6c6f + +# US-ASCII: unmappable characters are silently replaced with '?' (Spark 3.5 behavior) +query ? +SELECT encode('é'::string, 'us-ascii'::string); +---- +3f + +# ISO-8859-1: code points above U+00FF are replaced with '?' (Ā is U+0100) +query ? +SELECT encode('Ā'::string, 'iso-8859-1'::string); +---- +3f + +# Multi-byte input (emoji 😀 = U+1F600) across charsets +query ? +SELECT encode('😀'::string, 'utf-8'::string); +---- +f09f9880 + +query ? +SELECT encode('😀'::string, 'utf-16be'::string); +---- +d83dde00 + +query ? +SELECT encode('😀'::string, 'utf-16le'::string); +---- +3dd800de + +query ? +SELECT encode('😀'::string, 'utf-16'::string); +---- +feffd83dde00 + +query ? +SELECT encode('😀'::string, 'utf-32be'::string); +---- +0001f600 + +query ? +SELECT encode('😀'::string, 'utf-32le'::string); +---- +00f60100 + +query ? +SELECT encode('😀'::string, 'utf-32'::string); +---- +0001f600 + +# Utf8View column input +query ? +SELECT encode(arrow_cast(s, 'Utf8View'), 'utf-8'::string) FROM (VALUES ('foo'::string), ('bar'::string)) AS t(s); +---- +666f6f +626172 + +# Binary input is decoded as UTF-8 (lossily), then re-encoded with the charset +query ? +SELECT encode(arrow_cast('Hello', 'Binary'), 'utf-8'::string); +---- +48656c6c6f + +query ? +SELECT encode(arrow_cast('Hello', 'LargeBinary'), 'utf-8'::string); +---- +48656c6c6f + +query ? +SELECT encode(arrow_cast('Hello', 'BinaryView'), 'utf-8'::string); +---- +48656c6c6f + +# Binary input with invalid UTF-8 bytes: each invalid byte becomes U+FFFD (ef bf bd). +# Bytes FE and FF are not valid UTF-8 and are replaced; 00 61 00 73 decode as NUL a NUL s. +query ? +SELECT encode(X'FEFF00610073', 'utf-8'::string); +---- +efbfbdefbfbd00610073 + +# NULL input +query ? +SELECT encode(NULL::string, 'utf-8'::string); +---- +NULL + +# Array input with NULLs +query ? +SELECT encode(s, 'utf-8'::string) FROM (VALUES ('hello'::string), (NULL::string), ('world'::string)) AS t(s); +---- +68656c6c6f +NULL +776f726c64 + +# NULL charset returns NULL (Spark is NullIntolerant) +query ? +SELECT encode('hello'::string, NULL::string); +---- +NULL + +# Charset varies per row, and a NULL charset yields NULL for that row +query ? +SELECT encode(a, b) FROM VALUES (X'0a', 'US-ASCII'), (X'deadbeef', NULL) AS t(a, b); +---- +0a +NULL + +# Per-row charsets: each row encoded with its own charset +query ? +SELECT encode(a, b) FROM VALUES (X'0a', 'US-ASCII'), (X'deadbeef', 'UTF-16be') AS t(a, b); +---- +0a +07adfffdfffd + +# Error: unsupported charset +statement error Unsupported charset for encode +SELECT encode('hello'::string, 'EBCDIC'::string); + +# Error: no arguments +statement error encode function requires 2 arguments +SELECT encode();