Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions src/docformatter/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
# SOFTWARE.
"""This module provides docformatter's Formattor class."""


# Standard Library Imports
import argparse
import collections
Expand Down Expand Up @@ -345,32 +344,20 @@ def _get_function_docstring_newlines( # noqa: PLR0911
return 0


def _get_module_docstring_newlines(black: bool = False) -> int:
def _get_module_docstring_newlines() -> int:
"""Return number of newlines after a module docstring.

docformatter_8.2: One blank line after a module docstring.
docformatter_8.2.1: Two blank lines after a module docstring when in black mode.

Parameters
----------
black : bool
Indicates whether we're using black formatting rules.

Returns
-------
newlines : int
The number of newlines to insert after the docstring.
"""
if black:
return 2

return 1


def _get_newlines_by_type(
tokens: list[tokenize.TokenInfo],
index: int,
black: bool = False,
) -> int:
"""Dispatch to the correct docstring formatter based on context.

Expand All @@ -395,7 +382,7 @@ def _get_newlines_by_type(
return 0
elif _classify.is_module_docstring(tokens, index):
# print("Module")
return _get_module_docstring_newlines(black)
return _get_module_docstring_newlines()
elif _classify.is_class_docstring(tokens, index):
# print("Class")
return _get_class_docstring_newlines(tokens, index)
Expand Down Expand Up @@ -980,9 +967,7 @@ def _do_format_oneline_docstring(
).strip()
if self.args.close_quotes_on_newline and "\n" in summary_wrapped:
summary_wrapped = (
f"{summary_wrapped[:-3]}"
f"\n{indentation}"
f"{summary_wrapped[-3:]}"
f"{summary_wrapped[:-3]}\n{indentation}{summary_wrapped[-3:]}"
)
return summary_wrapped

Expand Down Expand Up @@ -1069,7 +1054,8 @@ def _do_rewrite_docstring_blocks(

_docstring_token = tokens[_docstr_idx]
_blank_line_count = _get_newlines_by_type(
tokens, _docstr_idx, black=self.args.black
tokens,
_docstr_idx,
)

if (
Expand Down
11 changes: 11 additions & 0 deletions tests/_data/string_files/do_format_code.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1233,3 +1233,14 @@ expected="""foo = f'''
bar
'''
"""

[issue_331_black_module_docstring]
source='''"""A."""


pass
'''
expected='''"""A."""

pass
'''
2 changes: 1 addition & 1 deletion tests/_data/string_files/format_functions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ expected = 1
expected = 1

[module_docstring_in_black]
expected = 2
expected = 1

[class_docstring_followed_by_statement]
source = '''
Expand Down
1 change: 1 addition & 0 deletions tests/formatter/test_do_format_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
("ellipses_is_code_line", NO_ARGS),
("do_not_break_f_string_double_quotes", NO_ARGS),
("do_not_break_f_string_single_quotes", NO_ARGS),
("issue_331_black_module_docstring", ["--black", ""]),
],
)
def test_do_format_code(test_key, test_args, args):
Expand Down
32 changes: 16 additions & 16 deletions tests/formatter/test_format_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,19 @@ def _get_docstring_token_and_index(tokens):

@pytest.mark.unit
@pytest.mark.parametrize(
"test_key, black",
"test_key",
[
("module_docstring_followed_by_string", False),
("module_docstring_followed_by_code", False),
("module_docstring_followed_by_comment_then_code", False),
("module_docstring_followed_by_comment_then_string", False),
("module_docstring_in_black", True),
"module_docstring_followed_by_string",
"module_docstring_followed_by_code",
"module_docstring_followed_by_comment_then_code",
"module_docstring_followed_by_comment_then_string",
"module_docstring_in_black",
],
)
def test_module_docstring_newlines(test_key, black):
def test_module_docstring_newlines(test_key):
expected = TEST_STRINGS[test_key]["expected"]

result = _format._get_module_docstring_newlines(black)
result = _format._get_module_docstring_newlines()
assert (
result == expected
), f"\nFailed {test_key}:\nExpected {expected}\nGot {result}"
Expand Down Expand Up @@ -221,23 +221,23 @@ def test_do_remove_preceding_blank_lines(test_key, block):
@pytest.mark.integration
@pytest.mark.order(5)
@pytest.mark.parametrize(
"test_key, black",
"test_key",
[
("get_newlines_by_type_module_docstring", False),
("get_newlines_by_type_module_docstring_black", True),
("get_newlines_by_type_class_docstring", False),
("get_newlines_by_type_function_docstring", False),
("get_newlines_by_type_attribute_docstring", False),
"get_newlines_by_type_module_docstring",
"get_newlines_by_type_module_docstring_black",
"get_newlines_by_type_class_docstring",
"get_newlines_by_type_function_docstring",
"get_newlines_by_type_attribute_docstring",
],
)
def test_get_newlines_by_type(test_key, black):
def test_get_newlines_by_type(test_key):
source = TEST_STRINGS[test_key]["source"]
expected = TEST_STRINGS[test_key]["expected"]

tokens = _get_tokens(source)
index = _get_docstring_token_and_index(tokens)

result = _format._get_newlines_by_type(tokens, index, black)
result = _format._get_newlines_by_type(tokens, index)
assert result == expected, f"\nFailed {test_key}\nExpected {expected}\nGot {result}"


Expand Down