2424import sqlite3 as sqlite
2525import unittest
2626
27+ from test .support import import_helper
2728from test .support .os_helper import TESTFN , unlink
2829
2930from .util import memory_database , cx_limit , with_tracebacks
3031from .util import MemoryDatabaseMixin
3132
33+ # TODO(picnixz): increase test coverage for other callbacks
34+ # such as 'func', 'step', 'finalize', and 'collation'.
35+
3236
3337class CollationTests (MemoryDatabaseMixin , unittest .TestCase ):
3438
@@ -129,8 +133,55 @@ def test_deregister_collation(self):
129133 self .assertEqual (str (cm .exception ), 'no such collation sequence: mycoll' )
130134
131135
136+ class AuthorizerTests (MemoryDatabaseMixin , unittest .TestCase ):
137+
138+ def assert_not_authorized (self , func , / , * args , ** kwargs ):
139+ with self .assertRaisesRegex (sqlite .DatabaseError , "not authorized" ):
140+ func (* args , ** kwargs )
141+
142+ # When a handler has an invalid signature, the exception raised is
143+ # the same that would be raised if the handler "negatively" replied.
144+
145+ def test_authorizer_invalid_signature (self ):
146+ self .cx .execute ("create table if not exists test(a number)" )
147+ self .cx .set_authorizer (lambda : None )
148+ self .assert_not_authorized (self .cx .execute , "select * from test" )
149+
150+ # Tests for checking that callback context mutations do not crash.
151+ # Regression tests for https://github.com/python/cpython/issues/142830.
152+
153+ @with_tracebacks (ZeroDivisionError , regex = "hello world" )
154+ def test_authorizer_concurrent_mutation_in_call (self ):
155+ self .cx .execute ("create table if not exists test(a number)" )
156+
157+ def handler (* a , ** kw ):
158+ self .cx .set_authorizer (None )
159+ raise ZeroDivisionError ("hello world" )
160+
161+ self .cx .set_authorizer (handler )
162+ self .assert_not_authorized (self .cx .execute , "select * from test" )
163+
164+ @with_tracebacks (OverflowError )
165+ def test_authorizer_concurrent_mutation_with_overflown_value (self ):
166+ _testcapi = import_helper .import_module ("_testcapi" )
167+ self .cx .execute ("create table if not exists test(a number)" )
168+
169+ def handler (* a , ** kw ):
170+ self .cx .set_authorizer (None )
171+ # We expect 'int' at the C level, so this one will raise
172+ # when converting via PyLong_Int().
173+ return _testcapi .INT_MAX + 1
174+
175+ self .cx .set_authorizer (handler )
176+ self .assert_not_authorized (self .cx .execute , "select * from test" )
177+
178+
132179class ProgressTests (MemoryDatabaseMixin , unittest .TestCase ):
133180
181+ def assert_interrupted (self , func , / , * args , ** kwargs ):
182+ with self .assertRaisesRegex (sqlite .OperationalError , "interrupted" ):
183+ func (* args , ** kwargs )
184+
134185 def test_progress_handler_used (self ):
135186 """
136187 Test that the progress handler is invoked once it is set.
@@ -219,7 +270,7 @@ def bad_progress():
219270 create table foo(a, b)
220271 """ )
221272
222- def test_progress_handler_keyword_args (self ):
273+ def test_set_progress_handler_keyword_args (self ):
223274 regex = (
224275 r"Passing keyword argument 'progress_handler' to "
225276 r"_sqlite3.Connection.set_progress_handler\(\) is deprecated. "
@@ -231,6 +282,43 @@ def test_progress_handler_keyword_args(self):
231282 self .con .set_progress_handler (progress_handler = lambda : None , n = 1 )
232283 self .assertEqual (cm .filename , __file__ )
233284
285+ # When a handler has an invalid signature, the exception raised is
286+ # the same that would be raised if the handler "negatively" replied.
287+
288+ def test_progress_handler_invalid_signature (self ):
289+ self .cx .execute ("create table if not exists test(a number)" )
290+ self .cx .set_progress_handler (lambda x : None , 1 )
291+ self .assert_interrupted (self .cx .execute , "select * from test" )
292+
293+ # Tests for checking that callback context mutations do not crash.
294+ # Regression tests for https://github.com/python/cpython/issues/142830.
295+
296+ @with_tracebacks (ZeroDivisionError , regex = "hello world" )
297+ def test_progress_handler_concurrent_mutation_in_call (self ):
298+ self .cx .execute ("create table if not exists test(a number)" )
299+
300+ def handler (* a , ** kw ):
301+ self .cx .set_progress_handler (None , 1 )
302+ raise ZeroDivisionError ("hello world" )
303+
304+ self .cx .set_progress_handler (handler , 1 )
305+ self .assert_interrupted (self .cx .execute , "select * from test" )
306+
307+ def test_progress_handler_concurrent_mutation_in_conversion (self ):
308+ self .cx .execute ("create table if not exists test(a number)" )
309+
310+ class Handler :
311+ def __bool__ (_ ):
312+ # clear the progress handler
313+ self .cx .set_progress_handler (None , 1 )
314+ raise ValueError # force PyObject_True() to fail
315+
316+ self .cx .set_progress_handler (Handler .__init__ , 1 )
317+ self .assert_interrupted (self .cx .execute , "select * from test" )
318+
319+ # Running with tracebacks makes the second execution of this
320+ # function raise another exception because of a database change.
321+
234322
235323class TraceCallbackTests (MemoryDatabaseMixin , unittest .TestCase ):
236324
@@ -352,7 +440,7 @@ def test_trace_bad_handler(self):
352440 cx .set_trace_callback (lambda stmt : 5 / 0 )
353441 cx .execute ("select 1" )
354442
355- def test_trace_keyword_args (self ):
443+ def test_set_trace_callback_keyword_args (self ):
356444 regex = (
357445 r"Passing keyword argument 'trace_callback' to "
358446 r"_sqlite3.Connection.set_trace_callback\(\) is deprecated. "
@@ -364,6 +452,35 @@ def test_trace_keyword_args(self):
364452 self .con .set_trace_callback (trace_callback = lambda : None )
365453 self .assertEqual (cm .filename , __file__ )
366454
455+ # When a handler has an invalid signature, the exception raised is
456+ # the same that would be raised if the handler "negatively" replied,
457+ # but for the trace handler, exceptions are never re-raised (only
458+ # printed when needed).
459+
460+ @with_tracebacks (
461+ TypeError ,
462+ regex = r".*<lambda>\(\) missing 6 required positional arguments" ,
463+ )
464+ def test_trace_handler_invalid_signature (self ):
465+ self .cx .execute ("create table if not exists test(a number)" )
466+ self .cx .set_trace_callback (lambda x , y , z , t , a , b , c : None )
467+ self .cx .execute ("select * from test" )
468+
469+ # Tests for checking that callback context mutations do not crash.
470+ # Regression tests for https://github.com/python/cpython/issues/142830.
471+
472+ @with_tracebacks (ZeroDivisionError , regex = "hello world" )
473+ def test_trace_callback_concurrent_mutation_in_call (self ):
474+ self .cx .execute ("create table if not exists test(a number)" )
475+
476+ def handler (statement ):
477+ # clear the progress handler
478+ self .cx .set_trace_callback (None )
479+ raise ZeroDivisionError ("hello world" )
480+
481+ self .cx .set_trace_callback (handler )
482+ self .cx .execute ("select * from test" )
483+
367484
368485if __name__ == "__main__" :
369486 unittest .main ()
0 commit comments