diff --git a/pre_commit_hooks/forbid_articles_in_test_filenames.py b/pre_commit_hooks/forbid_articles_in_test_filenames.py new file mode 100755 index 00000000..0c8fd609 --- /dev/null +++ b/pre_commit_hooks/forbid_articles_in_test_filenames.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path + +FORBIDDEN = {'a', 'an', 'the'} + + +def git_ls_python_files(): + result = subprocess.run( + ['git', 'ls-files', '*.py'], + capture_output=True, + text=True, + check=True, + ) + return [Path(p) for p in result.stdout.splitlines()] + + +def is_test_file(path: Path) -> bool: + name = path.name + return ( + name.startswith('test_') or + name.startswith('tests_') or + name.endswith('_test.py') + ) + + +def has_forbidden_article(path: Path) -> bool: + parts = path.stem.split('_') + return any(part in FORBIDDEN for part in parts) + + +def main() -> int: + for path in git_ls_python_files(): + if not is_test_file(path): + continue + + if has_forbidden_article(path): + print('ERROR: Forbidden article in test filename:') + print(path) + return 1 + + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/pre_commit_hooks/tests_should_end_in_test.py b/pre_commit_hooks/tests_should_end_in_test.py index 07af277d..9f3733f6 100644 --- a/pre_commit_hooks/tests_should_end_in_test.py +++ b/pre_commit_hooks/tests_should_end_in_test.py @@ -14,38 +14,22 @@ def main(argv: Sequence[str] | None = None) -> int: '--pytest', dest='pattern', action='store_const', - const=r'.*_test\.py', - default=r'.*_test\.py', + const=r'^tests\/(?:[a-zA-Z0-9_]+\/)*tests_[a-zA-Z0-9_]*\.py$', + default=r'^tests\/(?:[a-zA-Z0-9_]+\/)*tests_[a-zA-Z0-9_]*\.py$', help='(the default) ensure tests match %(const)s', ) - mutex.add_argument( - '--pytest-test-first', - dest='pattern', - action='store_const', - const=r'test_.*\.py', - help='ensure tests match %(const)s', - ) - mutex.add_argument( - '--django', '--unittest', - dest='pattern', - action='store_const', - const=r'test.*\.py', - help='ensure tests match %(const)s', - ) args = parser.parse_args(argv) - retcode = 0 reg = re.compile(args.pattern) for filename in args.filenames: base = os.path.basename(filename) - if ( - not reg.fullmatch(base) and - not base == '__init__.py' and - not base == 'conftest.py' - ): + # Check for files that should be ignored + if base in ('__init__.py', 'conftest.py', 'models.py'): + continue + # Raise an exception if filename doesn't start with 'tests_' and doesn't match the pattern + if not reg.fullmatch(filename): retcode = 1 print(f'{filename} does not match pattern "{args.pattern}"') - return retcode