1
1
"""Re-implementation of :func:`black.format_str` as a line generator"""
2
2
3
- from typing import Generator
3
+ from typing import Generator , List
4
4
from black import get_future_imports , detect_target_versions , decode_bytes
5
5
from black .lines import Line , EmptyLineTracker
6
6
from black .linegen import transform_line , LineGenerator
10
10
from black .parsing import lib2to3_parse
11
11
12
12
13
- def format_str_to_lines (
13
+ def format_str_to_chunks ( # pylint: disable=too-many-locals
14
14
src_contents : str , * , mode : Mode
15
- ) -> Generator [str , None , None ]: # pylint: disable=too-many-locals
15
+ ) -> Generator [List [ str ] , None , None ]:
16
16
"""Reformat a string and yield each line of new contents
17
17
18
18
This is a re-implementation of :func:`black.format_str` modified to be a generator
19
- which yields each resulting line instead of concatenating them into a single string.
19
+ which yields each resulting chunk as a list of lines instead of concatenating them
20
+ into a single string.
20
21
21
22
"""
22
23
src_node = lib2to3_parse (src_contents .lstrip (), mode .target_versions )
@@ -42,20 +43,22 @@ def format_str_to_lines(
42
43
}
43
44
num_chars = 0
44
45
for current_line in lines .visit (src_node ):
45
- for _ in range ( after ) :
46
- yield empty_line
47
- num_chars += after * empty_line_len
46
+ if after :
47
+ yield after * [ empty_line ]
48
+ num_chars += after * empty_line_len
48
49
before , after = elt .maybe_empty_lines (current_line )
49
- for _ in range (before ):
50
- yield empty_line
51
- num_chars += before * empty_line_len
52
- for line in transform_line (
53
- current_line , mode = mode , features = split_line_features
54
- ):
55
- line_str = str (line )
56
- yield line_str
57
- num_chars += len (line_str )
50
+ if before :
51
+ yield before * [empty_line ]
52
+ num_chars += before * empty_line_len
53
+ lines = [
54
+ str (line )
55
+ for line in transform_line (
56
+ current_line , mode = mode , features = split_line_features
57
+ )
58
+ ]
59
+ yield lines
60
+ num_chars += sum (len (line ) for line in lines )
58
61
if not num_chars :
59
62
normalized_content , _ , newline = decode_bytes (src_contents .encode ("utf-8" ))
60
63
if "\n " in normalized_content :
61
- yield newline
64
+ yield [ newline ]
0 commit comments