Skip to content

Commit 8fea9c5

Browse files
committed
Validate breaking changes in return_types.
0 parents  commit 8fea9c5

File tree

3 files changed

+263
-0
lines changed

3 files changed

+263
-0
lines changed

.github/workflows/test.yml

Whitespace-only changes.

action.yml

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: 'ComfyUI Node Validator'
2+
description: 'Validates ComfyUI custom nodes for breaking changes'
3+
branding:
4+
icon: 'check-circle'
5+
color: 'green'
6+
7+
inputs:
8+
base_ref:
9+
description: 'Base branch to compare against'
10+
required: false
11+
default: 'main'
12+
13+
runs:
14+
using: 'composite'
15+
steps:
16+
- name: Checkout PR
17+
uses: actions/checkout@v3
18+
with:
19+
path: pr_repo
20+
21+
- name: Checkout base
22+
uses: actions/checkout@v3
23+
with:
24+
ref: ${{ inputs.base_ref }}
25+
path: base_repo
26+
27+
- name: Set up Python
28+
uses: actions/setup-python@v4
29+
with:
30+
python-version: '3.10'
31+
32+
- name: Install dependencies
33+
shell: bash
34+
run: |
35+
python -m pip install --upgrade pip
36+
pip install typing-extensions
37+
38+
- name: Run validation
39+
shell: bash
40+
run: |
41+
python ${{ github.action_path }}/src/validate_nodes.py base_repo pr_repo

src/validate_nodes.py

+222
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
from typing import Type, Dict, Any, Set, Tuple
2+
import sys
3+
import importlib.util
4+
from pathlib import Path
5+
from dataclasses import dataclass
6+
from enum import Enum
7+
import os
8+
9+
class BreakingChangeType(Enum):
10+
RETURN_TYPES_CHANGED = "Return types changed"
11+
RETURN_TYPES_REORDERED = "Return types reordered"
12+
INPUT_REMOVED = "Required input removed"
13+
INPUT_TYPE_CHANGED = "Input type changed"
14+
NODE_REMOVED = "Node removed"
15+
FUNCTION_CHANGED = "Entry point function changed"
16+
17+
@dataclass
18+
class BreakingChange:
19+
node_name: str
20+
change_type: BreakingChangeType
21+
details: str
22+
base_value: Any = None
23+
pr_value: Any = None
24+
25+
def load_node_mappings(repo_path: str) -> Dict[str, Type]:
26+
"""
27+
Load NODE_CLASS_MAPPINGS from a repository's __init__.py
28+
"""
29+
init_path = os.path.join(repo_path, "__init__.py")
30+
31+
if not os.path.exists(init_path):
32+
raise FileNotFoundError(f"Could not find __init__.py in {repo_path}")
33+
34+
# Add the repo path to system path temporarily
35+
sys.path.insert(0, os.path.dirname(repo_path))
36+
37+
try:
38+
# Load the module
39+
spec = importlib.util.spec_from_file_location("module", init_path)
40+
if spec is None or spec.loader is None:
41+
raise ImportError(f"Failed to load module from {init_path}")
42+
43+
module = importlib.util.module_from_spec(spec)
44+
spec.loader.exec_module(module)
45+
46+
# Get NODE_CLASS_MAPPINGS
47+
mappings = getattr(module, "NODE_CLASS_MAPPINGS", {})
48+
if not mappings:
49+
raise AttributeError("NODE_CLASS_MAPPINGS not found in __init__.py")
50+
51+
return mappings
52+
53+
finally:
54+
# Remove the temporary path
55+
sys.path.pop(0)
56+
57+
58+
def get_node_classes(module) -> Dict[str, Type]:
59+
"""Extract node classes from module using NODE_CLASS_MAPPINGS."""
60+
return getattr(module, "NODE_CLASS_MAPPINGS", {})
61+
62+
def compare_return_types(node_name: str, base_class: Type, pr_class: Type) -> list[BreakingChange]:
63+
"""Compare RETURN_TYPES between base and PR versions of a node."""
64+
changes = []
65+
base_types = getattr(base_class, "RETURN_TYPES", tuple())
66+
pr_types = getattr(pr_class, "RETURN_TYPES", tuple())
67+
68+
if len(base_types) != len(pr_types):
69+
changes.append(BreakingChange(
70+
node_name=node_name,
71+
change_type=BreakingChangeType.RETURN_TYPES_CHANGED,
72+
details=f"Number of return types changed from {len(base_types)} to {len(pr_types)}",
73+
base_value=base_types,
74+
pr_value=pr_types
75+
))
76+
return changes
77+
78+
# Check for type changes and reordering
79+
base_types_set = set(base_types)
80+
pr_types_set = set(pr_types)
81+
82+
if base_types_set != pr_types_set:
83+
changes.append(BreakingChange(
84+
node_name=node_name,
85+
change_type=BreakingChangeType.RETURN_TYPES_CHANGED,
86+
details="Return types changed",
87+
base_value=base_types,
88+
pr_value=pr_types
89+
))
90+
elif base_types != pr_types:
91+
changes.append(BreakingChange(
92+
node_name=node_name,
93+
change_type=BreakingChangeType.RETURN_TYPES_REORDERED,
94+
details="Return types were reordered",
95+
base_value=base_types,
96+
pr_value=pr_types
97+
))
98+
99+
return changes
100+
101+
def compare_input_types(node_name: str, base_class: Type, pr_class: Type) -> list[BreakingChange]:
102+
"""Compare INPUT_TYPES between base and PR versions of a node."""
103+
changes = []
104+
105+
base_inputs = base_class.INPUT_TYPES().get("required", {})
106+
pr_inputs = pr_class.INPUT_TYPES().get("required", {})
107+
108+
# Check for removed inputs
109+
for input_name, input_config in base_inputs.items():
110+
if input_name not in pr_inputs:
111+
changes.append(BreakingChange(
112+
node_name=node_name,
113+
change_type=BreakingChangeType.INPUT_REMOVED,
114+
details=f"Required input '{input_name}' was removed",
115+
base_value=input_config,
116+
pr_value=None
117+
))
118+
continue
119+
120+
# Check input type changes
121+
if pr_inputs[input_name][0] != input_config[0]:
122+
changes.append(BreakingChange(
123+
node_name=node_name,
124+
change_type=BreakingChangeType.INPUT_TYPE_CHANGED,
125+
details=f"Input type changed for '{input_name}'",
126+
base_value=input_config[0],
127+
pr_value=pr_inputs[input_name][0]
128+
))
129+
130+
return changes
131+
132+
def compare_function(node_name: str, base_class: Type, pr_class: Type) -> list[BreakingChange]:
133+
"""Compare FUNCTION attribute between base and PR versions of a node."""
134+
changes = []
135+
136+
base_function = getattr(base_class, "FUNCTION", None)
137+
pr_function = getattr(pr_class, "FUNCTION", None)
138+
139+
if base_function != pr_function:
140+
changes.append(BreakingChange(
141+
node_name=node_name,
142+
change_type=BreakingChangeType.FUNCTION_CHANGED,
143+
details="Entry point function changed",
144+
base_value=base_function,
145+
pr_value=pr_function
146+
))
147+
148+
return changes
149+
150+
def compare_nodes(base_nodes: Dict[str, Type], pr_nodes: Dict[str, Type]) -> list[BreakingChange]:
151+
"""Compare two versions of nodes for breaking changes."""
152+
changes = []
153+
154+
# Check for removed nodes
155+
for node_name in base_nodes:
156+
if node_name not in pr_nodes:
157+
changes.append(BreakingChange(
158+
node_name=node_name,
159+
change_type=BreakingChangeType.NODE_REMOVED,
160+
details="Node was removed",
161+
))
162+
continue
163+
164+
base_class = base_nodes[node_name]
165+
pr_class = pr_nodes[node_name]
166+
167+
changes.extend(compare_return_types(node_name, base_class, pr_class))
168+
changes.extend(compare_input_types(node_name, base_class, pr_class))
169+
changes.extend(compare_function(node_name, base_class, pr_class))
170+
171+
return changes
172+
173+
def format_breaking_changes(changes: list[BreakingChange]) -> str:
174+
"""Format breaking changes into a clear error message."""
175+
if not changes:
176+
return "✅ No breaking changes detected"
177+
178+
output = ["❌ Breaking changes detected:\n"]
179+
180+
# Group changes by node
181+
changes_by_node = {}
182+
for change in changes:
183+
if change.node_name not in changes_by_node:
184+
changes_by_node[change.node_name] = []
185+
changes_by_node[change.node_name].append(change)
186+
187+
# Format each node's changes
188+
for node_name, node_changes in changes_by_node.items():
189+
output.append(f"Node: {node_name}")
190+
for change in node_changes:
191+
output.append(f" • {change.change_type.value}: {change.details}")
192+
if change.base_value is not None:
193+
output.append(f" - Base: {change.base_value}")
194+
if change.pr_value is not None:
195+
output.append(f" - PR: {change.pr_value}")
196+
output.append("")
197+
198+
return "\n".join(output)
199+
200+
def main():
201+
if len(sys.argv) != 3:
202+
print("Usage: validate_nodes.py <base_repo_path> <pr_repo_path>")
203+
sys.exit(1)
204+
205+
base_path = sys.argv[1]
206+
pr_path = sys.argv[2]
207+
208+
try:
209+
base_nodes = load_node_mappings(base_path)
210+
pr_nodes = load_node_mappings(pr_path)
211+
except Exception as e:
212+
print(f"❌ Error loading nodes: {str(e)}")
213+
sys.exit(1)
214+
215+
breaking_changes = compare_nodes(base_nodes, pr_nodes)
216+
print(format_breaking_changes(breaking_changes))
217+
218+
if breaking_changes:
219+
sys.exit(1)
220+
221+
if __name__ == "__main__":
222+
main()

0 commit comments

Comments
 (0)