-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Simplify Spark sha2 implementation
#19475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,26 +15,30 @@ | |
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| extern crate datafusion_functions; | ||
|
|
||
| use crate::function::error_utils::{ | ||
| invalid_arg_count_exec_err, unsupported_data_type_exec_err, | ||
| }; | ||
| use crate::function::math::hex::spark_sha2_hex; | ||
| use arrow::array::{ArrayRef, AsArray, StringArray}; | ||
| use arrow::array::{ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray}; | ||
| use arrow::datatypes::{DataType, Int32Type}; | ||
| use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err}; | ||
| use datafusion_expr::Signature; | ||
| use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; | ||
| pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512}; | ||
| use datafusion_common::types::{ | ||
| NativeType, logical_binary, logical_int32, logical_string, | ||
| }; | ||
| use datafusion_common::utils::take_function_args; | ||
| use datafusion_common::{Result, internal_err}; | ||
| use datafusion_expr::{ | ||
| Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, | ||
| TypeSignatureClass, Volatility, | ||
| }; | ||
| use datafusion_functions::utils::make_scalar_function; | ||
| use sha2::{self, Digest}; | ||
| use std::any::Any; | ||
| use std::fmt::Write; | ||
| use std::sync::Arc; | ||
|
|
||
| /// Differs from DataFusion version in allowing array input for bit lengths, and | ||
| /// also hex encoding the output. | ||
| /// | ||
| /// <https://spark.apache.org/docs/latest/api/sql/index.html#sha2> | ||
| #[derive(Debug, PartialEq, Eq, Hash)] | ||
| pub struct SparkSha2 { | ||
| signature: Signature, | ||
| aliases: Vec<String>, | ||
| } | ||
|
|
||
| impl Default for SparkSha2 { | ||
|
|
@@ -46,8 +50,21 @@ impl Default for SparkSha2 { | |
| impl SparkSha2 { | ||
| pub fn new() -> Self { | ||
| Self { | ||
| signature: Signature::user_defined(Volatility::Immutable), | ||
| aliases: vec![], | ||
| signature: Signature::coercible( | ||
| vec![ | ||
| Coercion::new_implicit( | ||
| TypeSignatureClass::Native(logical_binary()), | ||
| vec![TypeSignatureClass::Native(logical_string())], | ||
| NativeType::Binary, | ||
| ), | ||
| Coercion::new_implicit( | ||
| TypeSignatureClass::Native(logical_int32()), | ||
| vec![TypeSignatureClass::Integer], | ||
| NativeType::Int32, | ||
|
Comment on lines
-49
to
+63
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving away from user_defined; also we cast strings to binary to simplify implementation as we only need raw bytes either way |
||
| ), | ||
| ], | ||
| Volatility::Immutable, | ||
| ), | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -65,163 +82,73 @@ impl ScalarUDFImpl for SparkSha2 { | |
| &self.signature | ||
| } | ||
|
|
||
| fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
| if arg_types[1].is_null() { | ||
| return Ok(DataType::Null); | ||
| } | ||
| Ok(match arg_types[0] { | ||
| DataType::Utf8View | ||
| | DataType::LargeUtf8 | ||
| | DataType::Utf8 | ||
| | DataType::Binary | ||
| | DataType::BinaryView | ||
| | DataType::LargeBinary => DataType::Utf8, | ||
| DataType::Null => DataType::Null, | ||
| _ => { | ||
| return exec_err!( | ||
| "{} function can only accept strings or binary arrays.", | ||
| self.name() | ||
| ); | ||
| } | ||
| }) | ||
| fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
| Ok(DataType::Utf8) | ||
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| { | ||
| internal_datafusion_err!("Expected 2 arguments for function sha2") | ||
| })?; | ||
|
|
||
| sha2(args) | ||
| make_scalar_function(sha2_impl, vec![])(&args.args) | ||
| } | ||
| } | ||
|
|
||
| fn aliases(&self) -> &[String] { | ||
| &self.aliases | ||
| } | ||
| fn sha2_impl(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [values, bit_lengths] = take_function_args("sha2", args)?; | ||
|
|
||
| fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { | ||
| if arg_types.len() != 2 { | ||
| return Err(invalid_arg_count_exec_err( | ||
| self.name(), | ||
| (2, 2), | ||
| arg_types.len(), | ||
| )); | ||
| let bit_lengths = bit_lengths.as_primitive::<Int32Type>(); | ||
| let output = match values.data_type() { | ||
| DataType::Binary => sha2_binary_impl(&values.as_binary::<i32>(), bit_lengths), | ||
| DataType::LargeBinary => { | ||
| sha2_binary_impl(&values.as_binary::<i64>(), bit_lengths) | ||
| } | ||
| let expr_type = match &arg_types[0] { | ||
| DataType::Utf8View | ||
| | DataType::LargeUtf8 | ||
| | DataType::Utf8 | ||
| | DataType::Binary | ||
| | DataType::BinaryView | ||
| | DataType::LargeBinary | ||
| | DataType::Null => Ok(arg_types[0].clone()), | ||
| _ => Err(unsupported_data_type_exec_err( | ||
| self.name(), | ||
| "String, Binary", | ||
| &arg_types[0], | ||
| )), | ||
| }?; | ||
| let bit_length_type = if arg_types[1].is_numeric() { | ||
| Ok(DataType::Int32) | ||
| } else if arg_types[1].is_null() { | ||
| Ok(DataType::Null) | ||
| } else { | ||
| Err(unsupported_data_type_exec_err( | ||
| self.name(), | ||
| "Numeric Type", | ||
| &arg_types[1], | ||
| )) | ||
| }?; | ||
|
|
||
| Ok(vec![expr_type, bit_length_type]) | ||
| } | ||
| DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), bit_lengths), | ||
| dt => return internal_err!("Unsupported datatype for sha2: {dt}"), | ||
| }; | ||
| Ok(output) | ||
| } | ||
|
|
||
| pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> { | ||
| match args { | ||
| [ | ||
| ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), | ||
| ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))), | ||
| ] => compute_sha2( | ||
| bit_length_arg, | ||
| &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))], | ||
| ), | ||
| [ | ||
| ColumnarValue::Array(expr_arg), | ||
| ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))), | ||
| ] => compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)]), | ||
| [ | ||
| ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), | ||
| ColumnarValue::Array(bit_length_arg), | ||
| ] => { | ||
| let arr: StringArray = bit_length_arg | ||
| .as_primitive::<Int32Type>() | ||
| .iter() | ||
| .map(|bit_length| { | ||
| match sha2([ | ||
| ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())), | ||
| ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), | ||
| ]) | ||
| .unwrap() | ||
| { | ||
| ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, | ||
| ColumnarValue::Array(arr) => arr | ||
| .as_string::<i32>() | ||
| .iter() | ||
| .map(|str| str.unwrap().to_string()) | ||
| .next(), // first element | ||
| _ => unreachable!(), | ||
| } | ||
| }) | ||
| .collect(); | ||
| Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) | ||
| } | ||
| [ | ||
| ColumnarValue::Array(expr_arg), | ||
| ColumnarValue::Array(bit_length_arg), | ||
| ] => { | ||
| let expr_iter = expr_arg.as_string::<i32>().iter(); | ||
| let bit_length_iter = bit_length_arg.as_primitive::<Int32Type>().iter(); | ||
| let arr: StringArray = expr_iter | ||
| .zip(bit_length_iter) | ||
| .map(|(expr, bit_length)| { | ||
| match sha2([ | ||
| ColumnarValue::Scalar(ScalarValue::Utf8(Some( | ||
| expr.unwrap().to_string(), | ||
| ))), | ||
| ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), | ||
| ]) | ||
| .unwrap() | ||
| { | ||
| ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, | ||
| ColumnarValue::Array(arr) => arr | ||
| .as_string::<i32>() | ||
| .iter() | ||
| .map(|str| str.unwrap().to_string()) | ||
| .next(), // first element | ||
| _ => unreachable!(), | ||
| } | ||
| }) | ||
| .collect(); | ||
| Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) | ||
| } | ||
| _ => exec_err!("Unsupported argument types for sha2 function"), | ||
| } | ||
| fn sha2_binary_impl<'a, BinaryArrType>( | ||
| values: &BinaryArrType, | ||
| bit_lengths: &Int32Array, | ||
| ) -> ArrayRef | ||
| where | ||
| BinaryArrType: BinaryArrayType<'a>, | ||
| { | ||
| let array = values | ||
| .iter() | ||
| .zip(bit_lengths.iter()) | ||
| .map(|(value, bit_length)| match (value, bit_length) { | ||
| (Some(value), Some(224)) => { | ||
| let mut digest = sha2::Sha224::default(); | ||
| digest.update(value); | ||
| Some(hex_encode(digest.finalize())) | ||
| } | ||
| (Some(value), Some(0 | 256)) => { | ||
| let mut digest = sha2::Sha256::default(); | ||
| digest.update(value); | ||
| Some(hex_encode(digest.finalize())) | ||
| } | ||
| (Some(value), Some(384)) => { | ||
| let mut digest = sha2::Sha384::default(); | ||
| digest.update(value); | ||
| Some(hex_encode(digest.finalize())) | ||
| } | ||
|
Comment on lines
+121
to
+134
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We directly use the |
||
| (Some(value), Some(512)) => { | ||
| let mut digest = sha2::Sha512::default(); | ||
| digest.update(value); | ||
| Some(hex_encode(digest.finalize())) | ||
| } | ||
| // Unknown bit-lengths go to null, same as in Spark | ||
| _ => None, | ||
| }) | ||
| .collect::<StringArray>(); | ||
| Arc::new(array) | ||
| } | ||
|
|
||
| fn compute_sha2( | ||
| bit_length_arg: i32, | ||
| expr_arg: &[ColumnarValue], | ||
| ) -> Result<ColumnarValue> { | ||
| match bit_length_arg { | ||
| 0 | 256 => sha256(expr_arg), | ||
| 224 => sha224(expr_arg), | ||
| 384 => sha384(expr_arg), | ||
| 512 => sha512(expr_arg), | ||
| _ => { | ||
| // Return null for unsupported bit lengths instead of error, because spark sha2 does not | ||
| // error out for this. | ||
| return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); | ||
| } | ||
| fn hex_encode<T: AsRef<[u8]>>(data: T) -> String { | ||
| let mut s = String::with_capacity(data.as_ref().len() * 2); | ||
| for b in data.as_ref() { | ||
| // Writing to a string never errors, so we can unwrap here. | ||
| write!(&mut s, "{b:02x}").unwrap(); | ||
| } | ||
| .map(|hashed| spark_sha2_hex(&[hashed]).unwrap()) | ||
| s | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because datafusion-spark now uses sha2 directly, we extract it as a common dependency