Skip to content

Commit

Permalink
Backlog and new issues fixes (#21)
Browse files Browse the repository at this point in the history
1.    Fix some badges in README and docs (hopefully now for real)
2.    Typing fixes
3.   Update of pre-commit ruff version
      a. Changes necessary with this change
4.    Fix weights not correctly passed to external python functions with NURBS representation
5.    Fix change of signature for splinepy io functions
6.    Rename function `get_control_points` to `get_parameter_values`
  • Loading branch information
clemens-fricke authored Jan 16, 2025
2 parents 5cac6e1 + 78168a3 commit b223219
Show file tree
Hide file tree
Showing 21 changed files with 100 additions and 79 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/pre_commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,8 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- name: Setup python
uses: actions/setup-python@v3
with:
python-version: '3.11'
- uses: pre-commit/[email protected]
18 changes: 7 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_language_version:
python: python3.10
python: python3.11
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
rev: v5.0.0
hooks:
- id: end-of-file-fixer
- id: check-yaml
Expand All @@ -19,12 +19,8 @@ repos:
- id: detect-private-key
- id: end-of-file-fixer
- id: mixed-line-ending
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.291
hooks:
- id: ruff
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Reinforcement Learning based Shape Optimization (ReLeSO)

[![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/tataratat/releso/pypi_upload)](https://github.com/tataratat/releso)
[![pypi_upload](https://github.com/tataratat/releso/actions/workflows/build_and_upload_wheels.yml/badge.svg)](https://github.com/tataratat/releso)
[![Read the docs](https://readthedocs.org/projects/releso/badge/?version=latest)](https://releso.readthedocs.io/en/latest/?badge=latest)
[![PyPI - Version](https://img.shields.io/pypi/v/releso)](https://pypi.org/project/releso/)
[![Python Versions](https://img.shields.io/pypi/pyversions/releso)](https://pypi.org/project/releso/)
Expand All @@ -10,7 +10,7 @@ Releso is a Library/Framework for
Reinforcement Learning based Shape Optimization. Please look into the
Documentation for information on how it works. The instruction on how the
documentation can be built is given below as well as the instruction on how the
package can be installed.
package can be installed.
Alternatively, it can be installed from `pip` via `pip install releso`.


Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Reinforcement Learning based Shape Optimization (ReLeSO)
|Build Status| |Documentation Status|
|PyPI| |Python| |License|

.. |Build Status| image:: https://img.shields.io/github/actions/workflow/status/tataratat/releso/pypi_upload
.. |Build Status| image:: https://github.com/tataratat/releso/actions/workflows/build_and_upload_wheels.yml/badge.svg
:target: https://github.com/tataratat/releso
:alt: PyPI - Version

Expand Down
4 changes: 2 additions & 2 deletions examples/nutils_converging_channel_incremental.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
// "control_points": [[0.0,0.0],[0.5,0.0],[1.0,0.0],[0.0,0.2],[0.5,0.2],[1.0,0.2],[0.0,0.4],[0.5,0.4],[1.0,0.4]]
},
"mesh": {
"path": "examples/2DChannelTria.msh",
"path": "2DChannelTria.msh",
"dimensions": 2
},
"discrete_actions": true,
Expand All @@ -39,7 +39,7 @@
"run_on_reset": true,
"use_communication_interface": true,
"working_directory": "./",
"python_file_path": "examples/poiseuille_flow_channel_shear_thinning.py",
"python_file_path": "poiseuille_flow_channel_shear_thinning.py",
"add_step_information": true,
"additional_observations": 3
}
Expand Down
16 changes: 9 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ repository = "https://github.com/clemens-fricke/releso.git"

[project.optional-dependencies]
all = [
"splinepy[all]",
"splinepy[all]>=0.1.2",
"torchvision",
]
test = [
Expand All @@ -50,7 +50,7 @@ docs = [
"sphinx-jsonschema",
]
dev = [
"splinepy[all]",
"splinepy[all]>=0.1.2",
"torchvision",
"pytest",
"pytest-cov",
Expand Down Expand Up @@ -82,6 +82,10 @@ line-length = 79

[tool.ruff]
line-length = 79
preview = true
exclude = ["examples", "tests/samples"]

[tool.lint]
# Adding "ANN" would be really nice, but is to much work currently.
# Adding "PT" would be nice, for pytest naming convention.
# Adding "SIM" would be nice, but is to much work currently.
Expand All @@ -90,18 +94,16 @@ select = [
"ISC", "ICN", "PIE", "T20", "PYI", "Q", "RSE", "TID", "TCH", "INT", "PD",
"PGH", "TRY", "FLY", "NPY", "PERF", "FURB", "LOG"]
ignore = ["N818", "TRY003"]
preview = true
exclude = ["examples"]

[tool.ruff.per-file-ignores]
[tool.lint.per-file-ignores]
"tests/*.py" = ["D", "S", "E", "CPY", "PIE804", "NPY002"]
"__main__.py" = ["T20"]

[tool.ruff.pydocstyle]
[tool.lint.pydocstyle]
convention = "google"

[tool.pytest.ini_options]
addopts = "--ignore=tests/samples -W ignore::DeprecationWarning"
addopts = "--ignore=tests/samples --ignore=examples -W ignore::DeprecationWarning"
markers = [
"torch_test",
]
Expand Down
2 changes: 1 addition & 1 deletion releso/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Current version.
"""
version = "0.1.1"
version = "0.1.2"
10 changes: 5 additions & 5 deletions releso/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class BaseAgent(BaseModel):
#: with a current timestamp is also added.
tensorboard_log: Optional[str]

def get_next_tensorboard_experiment_name(self) -> str:
def get_next_tensorboard_experiment_name(self) -> Optional[str]:
"""Return tensorboard experiment name.
Adds a date and time marker to the tensorboard experiment name so that
Expand Down Expand Up @@ -231,7 +231,7 @@ def get_agent(
else:
raise AgentUnknownException(self.agent_type)

def get_next_tensorboard_experiment_name(self) -> str:
def get_next_tensorboard_experiment_name(self) -> Optional[str]:
"""Return the name of the tensorboard experiment.
The tensorboard experiment name of the original training run if given
Expand All @@ -241,7 +241,7 @@ def get_next_tensorboard_experiment_name(self) -> str:
str: tensorboard experiment name
"""
if self.tesorboard_run_directory:
return self.tesorboard_run_directory
return str(self.tesorboard_run_directory) if not None else None
if self.tensorboard_log is not None:
return super().get_next_tensorboard_experiment_name()
return None
Expand Down Expand Up @@ -324,7 +324,7 @@ def get_agent(
self.get_logger().info(f"Using agent of type {self.agent_type}.")
if normalizer_divisor == 0:
self.get_logger().warning("Normalizer divisor is 0, will use 1.")
normalizer_divisor = 1.0
normalizer_divisor = 1
self.n_steps = max(int(self.n_steps / normalizer_divisor), 1)
return A2C(env=environment, **self.get_additional_kwargs())

Expand Down Expand Up @@ -398,7 +398,7 @@ def get_agent(
self.get_logger().info(f"Using agent of type {self.agent_type}.")
if normalizer_divisor == 0:
self.get_logger().warning("Normalizer divisor is 0, will use 1.")
normalizer_divisor = 1.0
normalizer_divisor = 1
self.n_steps = max(int(self.n_steps / normalizer_divisor), 1)
return PPO(env=environment, **self.get_additional_kwargs())

Expand Down
4 changes: 2 additions & 2 deletions releso/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def save_model(self, file_name: Optional[str] = None) -> str:
path.parent.mkdir(parents=True, exist_ok=True)
save_path = path / "model_end.save"
self._agent.save(save_path)
return save_path
return str(save_path)

def evaluate_model(
self,
Expand Down Expand Up @@ -242,7 +242,7 @@ def _init():

def _create_validation_environment(
self, throw_error_if_none: bool = False
) -> Environment:
) -> Optional[Environment]:
"""Creates a validation environment.
Args:
Expand Down
6 changes: 3 additions & 3 deletions releso/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def setup(self, environment_id: UUID):
"""
self._actions = self.shape_definition.get_actions()

def get_control_points(self) -> List[List[float]]:
def get_parameter_values(self) -> List[List[float]]:
"""Return all control_points of the spline.
Returns:
List[List[float]]: Nested list of control_points.
"""
return self.shape_definition.get_control_points()
return self.shape_definition.get_parameter_values()

def apply_action(self, action: Union[List[float], int]) -> Optional[Any]:
"""Function that applies a given action to the Spline.
Expand Down Expand Up @@ -117,7 +117,7 @@ def apply(self) -> Optional[Any]:
Returns:
List[List[float]]: Current control points of the shape.
"""
return self.get_control_points()
return self.get_parameter_values()

def is_geometry_changed(self) -> bool:
"""Checks if the geometry was changed with the previous action apply.
Expand Down
8 changes: 4 additions & 4 deletions releso/gym_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
fix for this problem is currently worked on.
"""

from typing import Any, Dict, List, Tuple, Union
from typing import Any, List, Union

import gymnasium

from releso.util.types import StepReturnType


class GymEnvironment(gymnasium.Env):
"""Environment interface class for the gym environment definition.
Expand All @@ -54,9 +56,7 @@ def __init__(self, action_space, observation_space) -> None:
self.action_space = action_space
self.observation_space = observation_space

def step(
self, action: Union[int, List[float]]
) -> Tuple[Any, float, bool, Dict[str, Any]]:
def step(self, action: Union[int, List[float]]) -> StepReturnType:
"""Dummy function definition for gym interface."""

def reset(self) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion releso/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def export_mesh(self, mesh: GustafMeshTypes, space_time: bool = False):
RuntimeError: _description_
"""
if self.mesh_format == "mixd":
mixd.export(mesh, self._export_path_changed, space_time=space_time)
mixd.export(self._export_path_changed, mesh, space_time=space_time)
else:
raise RuntimeError(
f"The requested format {self.mesh_format} is not supported."
Expand Down
9 changes: 5 additions & 4 deletions releso/shape_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,16 @@ class ShapeDefinition(BaseModel):
"""Base of shape parameterization, also represents a simple point cloud."""

#: control_points of the shape. These are the base variables used for the
#: optimization. Overwrite `get_actions` and `get_control_points` if
#: optimization. Overwrite `get_actions` and `get_parameter_values` if
#: additional optimization variables are needed. See (WIP) NURBSDefinition.
control_points: List[List[VariableLocation]]

def get_number_of_points(self) -> int:
"""Returns the number of points in the Cube.
Number of control points multiplied by the number of dimensions for
each cp.
each cp. Assumes that all dimensions have the same number of control
points.
Returns:
int: number of points in the geometry
Expand Down Expand Up @@ -297,7 +298,7 @@ def convert_all_control_point_locations_to_variable_locations(
) from None
return new_list

def get_control_points(self) -> List[List[float]]:
def get_parameter_values(self) -> List[List[float]]:
"""Returns the current positions of all control points.
Returns:
Expand Down Expand Up @@ -336,7 +337,7 @@ def get_shape(self) -> Any:
Returns:
Any: Shape that is generated.
"""
return self.get_control_points()
return self.get_parameter_values()

def reset(self) -> None:
"""Resets the shape to the original values."""
Expand Down
20 changes: 17 additions & 3 deletions releso/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def get_shape(self) -> BSpline:
"""
self.get_logger().debug("Creating BSpline.")
self.get_logger().debug(
f"With control_points: {self.get_control_points()}"
f"With control_points: {self.get_parameter_values()}"
)

return BSpline(
Expand All @@ -309,7 +309,7 @@ def get_shape(self) -> BSpline:
space_dim.get_knot_vector()
for space_dim in self.space_dimensions
],
self.get_control_points(),
self.get_parameter_values(),
)


Expand Down Expand Up @@ -430,7 +430,7 @@ def get_shape(self) -> NURBS:
space_dim.get_knot_vector()
for space_dim in self.space_dimensions
],
self.get_control_points(),
self.get_parameter_values()[:-1],
self.get_weights(),
)

Expand All @@ -448,6 +448,20 @@ def get_actions(self) -> List[VariableLocation]:
# raise RuntimeError(f"Actions: {actions}")
return actions


def get_parameter_values(self) -> List[List[float]]:
"""Returns the current positions of all control points with weights as
well.
Returns:
List[List[float]]: Positions of all control points.
"""
control_points = super().get_parameter_values()
control_points.append(
[weight.current_position for weight in self.weights]
)
return control_points

def reset(self) -> None:
"""Resets the spline to the original shape."""
super().reset()
Expand Down
13 changes: 8 additions & 5 deletions releso/util/util_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def which(
def is_exe(fpath):
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)

fpath, fname = os.path.split(program)
fpath, _ = os.path.split(program)
if fpath:
if is_exe(program):
return program
Expand Down Expand Up @@ -136,17 +136,20 @@ def get_path_extension() -> str:
# check if slurm task array is running
if (
os.getenv("SLURM_ARRAY_TASK_COUNT")
and int(os.getenv("SLURM_ARRAY_TASK_COUNT")) > 1
and (
(tmp_var := os.getenv("SLURM_ARRAY_TASK_COUNT")) is not None
and int(tmp_var) > 1
)
and os.getenv("SLURM_ARRAY_JOB_ID")
and os.getenv("SLURM_ARRAY_TASK_ID")
):
ret_str = (
os.getenv("SLURM_ARRAY_JOB_ID")
str(os.getenv("SLURM_ARRAY_JOB_ID"))
+ "/"
+ os.getenv("SLURM_ARRAY_TASK_ID")
+ str(os.getenv("SLURM_ARRAY_TASK_ID"))
+ "_"
+ ret_str
)
elif os.getenv("SLURM_JOB_ID"): # default slurm job running
ret_str += "_" + os.getenv("SLURM_JOB_ID")
ret_str += "_" + str(os.getenv("SLURM_JOB_ID"))
return ret_str
3 changes: 3 additions & 0 deletions tests/test_base_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Union
import pathlib

import pytest

Expand Down Expand Up @@ -130,6 +131,8 @@ def test_base_parser_validation(
)

for file_path in file_paths:
print(file_path)
file_path = pathlib.Path(file_path)
assert file_path.exists()
clean_up_provider(
file_path.parent
Expand Down
Loading

0 comments on commit b223219

Please sign in to comment.