|
| 1 | +"""Re-implementation of :func:`black.format_str` as a line generator""" |
| 2 | + |
| 3 | +from typing import Generator |
| 4 | +from black import get_future_imports, detect_target_versions, decode_bytes |
| 5 | +from black.lines import Line, EmptyLineTracker |
| 6 | +from black.linegen import transform_line, LineGenerator |
| 7 | +from black.comments import normalize_fmt_off |
| 8 | +from black.mode import Mode |
| 9 | +from black.mode import Feature, supports_feature |
| 10 | +from black.parsing import lib2to3_parse |
| 11 | + |
| 12 | + |
| 13 | +def format_str_to_lines( |
| 14 | + src_contents: str, *, mode: Mode |
| 15 | +) -> Generator[str, None, None]: # pylint: disable=too-many-locals |
| 16 | + """Reformat a string and yield each line of new contents |
| 17 | +
|
| 18 | + This is a re-implementation of :func:`black.format_str` modified to be a generator |
| 19 | + which yields each resulting line instead of concatenating them into a single string. |
| 20 | +
|
| 21 | + """ |
| 22 | + src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) |
| 23 | + future_imports = get_future_imports(src_node) |
| 24 | + if mode.target_versions: |
| 25 | + versions = mode.target_versions |
| 26 | + else: |
| 27 | + versions = detect_target_versions(src_node) |
| 28 | + normalize_fmt_off(src_node) |
| 29 | + lines = LineGenerator( |
| 30 | + mode=mode, |
| 31 | + remove_u_prefix="unicode_literals" in future_imports |
| 32 | + or supports_feature(versions, Feature.UNICODE_LITERALS), |
| 33 | + ) |
| 34 | + elt = EmptyLineTracker(is_pyi=mode.is_pyi) |
| 35 | + empty_line = str(Line(mode=mode)) |
| 36 | + empty_line_len = len(empty_line) |
| 37 | + after = 0 |
| 38 | + split_line_features = { |
| 39 | + feature |
| 40 | + for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF} |
| 41 | + if supports_feature(versions, feature) |
| 42 | + } |
| 43 | + num_chars = 0 |
| 44 | + for current_line in lines.visit(src_node): |
| 45 | + for _ in range(after): |
| 46 | + yield empty_line |
| 47 | + num_chars += after * empty_line_len |
| 48 | + before, after = elt.maybe_empty_lines(current_line) |
| 49 | + for _ in range(before): |
| 50 | + yield empty_line |
| 51 | + num_chars += before * empty_line_len |
| 52 | + for line in transform_line( |
| 53 | + current_line, mode=mode, features=split_line_features |
| 54 | + ): |
| 55 | + line_str = str(line) |
| 56 | + yield line_str |
| 57 | + num_chars += len(line_str) |
| 58 | + if not num_chars: |
| 59 | + normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8")) |
| 60 | + if "\n" in normalized_content: |
| 61 | + yield newline |
0 commit comments