From 78e4afac109aaae6727dcf4c87ff3defb5243ca9 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Thu, 25 Dec 2025 19:12:32 +0200 Subject: [PATCH 1/5] Optimize column_encryption_policy checks in recv_results_rows There's no point in checking a global policy for every single value decoding, not for every row decoded. Adjusted the code to only check it once per recv_results_rows() call - decode_row() should be defined either as is today with column_encryption_policy enabled, or much simpler without all those extra checks. Added a unit test from CoPilot. Fixes: https://github.com/scylladb/python-driver/issues/582 Signed-off-by: Yaniv Kaul --- cassandra/protocol.py | 31 +++- .../unit/test_protocol_decode_optimization.py | 155 ++++++++++++++++++ 2 files changed, 177 insertions(+), 9 deletions(-) create mode 100644 tests/unit/test_protocol_decode_optimization.py diff --git a/cassandra/protocol.py b/cassandra/protocol.py index e574965de8..5f77818c70 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -719,24 +719,37 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)] self.column_names = [c[2] for c in column_metadata] self.column_types = [c[3] for c in column_metadata] - col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] - def decode_val(val, col_md, col_desc): - uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc) - col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] - raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val - return col_type.from_binary(raw_bytes, protocol_version) + if column_encryption_policy: + col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] - def decode_row(row): - return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) + def decode_val(val, col_md, col_desc): + uses_ce = column_encryption_policy.contains_column(col_desc) + if uses_ce: + col_type = column_encryption_policy.column_type(col_desc) + raw_bytes = column_encryption_policy.decrypt(col_desc, val) + return col_type.from_binary(raw_bytes, protocol_version) + else: + return col_md[3].from_binary(val, protocol_version) + + def decode_row(row): + return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) + else: + def decode_row(row): + return tuple(col_md[3].from_binary(val, protocol_version) for val, col_md in zip(row, column_metadata)) try: self.parsed_rows = [decode_row(row) for row in rows] except Exception: + if not column_encryption_policy: + col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] for row in rows: for val, col_md, col_desc in zip(row, column_metadata, col_descs): try: - decode_val(val, col_md, col_desc) + if column_encryption_policy: + decode_val(val, col_md, col_desc) + else: + col_md[3].from_binary(val, protocol_version) except Exception as e: raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2], col_md[3].cql_parameterized_type(), diff --git a/tests/unit/test_protocol_decode_optimization.py b/tests/unit/test_protocol_decode_optimization.py new file mode 100644 index 0000000000..e0fd81fe3e --- /dev/null +++ b/tests/unit/test_protocol_decode_optimization.py @@ -0,0 +1,155 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import unittest +from unittest.mock import Mock + +from cassandra import ProtocolVersion +from cassandra.cqltypes import Int32Type, UTF8Type +from cassandra.marshal import int32_pack +from cassandra.policies import ColDesc +from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS + + +class DecodeOptimizationTest(unittest.TestCase): + """ + Tests to verify the optimization of column_encryption_policy checks + in recv_results_rows. The optimization checks if the policy exists once + per result message, avoiding the redundant 'column_encryption_policy and ...' + check for every value. + """ + + def _create_mock_result_metadata(self): + """Create mock result metadata for testing""" + return [ + ('keyspace1', 'table1', 'col1', Int32Type), + ('keyspace1', 'table1', 'col2', UTF8Type), + ] + + def _create_mock_result_message(self): + """Create a mock result message with data""" + msg = ResultMessage(kind=RESULT_KIND_ROWS) + msg.column_metadata = self._create_mock_result_metadata() + msg.recv_results_metadata = Mock() + msg.recv_row = Mock(side_effect=[ + [int32_pack(42), b'hello'], + [int32_pack(100), b'world'], + ]) + return msg + + def _create_mock_stream(self): + """Create a mock stream for reading rows""" + # Pack rowcount (2 rows) + data = int32_pack(2) + return io.BytesIO(data) + + def test_decode_without_encryption_policy(self): + """ + Test that decoding works correctly without column encryption policy. + This should use the optimized simple path. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + self.assertEqual(msg.parsed_rows[1][0], 100) + self.assertEqual(msg.parsed_rows[1][1], 'world') + + def test_decode_with_encryption_policy_no_encrypted_columns(self): + """ + Test that decoding works with encryption policy when no columns are encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy that has no encrypted columns + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + def test_decode_with_encryption_policy_with_encrypted_column(self): + """ + Test that decoding works with encryption policy when one column is encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy where first column is encrypted + mock_policy = Mock() + def contains_column_side_effect(col_desc): + return col_desc.col == 'col1' + mock_policy.contains_column = Mock(side_effect=contains_column_side_effect) + mock_policy.column_type = Mock(return_value=Int32Type) + mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + # Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column) + self.assertEqual(mock_policy.decrypt.call_count, 2) + + def test_optimization_efficiency(self): + """ + Verify that the optimization checks policy existence once per result message. + The key optimization is checking 'if column_encryption_policy:' once, + rather than 'column_encryption_policy and ...' for every value. + """ + msg = self._create_mock_result_message() + + # Create more rows to make the check pattern clear + msg.recv_row = Mock(side_effect=[ + [int32_pack(i), f'text{i}'.encode()] for i in range(100) + ]) + + # Create mock stream with 100 rows + f = io.BytesIO(int32_pack(100)) + + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # With optimization: policy existence checked once, contains_column called per value + # = 100 rows * 2 columns = 200 calls to contains_column + # The key is we avoid checking 'column_encryption_policy and ...' 200 times + self.assertEqual(mock_policy.contains_column.call_count, 200, + "contains_column should be called for each value when policy exists") + + +if __name__ == '__main__': + unittest.main() From 44574c650edb658ab9900eacf5a357c3f4ec0245 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 4 Jan 2026 17:59:21 +0200 Subject: [PATCH 2/5] (improvement)Optimize column_encryption_policy checks in Cython's unpack_row() function Very similar to the native Python code, separate the two cases, if column encryption (CE) policy is not enabled, the code is substantially simplified. If it is, it's slightly more elaborate. Decided to have two loops in two functions, one for each case, for performance reasons, even if readability-wise it's not as great. AI agreed with me: Recommendation: Keep it as is. In high-performance Cython code like this, duplicating a small block of code Fixes: https://github.com/scylladb/python-driver/issues/639 Signed-off-by: Yaniv Kaul --- cassandra/obj_parser.pyx | 47 ++++++++++++++++++++++++++++++++++------ cassandra/row_parser.pyx | 8 +++++-- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/cassandra/obj_parser.pyx b/cassandra/obj_parser.pyx index cf43771dd7..2d366fc5bb 100644 --- a/cassandra/obj_parser.pyx +++ b/cassandra/obj_parser.pyx @@ -31,7 +31,10 @@ cdef class ListParser(ColumnParser): cdef Py_ssize_t i, rowcount rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() - return [rowparser.unpack_row(reader, desc) for i in range(rowcount)] + if desc.column_encryption_policy: + return [rowparser.unpack_ce_row(reader, desc) for i in range(rowcount)] + else: + return [rowparser.unpack_row(reader, desc) for i in range(rowcount)] cdef class LazyParser(ColumnParser): @@ -47,7 +50,10 @@ def parse_rows_lazy(BytesIOReader reader, ParseDesc desc): cdef Py_ssize_t i, rowcount rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() - return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) + if desc.column_encryption_policy: + return (rowparser.unpack_ce_row(reader, desc) for i in range(rowcount)) + else: + return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) cdef class TupleRowParser(RowParser): @@ -55,9 +61,11 @@ cdef class TupleRowParser(RowParser): Parse a single returned row into a tuple of objects: (obj1, ..., objN) + If CE (Column encryption) policy is enabled - use unpack_ce_row(), + otherwsise use unpack_row() """ - cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): + cpdef unpack_ce_row(self, BytesIOReader reader, ParseDesc desc): assert desc.rowsize >= 0 cdef Buffer buf @@ -73,9 +81,9 @@ cdef class TupleRowParser(RowParser): # Deserialize bytes to python object deserializer = desc.deserializers[i] - coldesc = desc.coldescs[i] - uses_ce = ce_policy and ce_policy.contains_column(coldesc) try: + coldesc = desc.coldescs[i] + uses_ce = ce_policy.contains_column(coldesc) if uses_ce: col_type = ce_policy.column_type(coldesc) decrypted_bytes = ce_policy.decrypt(coldesc, to_bytes(&buf)) @@ -84,11 +92,36 @@ cdef class TupleRowParser(RowParser): val = from_binary(deserializer, &newbuf, desc.protocol_version) else: val = from_binary(deserializer, &buf, desc.protocol_version) + # Insert new object into tuple + tuple_set(res, i, val) + except Exception as e: + raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], + desc.coltypes[i].cql_parameterized_type(), + str(e))) + + return res + + cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): + assert desc.rowsize >= 0 + + cdef Buffer buf + cdef Py_ssize_t i, rowsize = desc.rowsize + cdef Deserializer deserializer + cdef tuple res = tuple_new(desc.rowsize) + + for i in range(rowsize): + # Read the next few bytes + get_buf(reader, &buf) + + # Deserialize bytes to python object + deserializer = desc.deserializers[i] + try: + val = from_binary(deserializer, &buf, desc.protocol_version) + # Insert new object into tuple + tuple_set(res, i, val) except Exception as e: raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], desc.coltypes[i].cql_parameterized_type(), str(e))) - # Insert new object into tuple - tuple_set(res, i, val) return res diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index 88277a4593..1308f5b2ce 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -44,7 +44,11 @@ def make_recv_results_rows(ColumnParser colparser): reader.buf_ptr = reader.buf reader.pos = 0 rowcount = read_int(reader) - for i in range(rowcount): - rowparser.unpack_row(reader, desc) + if desc.column_encryption_policy: + for i in range(rowcount): + rowparser.unpack_ce_row(reader, desc) + else: + for i in range(rowcount): + rowparser.unpack_row(reader, desc) return recv_results_rows From bb07e9badef76e799f4b522a2ce04cc17d284384 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Mon, 5 Jan 2026 11:07:44 +0200 Subject: [PATCH 3/5] (improvement)Optimize column_encryption_policy checks: tests Add tests, respond to review feedback on added tests. Signed-off-by: Yaniv Kaul --- tests/unit/test_protocol.py | 135 ++++++++++++++- .../unit/test_protocol_decode_optimization.py | 155 ------------------ 2 files changed, 133 insertions(+), 157 deletions(-) delete mode 100644 tests/unit/test_protocol_decode_optimization.py diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 9704811239..ea12fa7b5a 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -12,22 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io import unittest from unittest.mock import Mock from cassandra import ProtocolVersion, UnsupportedOperation +from cassandra.cqltypes import Int32Type, UTF8Type from cassandra.protocol import ( PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, - BatchMessage + BatchMessage, + ResultMessage, RESULT_KIND_ROWS ) from cassandra.query import BatchType -from cassandra.marshal import uint32_unpack +from cassandra.marshal import uint32_unpack, int32_pack from cassandra.cluster import ContinuousPagingOptions import pytest +from cassandra.policies import ColDesc class MessageTest(unittest.TestCase): @@ -189,3 +193,130 @@ def test_batch_message_with_keyspace(self): (b'\x00\x03',), (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) ) + +class ResultTest(unittest.TestCase): + """ + Tests to verify the optimization of column_encryption_policy checks + in recv_results_rows. The optimization checks if the policy exists once + per result message, avoiding the redundant 'column_encryption_policy and ...' + check for every value. + """ + + def _create_mock_result_metadata(self): + """Create mock result metadata for testing""" + return [ + ('keyspace1', 'table1', 'col1', Int32Type), + ('keyspace1', 'table1', 'col2', UTF8Type), + ] + + def _create_mock_result_message(self): + """Create a mock result message with data""" + msg = ResultMessage(kind=RESULT_KIND_ROWS) + msg.column_metadata = self._create_mock_result_metadata() + msg.recv_results_metadata = Mock() + msg.recv_row = Mock(side_effect=[ + [int32_pack(42), b'hello'], + [int32_pack(100), b'world'], + ]) + return msg + + def _create_mock_stream(self): + """Create a mock stream for reading rows""" + # Pack rowcount (2 rows) + data = int32_pack(2) + return io.BytesIO(data) + + def test_decode_without_encryption_policy(self): + """ + Test that decoding works correctly without column encryption policy. + This should use the optimized simple path. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + self.assertEqual(msg.parsed_rows[1][0], 100) + self.assertEqual(msg.parsed_rows[1][1], 'world') + + def test_decode_with_encryption_policy_no_encrypted_columns(self): + """ + Test that decoding works with encryption policy when no columns are encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy that has no encrypted columns + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + def test_decode_with_encryption_policy_with_encrypted_column(self): + """ + Test that decoding works with encryption policy when one column is encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy where first column is encrypted + mock_policy = Mock() + def contains_column_side_effect(col_desc): + return col_desc.col == 'col1' + mock_policy.contains_column = Mock(side_effect=contains_column_side_effect) + mock_policy.column_type = Mock(return_value=Int32Type) + mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + # Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column) + self.assertEqual(mock_policy.decrypt.call_count, 2) + + def test_optimization_efficiency(self): + """ + Verify that the optimization checks policy existence once per result message. + The key optimization is checking 'if column_encryption_policy:' once, + rather than 'column_encryption_policy and ...' for every value. + """ + msg = self._create_mock_result_message() + + # Create more rows to make the check pattern clear + msg.recv_row = Mock(side_effect=[ + [int32_pack(i), f'text{i}'.encode()] for i in range(100) + ]) + + # Create mock stream with 100 rows + f = io.BytesIO(int32_pack(100)) + + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # With optimization: policy existence checked once, contains_column called per value + # = 100 rows * 2 columns = 200 calls to contains_column + # The key is we avoid checking 'column_encryption_policy and ...' 200 times + self.assertEqual(mock_policy.contains_column.call_count, 200, + "contains_column should be called for each value when policy exists") diff --git a/tests/unit/test_protocol_decode_optimization.py b/tests/unit/test_protocol_decode_optimization.py deleted file mode 100644 index e0fd81fe3e..0000000000 --- a/tests/unit/test_protocol_decode_optimization.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import unittest -from unittest.mock import Mock - -from cassandra import ProtocolVersion -from cassandra.cqltypes import Int32Type, UTF8Type -from cassandra.marshal import int32_pack -from cassandra.policies import ColDesc -from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS - - -class DecodeOptimizationTest(unittest.TestCase): - """ - Tests to verify the optimization of column_encryption_policy checks - in recv_results_rows. The optimization checks if the policy exists once - per result message, avoiding the redundant 'column_encryption_policy and ...' - check for every value. - """ - - def _create_mock_result_metadata(self): - """Create mock result metadata for testing""" - return [ - ('keyspace1', 'table1', 'col1', Int32Type), - ('keyspace1', 'table1', 'col2', UTF8Type), - ] - - def _create_mock_result_message(self): - """Create a mock result message with data""" - msg = ResultMessage(kind=RESULT_KIND_ROWS) - msg.column_metadata = self._create_mock_result_metadata() - msg.recv_results_metadata = Mock() - msg.recv_row = Mock(side_effect=[ - [int32_pack(42), b'hello'], - [int32_pack(100), b'world'], - ]) - return msg - - def _create_mock_stream(self): - """Create a mock stream for reading rows""" - # Pack rowcount (2 rows) - data = int32_pack(2) - return io.BytesIO(data) - - def test_decode_without_encryption_policy(self): - """ - Test that decoding works correctly without column encryption policy. - This should use the optimized simple path. - """ - msg = self._create_mock_result_message() - f = self._create_mock_stream() - - msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None) - - # Verify results - self.assertEqual(len(msg.parsed_rows), 2) - self.assertEqual(msg.parsed_rows[0][0], 42) - self.assertEqual(msg.parsed_rows[0][1], 'hello') - self.assertEqual(msg.parsed_rows[1][0], 100) - self.assertEqual(msg.parsed_rows[1][1], 'world') - - def test_decode_with_encryption_policy_no_encrypted_columns(self): - """ - Test that decoding works with encryption policy when no columns are encrypted. - """ - msg = self._create_mock_result_message() - f = self._create_mock_stream() - - # Create mock encryption policy that has no encrypted columns - mock_policy = Mock() - mock_policy.contains_column = Mock(return_value=False) - - msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) - - # Verify results - self.assertEqual(len(msg.parsed_rows), 2) - self.assertEqual(msg.parsed_rows[0][0], 42) - self.assertEqual(msg.parsed_rows[0][1], 'hello') - - # Verify contains_column was called for each value (but policy existence check happens once) - # Should be called 4 times (2 rows × 2 columns) - self.assertEqual(mock_policy.contains_column.call_count, 4) - - def test_decode_with_encryption_policy_with_encrypted_column(self): - """ - Test that decoding works with encryption policy when one column is encrypted. - """ - msg = self._create_mock_result_message() - f = self._create_mock_stream() - - # Create mock encryption policy where first column is encrypted - mock_policy = Mock() - def contains_column_side_effect(col_desc): - return col_desc.col == 'col1' - mock_policy.contains_column = Mock(side_effect=contains_column_side_effect) - mock_policy.column_type = Mock(return_value=Int32Type) - mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val) - - msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) - - # Verify results - self.assertEqual(len(msg.parsed_rows), 2) - self.assertEqual(msg.parsed_rows[0][0], 42) - self.assertEqual(msg.parsed_rows[0][1], 'hello') - - # Verify contains_column was called for each value (but policy existence check happens once) - # Should be called 4 times (2 rows × 2 columns) - self.assertEqual(mock_policy.contains_column.call_count, 4) - - # Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column) - self.assertEqual(mock_policy.decrypt.call_count, 2) - - def test_optimization_efficiency(self): - """ - Verify that the optimization checks policy existence once per result message. - The key optimization is checking 'if column_encryption_policy:' once, - rather than 'column_encryption_policy and ...' for every value. - """ - msg = self._create_mock_result_message() - - # Create more rows to make the check pattern clear - msg.recv_row = Mock(side_effect=[ - [int32_pack(i), f'text{i}'.encode()] for i in range(100) - ]) - - # Create mock stream with 100 rows - f = io.BytesIO(int32_pack(100)) - - mock_policy = Mock() - mock_policy.contains_column = Mock(return_value=False) - - msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) - - # With optimization: policy existence checked once, contains_column called per value - # = 100 rows * 2 columns = 200 calls to contains_column - # The key is we avoid checking 'column_encryption_policy and ...' 200 times - self.assertEqual(mock_policy.contains_column.call_count, 200, - "contains_column should be called for each value when policy exists") - - -if __name__ == '__main__': - unittest.main() From d67815cf9cc976c795fedbcd05ff013cb497c9b6 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 18 Jan 2026 14:47:18 +0200 Subject: [PATCH 4/5] query: split Column Encryption in the bind path Split BoundStatement.bind() into CE and non-CE loops to avoid per-value CE checks when no policy is configured. In the CE loop, use a single uses_ce branch to select type serialization and optional encryption for each column. Signed-off-by: Yaniv Kaul --- cassandra/query.py | 65 +++++++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..fd165469a2 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -636,28 +636,51 @@ def bind(self, values): self.raw_values = values self.values = [] - for value, col_spec in zip(values, col_meta): - if value is None: - self.values.append(None) - elif value is UNSET_VALUE: - if proto_version >= 4: - self._append_unset_value() + if ce_policy: + for value, col_spec in zip(values, col_meta): + if value is None: + self.values.append(None) + elif value is UNSET_VALUE: + if proto_version >= 4: + self._append_unset_value() + else: + raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) - else: - try: - col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) - uses_ce = ce_policy and ce_policy.contains_column(col_desc) - col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type - col_bytes = col_type.serialize(value, proto_version) - if uses_ce: - col_bytes = ce_policy.encrypt(col_desc, col_bytes) - self.values.append(col_bytes) - except (TypeError, struct.error) as exc: - actual_type = type(value) - message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) - raise TypeError(message) + try: + col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) + uses_ce = ce_policy.contains_column(col_desc) + if uses_ce: + col_type = ce_policy.column_type(col_desc) + col_bytes = col_type.serialize(value, proto_version) + col_bytes = ce_policy.encrypt(col_desc, col_bytes) + else: + col_type = col_spec.type + col_bytes = col_type.serialize(value, proto_version) + self.values.append(col_bytes) + except (TypeError, struct.error) as exc: + actual_type = type(value) + message = ('Received an argument of invalid type for column "%s". ' + 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + raise TypeError(message) + else: + for value, col_spec in zip(values, col_meta): + if value is None: + self.values.append(None) + elif value is UNSET_VALUE: + if proto_version >= 4: + self._append_unset_value() + else: + raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + else: + try: + col_type = col_spec.type + col_bytes = col_type.serialize(value, proto_version) + self.values.append(col_bytes) + except (TypeError, struct.error) as exc: + actual_type = type(value) + message = ('Received an argument of invalid type for column "%s". ' + 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + raise TypeError(message) if proto_version >= 4: diff = col_meta_len - len(self.values) From d87f917715ec1aa935bddce000a0f2919ef433dd Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Tue, 27 Jan 2026 18:16:40 +0200 Subject: [PATCH 5/5] Renamed unpack_row() -> unpack_plain_row(), unpack_ce_row() -> unpack_col_encrypted_row() Per review comments. Signed-off-by: Yaniv Kaul --- cassandra/numpy_parser.pyx | 4 ++-- cassandra/obj_parser.pyx | 16 ++++++++-------- cassandra/parsing.pxd | 2 +- cassandra/parsing.pyx | 2 +- cassandra/row_parser.pyx | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/cassandra/numpy_parser.pyx b/cassandra/numpy_parser.pyx index 030c2c65c7..faa6d6b93f 100644 --- a/cassandra/numpy_parser.pyx +++ b/cassandra/numpy_parser.pyx @@ -97,7 +97,7 @@ cdef _parse_rows(BytesIOReader reader, ParseDesc desc, cdef Py_ssize_t i for i in range(rowcount): - unpack_row(reader, desc, arrs) + unpack_plain_row(reader, desc, arrs) ### Helper functions to create NumPy arrays and array descriptors @@ -144,7 +144,7 @@ def make_array(coltype, array_size): @cython.boundscheck(False) @cython.wraparound(False) -cdef inline int unpack_row( +cdef inline int unpack_plain_row( BytesIOReader reader, ParseDesc desc, ArrDesc *arrays) except -1: cdef Buffer buf cdef Py_ssize_t i, rowsize = desc.rowsize diff --git a/cassandra/obj_parser.pyx b/cassandra/obj_parser.pyx index 2d366fc5bb..4f924225f1 100644 --- a/cassandra/obj_parser.pyx +++ b/cassandra/obj_parser.pyx @@ -32,9 +32,9 @@ cdef class ListParser(ColumnParser): rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() if desc.column_encryption_policy: - return [rowparser.unpack_ce_row(reader, desc) for i in range(rowcount)] + return [rowparser.unpack_col_encrypted_row(reader, desc) for i in range(rowcount)] else: - return [rowparser.unpack_row(reader, desc) for i in range(rowcount)] + return [rowparser.unpack_plain_row(reader, desc) for i in range(rowcount)] cdef class LazyParser(ColumnParser): @@ -51,9 +51,9 @@ def parse_rows_lazy(BytesIOReader reader, ParseDesc desc): rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() if desc.column_encryption_policy: - return (rowparser.unpack_ce_row(reader, desc) for i in range(rowcount)) + return (rowparser.unpack_col_encrypted_row(reader, desc) for i in range(rowcount)) else: - return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) + return (rowparser.unpack_plain_row(reader, desc) for i in range(rowcount)) cdef class TupleRowParser(RowParser): @@ -61,11 +61,11 @@ cdef class TupleRowParser(RowParser): Parse a single returned row into a tuple of objects: (obj1, ..., objN) - If CE (Column encryption) policy is enabled - use unpack_ce_row(), - otherwsise use unpack_row() + If CE (Column encryption) policy is enabled - use unpack_col_encrypted_row(), + otherwsise use unpack_plain_row() """ - cpdef unpack_ce_row(self, BytesIOReader reader, ParseDesc desc): + cpdef unpack_col_encrypted_row(self, BytesIOReader reader, ParseDesc desc): assert desc.rowsize >= 0 cdef Buffer buf @@ -101,7 +101,7 @@ cdef class TupleRowParser(RowParser): return res - cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): + cpdef unpack_plain_row(self, BytesIOReader reader, ParseDesc desc): assert desc.rowsize >= 0 cdef Buffer buf diff --git a/cassandra/parsing.pxd b/cassandra/parsing.pxd index 27dc368b07..6a93f4104e 100644 --- a/cassandra/parsing.pxd +++ b/cassandra/parsing.pxd @@ -28,5 +28,5 @@ cdef class ColumnParser: cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc) cdef class RowParser: - cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc) + cpdef unpack_plain_row(self, BytesIOReader reader, ParseDesc desc) diff --git a/cassandra/parsing.pyx b/cassandra/parsing.pyx index 954767d227..5f9af7c7df 100644 --- a/cassandra/parsing.pyx +++ b/cassandra/parsing.pyx @@ -39,7 +39,7 @@ cdef class ColumnParser: cdef class RowParser: """Parser for a single row""" - cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): + cpdef unpack_plain_row(self, BytesIOReader reader, ParseDesc desc): """ Unpack a single row of data in a ResultMessage. """ diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index 1308f5b2ce..daa87ca362 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -46,9 +46,9 @@ def make_recv_results_rows(ColumnParser colparser): rowcount = read_int(reader) if desc.column_encryption_policy: for i in range(rowcount): - rowparser.unpack_ce_row(reader, desc) + rowparser.unpack_col_encrypted_row(reader, desc) else: for i in range(rowcount): - rowparser.unpack_row(reader, desc) + rowparser.unpack_plain_row(reader, desc) return recv_results_rows