Skip to content

Commit 3b0d4db

Browse files
better type hints
1 parent 68fa1f3 commit 3b0d4db

File tree

10 files changed

+338
-231
lines changed

10 files changed

+338
-231
lines changed

examples/casadi_functions.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@
2929
N1 = metanet.Node(name="N1")
3030
N2 = metanet.Node(name="N2")
3131
N3 = metanet.Node(name="N3")
32-
L1 = metanet.Link(2, lanes, L, rho_max, rho_crit_sym, v_free_sym, a_sym, name="L1")
33-
L2 = metanet.Link(1, lanes, L, rho_max, rho_crit_sym, v_free_sym, a_sym, name="L2")
34-
O1 = metanet.MeteredOnRamp(C[0], name="O1")
35-
O2 = metanet.SimpleMeteredOnRamp(C[1], name="O2")
36-
D3 = metanet.CongestedDestination(name="D3")
32+
L1 = metanet.Link[cs.SX](
33+
2, lanes, L, rho_max, rho_crit_sym, v_free_sym, a_sym, name="L1"
34+
)
35+
L2 = metanet.Link[cs.SX](
36+
1, lanes, L, rho_max, rho_crit_sym, v_free_sym, a_sym, name="L2"
37+
)
38+
O1 = metanet.MeteredOnRamp[cs.SX](C[0], name="O1")
39+
O2 = metanet.SimpleMeteredOnRamp[cs.SX](C[1], name="O2")
40+
D3 = metanet.CongestedDestination[cs.SX](name="D3")
3741

3842

3943
# build and validate network

src/sym_metanet/blocks/base.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from abc import ABC, abstractmethod
22
from itertools import count
3-
from typing import Dict, Generic, Optional, Set, TypeVar, ClassVar
3+
from typing import ClassVar, Dict, Generic, Optional, Set
4+
5+
from sym_metanet.util.types import VarType
46

57

68
class ElementBase:
@@ -33,15 +35,7 @@ def __repr__(self) -> str:
3335
return f"<{self.name}: {self.__class__.__name__}>"
3436

3537

36-
sym_var = TypeVar("sym_var")
37-
sym_var.__doc__ = (
38-
"Variable that can also be numerical or symbolic, "
39-
"depending on the engine. Should be indexable as an array "
40-
"in case of vector quantities."
41-
)
42-
43-
44-
class ElementWithVars(ElementBase, Generic[sym_var], ABC):
38+
class ElementWithVars(ElementBase, Generic[VarType], ABC):
4539
"""Base class for any element with states, actions or disturbances."""
4640

4741
__slots__ = ("states", "next_states", "actions", "disturbances")
@@ -59,10 +53,10 @@ def __init__(self, name: Optional[str] = None) -> None:
5953
of the class' instancies.
6054
"""
6155
super().__init__(name=name)
62-
self.states: Optional[Dict[str, sym_var]] = None
63-
self.next_states: Optional[Dict[str, sym_var]] = None
64-
self.actions: Optional[Dict[str, sym_var]] = None
65-
self.disturbances: Optional[Dict[str, sym_var]] = None
56+
self.states: Optional[Dict[str, VarType]] = None
57+
self.next_states: Optional[Dict[str, VarType]] = None
58+
self.actions: Optional[Dict[str, VarType]] = None
59+
self.disturbances: Optional[Dict[str, VarType]] = None
6660

6761
@property
6862
def has_states(self) -> bool:
@@ -94,12 +88,12 @@ def init_vars(self, *args, **kwargs) -> None:
9488
)
9589

9690
@abstractmethod
97-
def step_dynamics(self, *args, **kwargs) -> Dict[str, sym_var]:
91+
def step_dynamics(self, *args, **kwargs) -> Dict[str, VarType]:
9892
"""Internal method for stepping the element's dynamics by one time step.
9993
10094
Returns
10195
-------
102-
Dict[str, sym_var]
96+
Dict[str, VarType]
10397
A dict with the states at the next time step.
10498
10599
Raises

src/sym_metanet/blocks/destinations.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,28 @@
1-
from typing import TYPE_CHECKING, Dict, Optional
1+
from typing import TYPE_CHECKING, Collection, Dict, Optional, Tuple
22

3-
from sym_metanet.blocks.base import ElementWithVars, sym_var
3+
from sym_metanet.blocks.base import ElementWithVars
44
from sym_metanet.engines.core import EngineBase, get_current_engine
55
from sym_metanet.util.funcs import first
6+
from sym_metanet.util.types import VarType
67

78
if TYPE_CHECKING:
89
from sym_metanet.blocks.links import Link
10+
from sym_metanet.blocks.nodes import Node
911
from sym_metanet.network import Network
1012

1113

12-
class Destination(ElementWithVars[sym_var]):
14+
class Destination(ElementWithVars[VarType]):
1315
"""Ideal congestion-free destination, representing a sink where cars can leave the
1416
highway with no congestion (i.e., no slowing down due to downstream density)."""
1517

1618
def init_vars(self, *args, **kwargs) -> None:
1719
"""Initializes no variable in the ideal destination."""
1820

19-
def step_dynamics(self, *args, **kwargs) -> Dict[str, sym_var]:
21+
def step_dynamics(self, *args, **kwargs) -> Dict[str, VarType]:
2022
"""No dynamics to steps in the ideal destination."""
2123
return {}
2224

23-
def get_density(self, net: "Network", **kwargs) -> sym_var:
25+
def get_density(self, net: "Network", **kwargs) -> VarType:
2426
"""Computes the (downstream) density induced by the ideal destination.
2527
2628
Parameters
@@ -30,22 +32,24 @@ def get_density(self, net: "Network", **kwargs) -> sym_var:
3032
3133
Returns
3234
-------
33-
sym_var
35+
symbolic variable
3436
The destination's downstream density.
3537
"""
3638
return self._get_entering_link(net=net).states["rho"][-1]
3739

38-
def _get_entering_link(self, net: "Network") -> "Link":
40+
def _get_entering_link(self, net: "Network") -> "Link[VarType]":
3941
"""Internal utility to fetch the link entering this destination (can only be
4042
one)."""
41-
links_up = net.in_links(net.destinations[self])
43+
links_up: Collection[Tuple["Node", "Node", "Link[VarType]"]] = net.in_links(
44+
net.destinations[self] # type: ignore[index]
45+
)
4246
assert (
4347
len(links_up) == 1
4448
), "Internal error. Only one link can enter a destination."
4549
return first(links_up)[-1]
4650

4751

48-
class CongestedDestination(Destination[sym_var]):
52+
class CongestedDestination(Destination[VarType]):
4953
"""Destination with a downstream density scenario to emulate congestions, that is,
5054
cars cannot exit freely the highway but must slow down and, possibly, create a
5155
congestion."""
@@ -54,7 +58,7 @@ class CongestedDestination(Destination[sym_var]):
5458

5559
def init_vars(
5660
self,
57-
init_conditions: Optional[Dict[str, sym_var]] = None,
61+
init_conditions: Optional[Dict[str, VarType]] = None,
5862
engine: Optional[EngineBase] = None,
5963
) -> None:
6064
"""Initializes
@@ -74,15 +78,15 @@ def init_vars(
7478
"""
7579
if engine is None:
7680
engine = get_current_engine()
77-
self.disturbances: Dict[str, sym_var] = {
81+
self.disturbances: Dict[str, VarType] = {
7882
"d": engine.var(f"d_{self.name}")
7983
if init_conditions is None or "d" not in init_conditions
8084
else init_conditions["d"]
8185
}
8286

8387
def get_density(
8488
self, net: "Network", engine: Optional[EngineBase] = None, **kwargs
85-
) -> sym_var:
89+
) -> VarType:
8690
"""Computes the (downstream) density induced by the congested destination.
8791
8892
Parameters
@@ -94,7 +98,7 @@ def get_density(
9498
9599
Returns
96100
-------
97-
sym_var
101+
variable
98102
The destination's downstream density.
99103
"""
100104
if engine is None:

src/sym_metanet/blocks/links.py

+42-38
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
from typing import TYPE_CHECKING, Dict, Optional
1+
from typing import TYPE_CHECKING, Collection, Dict, Optional, Tuple, Union
22

3-
from sym_metanet.blocks.base import ElementWithVars, sym_var
3+
from sym_metanet.blocks.base import ElementWithVars
44
from sym_metanet.blocks.origins import MeteredOnRamp
55
from sym_metanet.engines.core import EngineBase, get_current_engine
66
from sym_metanet.util.funcs import first
7+
from sym_metanet.util.types import VarType
78

89
if TYPE_CHECKING:
10+
from sym_metanet.blocks.nodes import Node
911
from sym_metanet.network import Network
1012

1113

12-
class Link(ElementWithVars[sym_var]):
14+
class Link(ElementWithVars[VarType]):
1315
"""Highway link between two nodes [1, Section 3.2.1]. Links represent stretch of
1416
highway with similar traffic characteristics and no road changes (e.g., same number
1517
of lanes and maximum speed).
@@ -26,13 +28,13 @@ class Link(ElementWithVars[sym_var]):
2628
def __init__(
2729
self,
2830
nb_segments: int,
29-
lanes: sym_var,
30-
length: sym_var,
31-
maximum_density: sym_var,
32-
critical_density: sym_var,
33-
free_flow_velocity: sym_var,
34-
a: sym_var,
35-
turnrate: sym_var = 1.0,
31+
lanes: Union[VarType, int],
32+
length: Union[VarType, float],
33+
maximum_density: Union[VarType, float],
34+
critical_density: Union[VarType, float],
35+
free_flow_velocity: Union[VarType, float],
36+
a: Union[VarType, float],
37+
turnrate: Union[VarType, float] = 1.0,
3638
name: Optional[str] = None,
3739
) -> None:
3840
"""Creates an instance of a METANET link.
@@ -41,20 +43,20 @@ def __init__(
4143
----------
4244
nb_segments : int
4345
Number of segments in this highway link, i.e., `N`.
44-
lanes : int or symbolic
46+
lanes : int or variable
4547
Number of lanes in each segment, i.e., `lam`.
46-
lengths : float or symbolic
48+
lengths : float or variable
4749
Length of each segment in the link, i.e., `L`.
48-
maximum density : float or symbolic
50+
maximum density : float or variable
4951
Maximum density that the link can withstand, i.e., `rho_max`.
50-
critical_densities : float or symbolic
52+
critical_densities : float or variable
5153
Critical density at which the traffic flow is maximal, i.e., `rho_crit`.
52-
free_flow_velocities : float or symbolic
54+
free_flow_velocities : float or variable
5355
Average speed of cars when traffic is freely flowing, i.e., `v_free`.
54-
a : float or symbolic
56+
a : float or variable
5557
Model parameter in the computations of the equivalent speed [1, Equation
5658
3.4].
57-
turnrate : float or symbolic, optional
59+
turnrate : float or variable, optional
5860
Fraction of the total flow that enters this link via the upstream node. Only
5961
relevant if multiple exiting links are attached to the same node, in order
6062
to split the flow according to these rates. Needs not be normalized. By
@@ -79,7 +81,7 @@ def __init__(
7981

8082
def init_vars(
8183
self,
82-
init_conditions: Optional[Dict[str, sym_var]] = None,
84+
init_conditions: Optional[Dict[str, VarType]] = None,
8385
engine: Optional[EngineBase] = None,
8486
) -> None:
8587
"""For each segment in the link, initializes
@@ -100,7 +102,7 @@ def init_vars(
100102
init_conditions = {}
101103
if engine is None:
102104
engine = get_current_engine()
103-
self.states: Dict[str, sym_var] = {
105+
self.states: Dict[str, VarType] = {
104106
name: (
105107
init_conditions[name]
106108
if name in init_conditions
@@ -109,7 +111,7 @@ def init_vars(
109111
for name in ("rho", "v")
110112
}
111113

112-
def get_flow(self, engine: Optional[EngineBase] = None, **kwargs) -> sym_var:
114+
def get_flow(self, engine: Optional[EngineBase] = None, **kwargs) -> VarType:
113115
"""Gets the flow in this link's segments.
114116
115117
Parameters
@@ -119,7 +121,7 @@ def get_flow(self, engine: Optional[EngineBase] = None, **kwargs) -> sym_var:
119121
120122
Returns
121123
-------
122-
sym_var
124+
variable
123125
The flow in this link.
124126
"""
125127
if engine is None:
@@ -131,46 +133,46 @@ def get_flow(self, engine: Optional[EngineBase] = None, **kwargs) -> sym_var:
131133
def step_dynamics(
132134
self,
133135
net: "Network",
134-
tau: sym_var,
135-
eta: sym_var,
136-
kappa: sym_var,
137-
T: sym_var,
138-
delta: Optional[sym_var] = None,
139-
phi: Optional[sym_var] = None,
136+
tau: Union[VarType, float],
137+
eta: Union[VarType, float],
138+
kappa: Union[VarType, float],
139+
T: Union[VarType, float],
140+
delta: Union[None, VarType, float] = None,
141+
phi: Union[None, VarType, float] = None,
140142
engine: Optional[EngineBase] = None,
141143
**kwargs,
142-
) -> Dict[str, sym_var]:
144+
) -> Dict[str, VarType]:
143145
"""Steps the dynamics of this link.
144146
145147
Parameters
146148
----------
147149
net : Network
148150
The network the link belongs to.
149-
tau : sym_var
151+
tau : float or variable
150152
Model parameter for the speed relaxation term.
151-
eta : sym_var
153+
eta : float or variable
152154
Model parameter for the speed anticipation term.
153-
kappa : sym_var
155+
kappa : float or variable
154156
Model parameter for the speed anticipation term.
155-
T : sym_var
157+
T : float or variable
156158
Sampling time.
157-
delta : sym_var, optional
159+
delta : float or variable, optional
158160
Model parameter for merging phenomenum. By default, not considered.
159-
phi : sym_var, optional
161+
phi : float or variable, optional
160162
Model parameter for lane drop phenomenum. By defaul, not considered.
161163
engine : EngineBase, optional
162164
The engine to be used. If `None`, the current engine is used.
163165
164166
Returns
165167
-------
166-
Dict[str, sym_var]
168+
Dict[str, variable]
167169
A dict with the states of the link (speeds and densities) at the next time
168170
step.
169171
"""
170172
if engine is None:
171173
engine = get_current_engine()
172174

173-
node_up, node_down = net.nodes_by_link[self]
175+
node_up, node_down = net.nodes_by_link[self] # type: ignore[index]
174176
rho = self.states["rho"]
175177
v = self.states["v"]
176178
q = self.get_flow(engine=engine)
@@ -204,10 +206,12 @@ def step_dynamics(
204206
# check for lane drops in the next link (only if one link downstream)
205207
lanes_drop = None
206208
if phi is not None:
207-
links_down = net.out_links(node_down)
209+
links_down: Collection[
210+
Tuple["Node", "Node", "Link[VarType]"]
211+
] = net.out_links(node_down)
208212
if len(links_down) == 1:
209213
link_down = first(links_down)[-1]
210-
lanes_drop = self.lam - link_down.lam
214+
lanes_drop = self.lam - link_down.lam # type: ignore[operator]
211215
if lanes_drop == 0:
212216
lanes_drop = None
213217

0 commit comments

Comments
 (0)