Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ recursive = "0.1.1"
regex = "1.12"
rstest = "0.26.1"
serde_json = "1"
sha2 = "^0.10.9"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because datafusion-spark now uses sha2 directly, we extract it as a common dependency

sqlparser = { version = "0.59.0", default-features = false, features = ["std", "visitor"] }
strum = "0.27.2"
strum_macros = "0.27.2"
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ md-5 = { version = "^0.10.0", optional = true }
num-traits = { workspace = true }
rand = { workspace = true }
regex = { workspace = true, optional = true }
sha2 = { version = "^0.10.9", optional = true }
sha2 = { workspace = true, optional = true }
unicode-segmentation = { version = "^1.7.1", optional = true }
uuid = { version = "1.19", features = ["v4"], optional = true }

Expand Down
1 change: 1 addition & 0 deletions datafusion/spark/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ log = { workspace = true }
percent-encoding = "2.3.2"
rand = { workspace = true }
sha1 = "0.10"
sha2 = { workspace = true }
url = { workspace = true }

[dev-dependencies]
Expand Down
249 changes: 88 additions & 161 deletions datafusion/spark/src/function/hash/sha2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,30 @@
// specific language governing permissions and limitations
// under the License.

extern crate datafusion_functions;

use crate::function::error_utils::{
invalid_arg_count_exec_err, unsupported_data_type_exec_err,
};
use crate::function::math::hex::spark_sha2_hex;
use arrow::array::{ArrayRef, AsArray, StringArray};
use arrow::array::{ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray};
use arrow::datatypes::{DataType, Int32Type};
use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err};
use datafusion_expr::Signature;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512};
use datafusion_common::types::{
NativeType, logical_binary, logical_int32, logical_string,
};
use datafusion_common::utils::take_function_args;
use datafusion_common::{Result, internal_err};
use datafusion_expr::{
Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignatureClass, Volatility,
};
use datafusion_functions::utils::make_scalar_function;
use sha2::{self, Digest};
use std::any::Any;
use std::fmt::Write;
use std::sync::Arc;

/// Differs from DataFusion version in allowing array input for bit lengths, and
/// also hex encoding the output.
///
/// <https://spark.apache.org/docs/latest/api/sql/index.html#sha2>
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkSha2 {
signature: Signature,
aliases: Vec<String>,
}

impl Default for SparkSha2 {
Expand All @@ -46,8 +50,21 @@ impl Default for SparkSha2 {
impl SparkSha2 {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec![],
signature: Signature::coercible(
vec![
Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_string())],
NativeType::Binary,
),
Coercion::new_implicit(
TypeSignatureClass::Native(logical_int32()),
vec![TypeSignatureClass::Integer],
NativeType::Int32,
Comment on lines -49 to +63
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving away from user_defined; also we cast strings to binary to simplify implementation as we only need raw bytes either way

),
],
Volatility::Immutable,
),
}
}
}
Expand All @@ -65,163 +82,73 @@ impl ScalarUDFImpl for SparkSha2 {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types[1].is_null() {
return Ok(DataType::Null);
}
Ok(match arg_types[0] {
DataType::Utf8View
| DataType::LargeUtf8
| DataType::Utf8
| DataType::Binary
| DataType::BinaryView
| DataType::LargeBinary => DataType::Utf8,
DataType::Null => DataType::Null,
_ => {
return exec_err!(
"{} function can only accept strings or binary arrays.",
self.name()
);
}
})
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
internal_datafusion_err!("Expected 2 arguments for function sha2")
})?;

sha2(args)
make_scalar_function(sha2_impl, vec![])(&args.args)
}
}

fn aliases(&self) -> &[String] {
&self.aliases
}
fn sha2_impl(args: &[ArrayRef]) -> Result<ArrayRef> {
let [values, bit_lengths] = take_function_args("sha2", args)?;

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return Err(invalid_arg_count_exec_err(
self.name(),
(2, 2),
arg_types.len(),
));
let bit_lengths = bit_lengths.as_primitive::<Int32Type>();
let output = match values.data_type() {
DataType::Binary => sha2_binary_impl(&values.as_binary::<i32>(), bit_lengths),
DataType::LargeBinary => {
sha2_binary_impl(&values.as_binary::<i64>(), bit_lengths)
}
let expr_type = match &arg_types[0] {
DataType::Utf8View
| DataType::LargeUtf8
| DataType::Utf8
| DataType::Binary
| DataType::BinaryView
| DataType::LargeBinary
| DataType::Null => Ok(arg_types[0].clone()),
_ => Err(unsupported_data_type_exec_err(
self.name(),
"String, Binary",
&arg_types[0],
)),
}?;
let bit_length_type = if arg_types[1].is_numeric() {
Ok(DataType::Int32)
} else if arg_types[1].is_null() {
Ok(DataType::Null)
} else {
Err(unsupported_data_type_exec_err(
self.name(),
"Numeric Type",
&arg_types[1],
))
}?;

Ok(vec![expr_type, bit_length_type])
}
DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), bit_lengths),
dt => return internal_err!("Unsupported datatype for sha2: {dt}"),
};
Ok(output)
}

pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
match args {
[
ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)),
ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))),
] => compute_sha2(
bit_length_arg,
&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))],
),
[
ColumnarValue::Array(expr_arg),
ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))),
] => compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)]),
[
ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)),
ColumnarValue::Array(bit_length_arg),
] => {
let arr: StringArray = bit_length_arg
.as_primitive::<Int32Type>()
.iter()
.map(|bit_length| {
match sha2([
ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())),
ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
])
.unwrap()
{
ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
ColumnarValue::Array(arr) => arr
.as_string::<i32>()
.iter()
.map(|str| str.unwrap().to_string())
.next(), // first element
_ => unreachable!(),
}
})
.collect();
Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
}
[
ColumnarValue::Array(expr_arg),
ColumnarValue::Array(bit_length_arg),
] => {
let expr_iter = expr_arg.as_string::<i32>().iter();
let bit_length_iter = bit_length_arg.as_primitive::<Int32Type>().iter();
let arr: StringArray = expr_iter
.zip(bit_length_iter)
.map(|(expr, bit_length)| {
match sha2([
ColumnarValue::Scalar(ScalarValue::Utf8(Some(
expr.unwrap().to_string(),
))),
ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
])
.unwrap()
{
ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
ColumnarValue::Array(arr) => arr
.as_string::<i32>()
.iter()
.map(|str| str.unwrap().to_string())
.next(), // first element
_ => unreachable!(),
}
})
.collect();
Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
}
_ => exec_err!("Unsupported argument types for sha2 function"),
}
fn sha2_binary_impl<'a, BinaryArrType>(
values: &BinaryArrType,
bit_lengths: &Int32Array,
) -> ArrayRef
where
BinaryArrType: BinaryArrayType<'a>,
{
let array = values
.iter()
.zip(bit_lengths.iter())
.map(|(value, bit_length)| match (value, bit_length) {
(Some(value), Some(224)) => {
let mut digest = sha2::Sha224::default();
digest.update(value);
Some(hex_encode(digest.finalize()))
}
(Some(value), Some(0 | 256)) => {
let mut digest = sha2::Sha256::default();
digest.update(value);
Some(hex_encode(digest.finalize()))
}
(Some(value), Some(384)) => {
let mut digest = sha2::Sha384::default();
digest.update(value);
Some(hex_encode(digest.finalize()))
}
Comment on lines +121 to +134
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We directly use the sha2 crate to do the hashing now; previously we used functions exposed by datafusion-functions but that seemed like unnecessary indirection

(Some(value), Some(512)) => {
let mut digest = sha2::Sha512::default();
digest.update(value);
Some(hex_encode(digest.finalize()))
}
// Unknown bit-lengths go to null, same as in Spark
_ => None,
})
.collect::<StringArray>();
Arc::new(array)
}

fn compute_sha2(
bit_length_arg: i32,
expr_arg: &[ColumnarValue],
) -> Result<ColumnarValue> {
match bit_length_arg {
0 | 256 => sha256(expr_arg),
224 => sha224(expr_arg),
384 => sha384(expr_arg),
512 => sha512(expr_arg),
_ => {
// Return null for unsupported bit lengths instead of error, because spark sha2 does not
// error out for this.
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
}
fn hex_encode<T: AsRef<[u8]>>(data: T) -> String {
let mut s = String::with_capacity(data.as_ref().len() * 2);
for b in data.as_ref() {
// Writing to a string never errors, so we can unwrap here.
write!(&mut s, "{b:02x}").unwrap();
}
.map(|hashed| spark_sha2_hex(&[hashed]).unwrap())
s
}
55 changes: 55 additions & 0 deletions datafusion/sqllogictest/test_files/spark/hash/sha2.slt
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,58 @@ SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('ba
967004d25de4abc1bd6a7c9a216254a5ac0733e8ad96dc9f1ea0fad9619da7c32d654ec8ad8ba2f9b5728fed6633bd91
8c6be9ed448a34883a13a13f4ead4aefa036b67dcda59020c01e57ea075ea8a4792d428f2c6fd0c09d1c49994d6c22789336e062188df29572ed07e7f9779c52
NULL

# All string types
query T
SELECT sha2(arrow_cast('foo', 'Utf8'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
----
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae

query T
SELECT sha2(arrow_cast('foo', 'LargeUtf8'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
----
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae

query T
SELECT sha2(arrow_cast('foo', 'Utf8View'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
----
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae

# All binary types
query T
SELECT sha2(arrow_cast('foo', 'Binary'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
----
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae

query T
SELECT sha2(arrow_cast('foo', 'LargeBinary'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
----
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae

query T
SELECT sha2(arrow_cast('foo', 'BinaryView'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
----
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae


# Null cases
query T
select sha2(null, 0);
----
NULL

query T
select sha2('a', null);
----
NULL

query T
select sha2('a', null::int);
----
NULL