diff --git a/Cargo.lock b/Cargo.lock index 22ec582536069..a106b43b47989 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2687,6 +2687,7 @@ dependencies = [ "percent-encoding", "rand 0.9.2", "sha1", + "sha2", "url", ] diff --git a/Cargo.toml b/Cargo.toml index 10fc88b7057c8..688640f227e67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -181,6 +181,7 @@ recursive = "0.1.1" regex = "1.12" rstest = "0.26.1" serde_json = "1" +sha2 = "^0.10.9" sqlparser = { version = "0.59.0", default-features = false, features = ["std", "visitor"] } strum = "0.27.2" strum_macros = "0.27.2" diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 3e832691f96b0..d67e21d6656b3 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -84,7 +84,7 @@ md-5 = { version = "^0.10.0", optional = true } num-traits = { workspace = true } rand = { workspace = true } regex = { workspace = true, optional = true } -sha2 = { version = "^0.10.9", optional = true } +sha2 = { workspace = true, optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } uuid = { version = "1.19", features = ["v4"], optional = true } diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 09959db41fe60..63a0000198d53 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -53,6 +53,7 @@ log = { workspace = true } percent-encoding = "2.3.2" rand = { workspace = true } sha1 = "0.10" +sha2 = { workspace = true } url = { workspace = true } [dev-dependencies] diff --git a/datafusion/spark/src/function/hash/sha2.rs b/datafusion/spark/src/function/hash/sha2.rs index 1f17275062778..a7ce5d7eb0ae0 100644 --- a/datafusion/spark/src/function/hash/sha2.rs +++ b/datafusion/spark/src/function/hash/sha2.rs @@ -15,26 +15,30 @@ // specific language governing permissions and limitations // under the License. -extern crate datafusion_functions; - -use crate::function::error_utils::{ - invalid_arg_count_exec_err, unsupported_data_type_exec_err, -}; -use crate::function::math::hex::spark_sha2_hex; -use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::array::{ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray}; use arrow::datatypes::{DataType, Int32Type}; -use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err}; -use datafusion_expr::Signature; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; -pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512}; +use datafusion_common::types::{ + NativeType, logical_binary, logical_int32, logical_string, +}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use sha2::{self, Digest}; use std::any::Any; +use std::fmt::Write; use std::sync::Arc; +/// Differs from DataFusion version in allowing array input for bit lengths, and +/// also hex encoding the output. +/// /// #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkSha2 { signature: Signature, - aliases: Vec, } impl Default for SparkSha2 { @@ -46,8 +50,21 @@ impl Default for SparkSha2 { impl SparkSha2 { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), - aliases: vec![], + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ], + Volatility::Immutable, + ), } } } @@ -65,163 +82,73 @@ impl ScalarUDFImpl for SparkSha2 { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[1].is_null() { - return Ok(DataType::Null); - } - Ok(match arg_types[0] { - DataType::Utf8View - | DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::BinaryView - | DataType::LargeBinary => DataType::Utf8, - DataType::Null => DataType::Null, - _ => { - return exec_err!( - "{} function can only accept strings or binary arrays.", - self.name() - ); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| { - internal_datafusion_err!("Expected 2 arguments for function sha2") - })?; - - sha2(args) + make_scalar_function(sha2_impl, vec![])(&args.args) } +} - fn aliases(&self) -> &[String] { - &self.aliases - } +fn sha2_impl(args: &[ArrayRef]) -> Result { + let [values, bit_lengths] = take_function_args("sha2", args)?; - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return Err(invalid_arg_count_exec_err( - self.name(), - (2, 2), - arg_types.len(), - )); + let bit_lengths = bit_lengths.as_primitive::(); + let output = match values.data_type() { + DataType::Binary => sha2_binary_impl(&values.as_binary::(), bit_lengths), + DataType::LargeBinary => { + sha2_binary_impl(&values.as_binary::(), bit_lengths) } - let expr_type = match &arg_types[0] { - DataType::Utf8View - | DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::BinaryView - | DataType::LargeBinary - | DataType::Null => Ok(arg_types[0].clone()), - _ => Err(unsupported_data_type_exec_err( - self.name(), - "String, Binary", - &arg_types[0], - )), - }?; - let bit_length_type = if arg_types[1].is_numeric() { - Ok(DataType::Int32) - } else if arg_types[1].is_null() { - Ok(DataType::Null) - } else { - Err(unsupported_data_type_exec_err( - self.name(), - "Numeric Type", - &arg_types[1], - )) - }?; - - Ok(vec![expr_type, bit_length_type]) - } + DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), bit_lengths), + dt => return internal_err!("Unsupported datatype for sha2: {dt}"), + }; + Ok(output) } -pub fn sha2(args: [ColumnarValue; 2]) -> Result { - match args { - [ - ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), - ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))), - ] => compute_sha2( - bit_length_arg, - &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))], - ), - [ - ColumnarValue::Array(expr_arg), - ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))), - ] => compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)]), - [ - ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), - ColumnarValue::Array(bit_length_arg), - ] => { - let arr: StringArray = bit_length_arg - .as_primitive::() - .iter() - .map(|bit_length| { - match sha2([ - ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())), - ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), - ]) - .unwrap() - { - ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, - ColumnarValue::Array(arr) => arr - .as_string::() - .iter() - .map(|str| str.unwrap().to_string()) - .next(), // first element - _ => unreachable!(), - } - }) - .collect(); - Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) - } - [ - ColumnarValue::Array(expr_arg), - ColumnarValue::Array(bit_length_arg), - ] => { - let expr_iter = expr_arg.as_string::().iter(); - let bit_length_iter = bit_length_arg.as_primitive::().iter(); - let arr: StringArray = expr_iter - .zip(bit_length_iter) - .map(|(expr, bit_length)| { - match sha2([ - ColumnarValue::Scalar(ScalarValue::Utf8(Some( - expr.unwrap().to_string(), - ))), - ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), - ]) - .unwrap() - { - ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, - ColumnarValue::Array(arr) => arr - .as_string::() - .iter() - .map(|str| str.unwrap().to_string()) - .next(), // first element - _ => unreachable!(), - } - }) - .collect(); - Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) - } - _ => exec_err!("Unsupported argument types for sha2 function"), - } +fn sha2_binary_impl<'a, BinaryArrType>( + values: &BinaryArrType, + bit_lengths: &Int32Array, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, +{ + let array = values + .iter() + .zip(bit_lengths.iter()) + .map(|(value, bit_length)| match (value, bit_length) { + (Some(value), Some(224)) => { + let mut digest = sha2::Sha224::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(0 | 256)) => { + let mut digest = sha2::Sha256::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(384)) => { + let mut digest = sha2::Sha384::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(512)) => { + let mut digest = sha2::Sha512::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + // Unknown bit-lengths go to null, same as in Spark + _ => None, + }) + .collect::(); + Arc::new(array) } -fn compute_sha2( - bit_length_arg: i32, - expr_arg: &[ColumnarValue], -) -> Result { - match bit_length_arg { - 0 | 256 => sha256(expr_arg), - 224 => sha224(expr_arg), - 384 => sha384(expr_arg), - 512 => sha512(expr_arg), - _ => { - // Return null for unsupported bit lengths instead of error, because spark sha2 does not - // error out for this. - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); - } +fn hex_encode>(data: T) -> String { + let mut s = String::with_capacity(data.as_ref().len() * 2); + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02x}").unwrap(); } - .map(|hashed| spark_sha2_hex(&[hashed]).unwrap()) + s } diff --git a/datafusion/sqllogictest/test_files/spark/hash/sha2.slt b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt index 7690a38773b04..07f70947fe926 100644 --- a/datafusion/sqllogictest/test_files/spark/hash/sha2.slt +++ b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt @@ -75,3 +75,58 @@ SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('ba 967004d25de4abc1bd6a7c9a216254a5ac0733e8ad96dc9f1ea0fad9619da7c32d654ec8ad8ba2f9b5728fed6633bd91 8c6be9ed448a34883a13a13f4ead4aefa036b67dcda59020c01e57ea075ea8a4792d428f2c6fd0c09d1c49994d6c22789336e062188df29572ed07e7f9779c52 NULL + +# All string types +query T +SELECT sha2(arrow_cast('foo', 'Utf8'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'LargeUtf8'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'Utf8View'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +# All binary types +query T +SELECT sha2(arrow_cast('foo', 'Binary'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'LargeBinary'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'BinaryView'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + + +# Null cases +query T +select sha2(null, 0); +---- +NULL + +query T +select sha2('a', null); +---- +NULL + +query T +select sha2('a', null::int); +---- +NULL