Skip to content

Commit

Permalink
add extended rearrange method
Browse files Browse the repository at this point in the history
  • Loading branch information
Ogban Ugot committed Mar 11, 2024
1 parent e03a14e commit bc2e480
Showing 1 changed file with 154 additions and 36 deletions.
190 changes: 154 additions & 36 deletions ivy_lint/formatters/function_ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
r"|ivy_tests/test_ivy/(?!.*(?:__init__\.py|conftest\.py|helpers/.*|test_frontends/config/.*$)).*)"
)

EXTENDED_FILE_PATTERN = re.compile(
r"(ivy/functional/backends/(?!.*(?:config\.py|__init__\.py)$).*"
r"|ivy/functional/stateful/(?!.*(?:config\.py|__init__\.py)$).*"
r"|ivy/functional/ivy/(?!.*(?:config\.py|__init__\.py)$).*)"
)


def class_build_dependency_graph(nodes_with_comments: List[Tuple[str, ast.AST]]) -> nx.DiGraph:
"""
Expand All @@ -28,7 +34,7 @@ def class_build_dependency_graph(nodes_with_comments: List[Tuple[str, ast.AST]])
Parameters
----------
nodes_with_comments
nodes_with_comments
A list of code nodes extracted from the source code.
Returns
Expand All @@ -53,9 +59,9 @@ def contains_any_name(code: str, names: List[str]) -> bool:
Parameters
----------
code
code
The code string to search for names in.
names
names
A list of names to search for within the code.
Returns
Expand All @@ -75,7 +81,7 @@ def extract_names_from_assignment(node: ast.Assign) -> List[str]:
Parameters
----------
node
node
The assignment node from the AST.
Returns
Expand Down Expand Up @@ -112,7 +118,7 @@ def assignment_build_dependency_graph(nodes_with_comments: List[Tuple[str, ast.A
Parameters
----------
nodes_with_comments
nodes_with_comments
A list of tuples containing source code and corresponding AST nodes.
Returns
Expand Down Expand Up @@ -148,7 +154,7 @@ def has_st_composite_decorator(node: ast.FunctionDef) -> bool:
Parameters
----------
node
node
An Abstract Syntax Tree (AST) node representing a function definition.
Returns
Expand All @@ -171,9 +177,9 @@ def related_helper_function(assignment_name: str, nodes_with_comments: List[Tupl
Parameters
----------
assignment_name
assignment_name
The name of the assignment that you want to find a related helper function for.
nodes_with_comments
nodes_with_comments
A list of tuples, where each tuple contains a string of code and the corresponding AST node with comments.
Returns
Expand All @@ -183,7 +189,7 @@ def related_helper_function(assignment_name: str, nodes_with_comments: List[Tupl
for _, node in nodes_with_comments:
if isinstance(node, (ast.FunctionDef, ast.ClassDef)) and hasattr(node, "name"):
if node.name.startswith("_") and contains_any_name(
ast.dump(node), [assignment_name]
ast.dump(node), [assignment_name]
):
return node.name
return None
Expand All @@ -197,7 +203,7 @@ def _is_assignment_target_an_attribute(node: ast.Assign) -> bool:
Parameters
----------
node
node
The assignment node being analyzed.
Returns
Expand All @@ -218,7 +224,7 @@ def _remove_existing_headers(self, source_code: str) -> str:
Parameters
----------
source_code
source_code
The original source code containing headers.
Returns
Expand All @@ -228,17 +234,17 @@ def _remove_existing_headers(self, source_code: str) -> str:
return HEADER_PATTERN.sub("", source_code)

def _extract_node_with_leading_comments(
self, node: ast.AST, source_code: str
self, node: ast.AST, source_code: str
) -> Tuple[str, ast.AST]:
"""
Extracts the portion of the source code containing the leading comments of the provided node.
It preserves the structure and leading comments for the specified node.
Parameters
----------
node
node
The node for which the leading comments need to be extracted.
source_code
source_code
The complete source code containing the specified node and comments.
Returns
Expand Down Expand Up @@ -277,7 +283,7 @@ def _extract_node_with_leading_comments(
return "\n".join(extracted_lines), node

def _extract_all_nodes_with_comments(
self, tree: ast.AST, source_code: str
self, tree: ast.AST, source_code: str
) -> List[Tuple[str, ast.AST]]:
"""
Extracts all nodes with their leading comments from the provided AST tree.
Expand All @@ -286,9 +292,9 @@ def _extract_all_nodes_with_comments(
Parameters
----------
tree
tree
The Abstract Syntax Tree (AST) representing the parsed source code.
source_code
source_code
The complete source code containing the nodes and their comments.
Returns
Expand All @@ -301,7 +307,7 @@ def _extract_all_nodes_with_comments(
for node in tree.body
]

def _rearrange_functions_and_classes(self, source_code: str) -> str:
def _rearrange_functions_and_classes(self, original_source_code: str, extended: bool) -> str:
"""
Rearranges functions and classes in the provided source code following a specific order.
Expand All @@ -310,14 +316,17 @@ def _rearrange_functions_and_classes(self, source_code: str) -> str:
Parameters
----------
source_code
original_source_code
The source code to be reordered.
extended
To call the extended rearrange method
Returns
-------
The reordered source code.
"""
source_code = self._remove_existing_headers(source_code)
source_code = self._remove_existing_headers(original_source_code)

tree = ast.parse(source_code)
nodes_with_comments = self._extract_all_nodes_with_comments(tree, source_code)
Expand Down Expand Up @@ -348,7 +357,7 @@ def _is_assignment_dependent_on_assignment(node: ast.Assign) -> bool:
Parameters
----------
node
node
The assignment node to be checked.
Returns
Expand All @@ -366,7 +375,7 @@ def _is_assignment_dependent_on_function_or_class(node: ast.Assign) -> bool:
Parameters
----------
node
node
The assignment node to be checked.
Returns
Expand Down Expand Up @@ -394,7 +403,7 @@ def sort_key(item: Tuple[str, ast.AST]) -> Tuple[float, int, str]:
Parameters
----------
item
item
The item containing source code and the associated AST node.
Returns
Expand Down Expand Up @@ -439,8 +448,8 @@ def sort_key(item: Tuple[str, ast.AST]) -> Tuple[float, int, str]:
i
for i, (_, n) in enumerate(nodes_with_comments)
if isinstance(n, (ast.FunctionDef, ast.ClassDef))
and hasattr(n, "name")
and n.name == related_function
and hasattr(n, "name")
and n.name == related_function
][0]
return (6, function_position, target_str)

Expand Down Expand Up @@ -474,10 +483,10 @@ def sort_key(item: Tuple[str, ast.AST]) -> Tuple[float, int, str]:
# Check and add module-level docstring
docstring_added = False
if (
isinstance(tree, ast.Module)
and tree.body
and isinstance(tree.body[0], ast.Expr)
and isinstance(tree.body[0].value, ast.Str)
isinstance(tree, ast.Module)
and tree.body
and isinstance(tree.body[0], ast.Expr)
and isinstance(tree.body[0].value, ast.Str)
):
docstring = ast.get_docstring(tree, clean=False)
if docstring:
Expand All @@ -495,9 +504,9 @@ def sort_key(item: Tuple[str, ast.AST]) -> Tuple[float, int, str]:
for code, node in nodes_sorted:
# If the docstring was added at the beginning, skip the node
if (
docstring_added
and isinstance(node, ast.Expr)
and isinstance(node.value, ast.Str)
docstring_added
and isinstance(node, ast.Expr)
and isinstance(node.value, ast.Str)
):
continue

Expand Down Expand Up @@ -534,6 +543,114 @@ def sort_key(item: Tuple[str, ast.AST]) -> Tuple[float, int, str]:

reordered_code = black.format_str(reordered_code, mode=black.Mode())

if extended:
return self._extended_rearrange_functions_and_classes(original_source_code, reordered_code, sort_key)
return reordered_code

def _extended_rearrange_functions_and_classes(self, original_source_code: str, reordered_code: str, sort_key) -> str:
"""
Extends the _rearrange_functions_and_classes function to cater to both ivy and stateful files.
This method utilizes the reordered code from _rearrange_functions_and_classes ensuring that classes, helpers and
assignment ordering is maintained, however it reorders the main functions keeping the organizational structure
typically found in ivy and stateful files.
Parameters
----------
original_source_code
The original source code to be reordered.
reordered_code
The reordered_code from _rearrange_functions_and_classes.
sort_key: the sort function from _rearrange_functions_and_classes
Returns
-------
The reordered source code.
"""
section_headers = [
"# Array API Standard",
"# Autograd",
"# Optimizer Steps",
"# Optimizer Updates",
"# Array Printing",
"# Device Queries",
"# Retrieval",
"# Conversions",
"# Memory",
"# Utilization",
"# Availability",
"# Default Device",
"# Device Allocation",
"# Function Splitting",
"# Profiler",
]

tree = ast.parse(original_source_code)
nodes_with_comments = self._extract_all_nodes_with_comments(tree, original_source_code)

sorted_sections = {}
current_section = []
current_header = ""
for code, node in nodes_with_comments:
if isinstance(node, ast.FunctionDef):
if node.name.startswith("_") or has_st_composite_decorator(node):
continue
for header in section_headers:
if code.strip().startswith(header):
current_header = header
current_section = []
break
current_section.append((code, node))
sorted_sections[current_header] = current_section

for header, section in sorted_sections.items():
sorted_section = sorted(section, key=sort_key)
sorted_sections[header] = sorted_section

reordered_code_list_main = []
for header, section in sorted_sections.items():
if header == "":
reordered_code_list_main.extend(code for code, _ in section)
else:
header = header.strip("#")
reordered_code_list_main.append(f"#{header}")
pattern = re.compile(rf"\s?#{re.escape(header)}\s?#?")
reordered_code_list_main.extend(pattern.sub("", code) for code, _ in section)

tree = ast.parse(reordered_code)
nodes_with_comments = self._extract_all_nodes_with_comments(tree, reordered_code)

reordered_code_list_before = []
reordered_code_list_after = []
previous_was_main = False
prev_was_assignment = False
for code, node in nodes_with_comments:
if isinstance(node, ast.Assign):
if prev_was_assignment:
code = code.strip()
if previous_was_main:
reordered_code_list_after.append(code)
else:
reordered_code_list_before.append(code)
prev_was_assignment = True
elif isinstance(node, ast.FunctionDef):
if node.name.startswith("_") or has_st_composite_decorator(node):
reordered_code_list_before.append(code)
continue
if code.strip().startswith("# --- Main"):
previous_was_main = True
else:
reordered_code_list_before.append(code)

reordered_code_list_before.extend(reordered_code_list_main)
reordered_code_list_before.extend(reordered_code_list_after)
reordered_code = "\n".join(reordered_code_list_before).strip()
if not reordered_code.endswith("\n"):
reordered_code += "\n"

reordered_code = black.format_str(reordered_code, mode=black.Mode())
return reordered_code

def _format_file(self, filename: str) -> bool:
Expand All @@ -542,24 +659,25 @@ def _format_file(self, filename: str) -> bool:
Parameters
----------
filename
filename
The path to the Python file to be formatted.
Returns
-------
True if formatting is successful, False otherwise.
"""
if FILE_PATTERN.match(filename) is None:
if FILE_PATTERN.match(filename) or EXTENDED_FILE_PATTERN.match(filename):
extended = True if EXTENDED_FILE_PATTERN.match(filename) else False
else:
return False

try:
with open(filename, "r", encoding="utf-8") as f:
original_code = f.read()

if not original_code.strip():
return False

reordered_code = self._rearrange_functions_and_classes(original_code)
reordered_code = self._rearrange_functions_and_classes(original_code, extended)

with open(filename, "w", encoding="utf-8") as f:
f.write(reordered_code)
Expand Down

0 comments on commit bc2e480

Please sign in to comment.