diff --git a/pre_commit_hooks/debug_statement_hook.py b/pre_commit_hooks/debug_statement_hook.py index 7e6be95e..a083bb3c 100644 --- a/pre_commit_hooks/debug_statement_hook.py +++ b/pre_commit_hooks/debug_statement_hook.py @@ -20,6 +20,11 @@ 'wdb', } +DEBUG_CALL_STATEMENTS = { + 'breakpoint', + 'print', +} + class Debug(NamedTuple): line: int @@ -45,7 +50,10 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: def visit_Call(self, node: ast.Call) -> None: """python3.7+ breakpoint()""" - if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint': + if ( + isinstance(node.func, ast.Name) and + node.func.id in DEBUG_CALL_STATEMENTS + ): st = Debug(node.lineno, node.col_offset, node.func.id, 'called') self.breakpoints.append(st) self.generic_visit(node) diff --git a/tests/debug_statement_hook_test.py b/tests/debug_statement_hook_test.py index 5a8e0bb2..da9c218c 100644 --- a/tests/debug_statement_hook_test.py +++ b/tests/debug_statement_hook_test.py @@ -32,6 +32,12 @@ def test_finds_breakpoint(): assert visitor.breakpoints == [Debug(1, 0, 'breakpoint', 'called')] +def test_finds_print(): + visitor = DebugStatementParser() + visitor.visit(ast.parse('print()')) + assert visitor.breakpoints == [Debug(1, 0, 'print', 'called')] + + def test_returns_one_for_failing_file(tmpdir): f_py = tmpdir.join('f.py') f_py.write('def f():\n import pdb; pdb.set_trace()')