diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index e84257ea67..656171c947 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -159,12 +159,35 @@ impl SparkBloomFilter { self.bits.to_bytes() } + /// Extracts bits data from Spark's full serialization format. + /// Spark's format includes a 12-byte header (version + num_hash_functions + num_words) + /// followed by the bits data. This function extracts just the bits data. + fn extract_bits_from_spark_format(&self, buf: &[u8]) -> &[u8] { + const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4) + + // Check if this is Spark's full serialization format + let expected_bits_size = self.bits.byte_size(); + if buf.len() == SPARK_HEADER_SIZE + expected_bits_size { + // This is Spark's full format, extract bits data (skip header) + &buf[SPARK_HEADER_SIZE..] + } else { + // This is already just bits data (Comet format) + buf + } + } + pub fn merge_filter(&mut self, other: &[u8]) { + // Extract bits data if other is in Spark's full serialization format + let bits_data = self.extract_bits_from_spark_format(other); + assert_eq!( - other.len(), + bits_data.len(), + self.bits.byte_size(), + "Cannot merge SparkBloomFilters with different lengths. Expected {} bytes, got {} bytes (full buffer size: {} bytes)", self.bits.byte_size(), - "Cannot merge SparkBloomFilters with different lengths." + bits_data.len(), + other.len() ); - self.bits.merge_bits(other); + self.bits.merge_bits(bits_data); } }