Skip to content

Commit

Permalink
create TemplatedStep duplicate method (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier authored Feb 20, 2025
1 parent c0b8c44 commit 1244c48
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 66 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**0.1.4 - 2/18/25**

- Implement duplicate_template_step method on TemplatedStep class

**0.1.3 - 1/7/25**

- Validate currently-installed python version during setup
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"docker",
"graphviz",
"loguru",
"layered_config_tree",
"layered_config_tree>=3.0.0",
"networkx",
"pandas",
"pyyaml",
Expand All @@ -67,7 +67,7 @@
"pytest-mock",
]
doc_requirements = [
"sphinx",
"sphinx<8.2.0",
"sphinx-rtd-theme",
"sphinx-autodoc-typehints",
"sphinx-click",
Expand Down
26 changes: 20 additions & 6 deletions src/easylink/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,24 @@ def set_configuration_state(
parent_config, combined_implementations, input_data_config
)

def _duplicate_template_step(self) -> Step:
"""Makes a duplicate of the template ``Step``.
Returns
-------
A duplicate of the :attr:`template_step`.
Notes
-----
A naive deepcopy would also make a copy of the :attr:`Step.parent_step`; we don't
want this to be pointing to a *copy* of `self`, but rather to the original.
We thus re-set the :attr:`Step.parent_step` to the original (`self`) after making
the copy.
"""
step_copy = copy.deepcopy(self.template_step)
step_copy.set_parent_step(self)
return step_copy


class LoopStep(TemplatedStep):
"""A type of :class:`TemplatedStep` that allows for looping.
Expand Down Expand Up @@ -875,9 +893,7 @@ def _update_step_graph(self, num_repeats) -> StepGraph:
edges = []

for i in range(num_repeats):
self.template_step.parent_step = None
updated_step = copy.deepcopy(self.template_step)
updated_step.set_parent_step(self)
updated_step = self._duplicate_template_step()
updated_step.name = f"{self.name}_{self.node_prefix}_{i+1}"
nodes.append(updated_step)
if i > 0:
Expand Down Expand Up @@ -973,9 +989,7 @@ def _update_step_graph(self, num_repeats: int) -> StepGraph:
graph = StepGraph()

for i in range(num_repeats):
self.template_step.parent_step = None
updated_step = copy.deepcopy(self.template_step)
updated_step.set_parent_step(self)
updated_step = self._duplicate_template_step()
updated_step.name = f"{self.name}_{self.node_prefix}_{i+1}"
graph.add_node_from_step(updated_step)
return graph
Expand Down
34 changes: 17 additions & 17 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_environment_configuration_not_found(computing_environment):


@pytest.mark.parametrize(
"key, input",
"key, value",
[
("computing_environment", None),
("computing_environment", "local"),
Expand All @@ -114,20 +114,20 @@ def test_environment_configuration_not_found(computing_environment):
("container_engine", "undefined"),
],
)
def test_required_attributes(mocker, default_config_params, key, input):
def test_required_attributes(mocker, default_config_params, key, value):
mocker.patch("easylink.configuration.Config._validate_environment")
config_params = default_config_params
if input:
config_params["environment"][key] = input
env_dict = {key: input} if input else {}
if value:
config_params["environment"][key] = value
env_dict = {key: value} if value else {}
retrieved = Config(config_params).environment[key]
expected = DEFAULT_ENVIRONMENT["environment"].copy()
expected.update(env_dict)
assert retrieved == expected[key]


@pytest.mark.parametrize(
"input",
"resource_request",
[
# missing
None,
Expand All @@ -137,22 +137,22 @@ def test_required_attributes(mocker, default_config_params, key, input):
{"memory": 100, "cpus": 200, "time_limit": 300},
],
)
def test_implementation_resource_requests(default_config_params, input):
def test_implementation_resource_requests(default_config_params, resource_request):
key = "implementation_resources"
config_params = default_config_params
if input:
config_params["environment"][key] = input
if resource_request:
config_params["environment"][key] = resource_request
config = Config(config_params)
env_dict = {key: input.copy()} if input else {}
env_dict = {key: resource_request.copy()} if resource_request else {}
retrieved = config.environment[key].to_dict()
expected = DEFAULT_ENVIRONMENT["environment"][key].copy()
if input:
if resource_request:
expected.update(env_dict[key])
assert retrieved == expected


@pytest.mark.parametrize(
"input",
"spark_request",
[
# missing
None,
Expand All @@ -179,7 +179,7 @@ def test_implementation_resource_requests(default_config_params, input):
],
)
@pytest.mark.parametrize("requires_spark", [True, False])
def test_spark_requests(default_config_params, input, requires_spark):
def test_spark_requests(default_config_params, spark_request, requires_spark):
key = "spark"
config_params = default_config_params
if requires_spark:
Expand All @@ -188,12 +188,12 @@ def test_spark_requests(default_config_params, input, requires_spark):
"name"
] = "step_1_python_pyspark"

if input:
config_params["environment"][key] = input
if spark_request:
config_params["environment"][key] = spark_request
retrieved = Config(config_params).environment[key].to_dict()
expected_env_dict = {key: input.copy()} if input else {}
expected_env_dict = {key: spark_request.copy()} if spark_request else {}
expected = LayeredConfigTree(SPARK_DEFAULTS, layers=["initial_data", "user"])
if input:
if spark_request:
expected.update(expected_env_dict[key], layer="user")
expected = expected.to_dict()
assert retrieved == expected
Expand Down
149 changes: 108 additions & 41 deletions tests/unit/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,35 +372,31 @@ def test_loop_get_implementation_graph(
{
"step_3": {
"iterate": [
LayeredConfigTree(
{
"implementation": {
"name": "step_3_python_pandas",
"configuration": {},
}
}
),
LayeredConfigTree(
{
"substeps": {
"step_3a": {
"implementation": {
"name": "step_3a_python_pandas",
"configuration": {},
}
{
"implementation": {
"name": "step_3_python_pandas",
"configuration": {},
},
},
{
"substeps": {
"step_3a": {
"implementation": {
"name": "step_3a_python_pandas",
"configuration": {},
},
"step_3b": {
"implementation": {
"name": "step_3b_python_pandas",
"configuration": {},
}
},
"step_3b": {
"implementation": {
"name": "step_3b_python_pandas",
"configuration": {},
},
},
}
),
},
},
],
}
}
},
},
)
step.set_configuration_state(pipeline_params, {}, {})
subgraph = step.get_implementation_graph()
Expand Down Expand Up @@ -574,8 +570,8 @@ def test_parallel_step_get_implementation_graph(
"input_data_file": "input_file_3",
},
],
}
}
},
},
)
step.set_configuration_state(pipeline_params, {}, {})
subgraph = step.get_implementation_graph()
Expand Down Expand Up @@ -633,6 +629,81 @@ def test_parallel_step_get_implementation_graph(
assert edge in subgraph.edges(data=True)


@pytest.mark.parametrize("step_type", ["parallel", "loop"])
def test__duplicate_template_step(step_type):
"""Test against _duplicate_template_step.
This is not an exhaustive test due to the complicated nature of testing
equality of different attributes between two deep copies. For example,
SomeClass.foo == SomeOtherClass.foo if foo is the same string even though
they are different objects in memory. This is not the case for methods as
well as other attribute types.
"""
template_step = Step("step")
template_step.set_configuration_state(
LayeredConfigTree(
{
"step": {
"implementation": {
"name": "step_implementation",
"configuration": {},
},
},
},
),
{},
{},
)
if step_type == "loop":
step = LoopStep(template_step)
config_key = "iterate"
else: # parellel
step = ParallelStep(template_step)
config_key = "parallel"
step.set_configuration_state(
LayeredConfigTree(
{
"step": {
config_key: [
{
"implementation": {
"name": "step_implementation",
"configuration": {},
},
},
],
},
},
),
{},
{},
)

duplicate_template_step = step._duplicate_template_step()

special_handle_attrs = ["_configuration_state", "configuration_state", "step_graph"]
attrs = [
attr
for attr in dir(step.template_step)
if attr not in special_handle_attrs
# dunders are too complicated to check
and not attr.startswith("__")
# methods are bound to each instance
and not callable(getattr(step.template_step, attr))
]
for attr in attrs:
assert getattr(step.template_step, attr) == getattr(duplicate_template_step, attr)

# Handle the special cases. Just check that they are of the same type and yet
# not equal (implying that they are bound to different instances)
for attr in special_handle_attrs:
assert isinstance(
getattr(step.template_step, attr),
type(getattr(duplicate_template_step, attr)),
)
assert getattr(step.template_step, attr) != getattr(duplicate_template_step, attr)


@pytest.fixture
def choice_step_params() -> dict[str, Any]:
return {
Expand Down Expand Up @@ -926,20 +997,16 @@ def test_complex_choice_step_get_implementation_graph(
},
"step_6": {
"iterate": [
LayeredConfigTree(
{
"implementation": {
"name": "step_6_python_pandas",
}
}
),
LayeredConfigTree(
{
"implementation": {
"name": "step_6_python_pandas",
}
}
),
{
"implementation": {
"name": "step_6_python_pandas",
},
},
{
"implementation": {
"name": "step_6_python_pandas",
},
},
],
},
},
Expand Down

0 comments on commit 1244c48

Please sign in to comment.