Skip to content

Commit

Permalink
better type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Jan 1, 2023
1 parent 68fa1f3 commit 3b0d4db
Show file tree
Hide file tree
Showing 10 changed files with 338 additions and 231 deletions.
14 changes: 9 additions & 5 deletions examples/casadi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@
N1 = metanet.Node(name="N1")
N2 = metanet.Node(name="N2")
N3 = metanet.Node(name="N3")
L1 = metanet.Link(2, lanes, L, rho_max, rho_crit_sym, v_free_sym, a_sym, name="L1")
L2 = metanet.Link(1, lanes, L, rho_max, rho_crit_sym, v_free_sym, a_sym, name="L2")
O1 = metanet.MeteredOnRamp(C[0], name="O1")
O2 = metanet.SimpleMeteredOnRamp(C[1], name="O2")
D3 = metanet.CongestedDestination(name="D3")
L1 = metanet.Link[cs.SX](
2, lanes, L, rho_max, rho_crit_sym, v_free_sym, a_sym, name="L1"
)
L2 = metanet.Link[cs.SX](
1, lanes, L, rho_max, rho_crit_sym, v_free_sym, a_sym, name="L2"
)
O1 = metanet.MeteredOnRamp[cs.SX](C[0], name="O1")
O2 = metanet.SimpleMeteredOnRamp[cs.SX](C[1], name="O2")
D3 = metanet.CongestedDestination[cs.SX](name="D3")


# build and validate network
Expand Down
26 changes: 10 additions & 16 deletions src/sym_metanet/blocks/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from itertools import count
from typing import Dict, Generic, Optional, Set, TypeVar, ClassVar
from typing import ClassVar, Dict, Generic, Optional, Set

from sym_metanet.util.types import VarType


class ElementBase:
Expand Down Expand Up @@ -33,15 +35,7 @@ def __repr__(self) -> str:
return f"<{self.name}: {self.__class__.__name__}>"


sym_var = TypeVar("sym_var")
sym_var.__doc__ = (
"Variable that can also be numerical or symbolic, "
"depending on the engine. Should be indexable as an array "
"in case of vector quantities."
)


class ElementWithVars(ElementBase, Generic[sym_var], ABC):
class ElementWithVars(ElementBase, Generic[VarType], ABC):
"""Base class for any element with states, actions or disturbances."""

__slots__ = ("states", "next_states", "actions", "disturbances")
Expand All @@ -59,10 +53,10 @@ def __init__(self, name: Optional[str] = None) -> None:
of the class' instancies.
"""
super().__init__(name=name)
self.states: Optional[Dict[str, sym_var]] = None
self.next_states: Optional[Dict[str, sym_var]] = None
self.actions: Optional[Dict[str, sym_var]] = None
self.disturbances: Optional[Dict[str, sym_var]] = None
self.states: Optional[Dict[str, VarType]] = None
self.next_states: Optional[Dict[str, VarType]] = None
self.actions: Optional[Dict[str, VarType]] = None
self.disturbances: Optional[Dict[str, VarType]] = None

@property
def has_states(self) -> bool:
Expand Down Expand Up @@ -94,12 +88,12 @@ def init_vars(self, *args, **kwargs) -> None:
)

@abstractmethod
def step_dynamics(self, *args, **kwargs) -> Dict[str, sym_var]:
def step_dynamics(self, *args, **kwargs) -> Dict[str, VarType]:
"""Internal method for stepping the element's dynamics by one time step.
Returns
-------
Dict[str, sym_var]
Dict[str, VarType]
A dict with the states at the next time step.
Raises
Expand Down
30 changes: 17 additions & 13 deletions src/sym_metanet/blocks/destinations.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Collection, Dict, Optional, Tuple

from sym_metanet.blocks.base import ElementWithVars, sym_var
from sym_metanet.blocks.base import ElementWithVars
from sym_metanet.engines.core import EngineBase, get_current_engine
from sym_metanet.util.funcs import first
from sym_metanet.util.types import VarType

if TYPE_CHECKING:
from sym_metanet.blocks.links import Link
from sym_metanet.blocks.nodes import Node
from sym_metanet.network import Network


class Destination(ElementWithVars[sym_var]):
class Destination(ElementWithVars[VarType]):
"""Ideal congestion-free destination, representing a sink where cars can leave the
highway with no congestion (i.e., no slowing down due to downstream density)."""

def init_vars(self, *args, **kwargs) -> None:
"""Initializes no variable in the ideal destination."""

def step_dynamics(self, *args, **kwargs) -> Dict[str, sym_var]:
def step_dynamics(self, *args, **kwargs) -> Dict[str, VarType]:
"""No dynamics to steps in the ideal destination."""
return {}

def get_density(self, net: "Network", **kwargs) -> sym_var:
def get_density(self, net: "Network", **kwargs) -> VarType:
"""Computes the (downstream) density induced by the ideal destination.
Parameters
Expand All @@ -30,22 +32,24 @@ def get_density(self, net: "Network", **kwargs) -> sym_var:
Returns
-------
sym_var
symbolic variable
The destination's downstream density.
"""
return self._get_entering_link(net=net).states["rho"][-1]

def _get_entering_link(self, net: "Network") -> "Link":
def _get_entering_link(self, net: "Network") -> "Link[VarType]":
"""Internal utility to fetch the link entering this destination (can only be
one)."""
links_up = net.in_links(net.destinations[self])
links_up: Collection[Tuple["Node", "Node", "Link[VarType]"]] = net.in_links(
net.destinations[self] # type: ignore[index]
)
assert (
len(links_up) == 1
), "Internal error. Only one link can enter a destination."
return first(links_up)[-1]


class CongestedDestination(Destination[sym_var]):
class CongestedDestination(Destination[VarType]):
"""Destination with a downstream density scenario to emulate congestions, that is,
cars cannot exit freely the highway but must slow down and, possibly, create a
congestion."""
Expand All @@ -54,7 +58,7 @@ class CongestedDestination(Destination[sym_var]):

def init_vars(
self,
init_conditions: Optional[Dict[str, sym_var]] = None,
init_conditions: Optional[Dict[str, VarType]] = None,
engine: Optional[EngineBase] = None,
) -> None:
"""Initializes
Expand All @@ -74,15 +78,15 @@ def init_vars(
"""
if engine is None:
engine = get_current_engine()
self.disturbances: Dict[str, sym_var] = {
self.disturbances: Dict[str, VarType] = {
"d": engine.var(f"d_{self.name}")
if init_conditions is None or "d" not in init_conditions
else init_conditions["d"]
}

def get_density(
self, net: "Network", engine: Optional[EngineBase] = None, **kwargs
) -> sym_var:
) -> VarType:
"""Computes the (downstream) density induced by the congested destination.
Parameters
Expand All @@ -94,7 +98,7 @@ def get_density(
Returns
-------
sym_var
variable
The destination's downstream density.
"""
if engine is None:
Expand Down
80 changes: 42 additions & 38 deletions src/sym_metanet/blocks/links.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Collection, Dict, Optional, Tuple, Union

from sym_metanet.blocks.base import ElementWithVars, sym_var
from sym_metanet.blocks.base import ElementWithVars
from sym_metanet.blocks.origins import MeteredOnRamp
from sym_metanet.engines.core import EngineBase, get_current_engine
from sym_metanet.util.funcs import first
from sym_metanet.util.types import VarType

if TYPE_CHECKING:
from sym_metanet.blocks.nodes import Node
from sym_metanet.network import Network


class Link(ElementWithVars[sym_var]):
class Link(ElementWithVars[VarType]):
"""Highway link between two nodes [1, Section 3.2.1]. Links represent stretch of
highway with similar traffic characteristics and no road changes (e.g., same number
of lanes and maximum speed).
Expand All @@ -26,13 +28,13 @@ class Link(ElementWithVars[sym_var]):
def __init__(
self,
nb_segments: int,
lanes: sym_var,
length: sym_var,
maximum_density: sym_var,
critical_density: sym_var,
free_flow_velocity: sym_var,
a: sym_var,
turnrate: sym_var = 1.0,
lanes: Union[VarType, int],
length: Union[VarType, float],
maximum_density: Union[VarType, float],
critical_density: Union[VarType, float],
free_flow_velocity: Union[VarType, float],
a: Union[VarType, float],
turnrate: Union[VarType, float] = 1.0,
name: Optional[str] = None,
) -> None:
"""Creates an instance of a METANET link.
Expand All @@ -41,20 +43,20 @@ def __init__(
----------
nb_segments : int
Number of segments in this highway link, i.e., `N`.
lanes : int or symbolic
lanes : int or variable
Number of lanes in each segment, i.e., `lam`.
lengths : float or symbolic
lengths : float or variable
Length of each segment in the link, i.e., `L`.
maximum density : float or symbolic
maximum density : float or variable
Maximum density that the link can withstand, i.e., `rho_max`.
critical_densities : float or symbolic
critical_densities : float or variable
Critical density at which the traffic flow is maximal, i.e., `rho_crit`.
free_flow_velocities : float or symbolic
free_flow_velocities : float or variable
Average speed of cars when traffic is freely flowing, i.e., `v_free`.
a : float or symbolic
a : float or variable
Model parameter in the computations of the equivalent speed [1, Equation
3.4].
turnrate : float or symbolic, optional
turnrate : float or variable, optional
Fraction of the total flow that enters this link via the upstream node. Only
relevant if multiple exiting links are attached to the same node, in order
to split the flow according to these rates. Needs not be normalized. By
Expand All @@ -79,7 +81,7 @@ def __init__(

def init_vars(
self,
init_conditions: Optional[Dict[str, sym_var]] = None,
init_conditions: Optional[Dict[str, VarType]] = None,
engine: Optional[EngineBase] = None,
) -> None:
"""For each segment in the link, initializes
Expand All @@ -100,7 +102,7 @@ def init_vars(
init_conditions = {}
if engine is None:
engine = get_current_engine()
self.states: Dict[str, sym_var] = {
self.states: Dict[str, VarType] = {
name: (
init_conditions[name]
if name in init_conditions
Expand All @@ -109,7 +111,7 @@ def init_vars(
for name in ("rho", "v")
}

def get_flow(self, engine: Optional[EngineBase] = None, **kwargs) -> sym_var:
def get_flow(self, engine: Optional[EngineBase] = None, **kwargs) -> VarType:
"""Gets the flow in this link's segments.
Parameters
Expand All @@ -119,7 +121,7 @@ def get_flow(self, engine: Optional[EngineBase] = None, **kwargs) -> sym_var:
Returns
-------
sym_var
variable
The flow in this link.
"""
if engine is None:
Expand All @@ -131,46 +133,46 @@ def get_flow(self, engine: Optional[EngineBase] = None, **kwargs) -> sym_var:
def step_dynamics(
self,
net: "Network",
tau: sym_var,
eta: sym_var,
kappa: sym_var,
T: sym_var,
delta: Optional[sym_var] = None,
phi: Optional[sym_var] = None,
tau: Union[VarType, float],
eta: Union[VarType, float],
kappa: Union[VarType, float],
T: Union[VarType, float],
delta: Union[None, VarType, float] = None,
phi: Union[None, VarType, float] = None,
engine: Optional[EngineBase] = None,
**kwargs,
) -> Dict[str, sym_var]:
) -> Dict[str, VarType]:
"""Steps the dynamics of this link.
Parameters
----------
net : Network
The network the link belongs to.
tau : sym_var
tau : float or variable
Model parameter for the speed relaxation term.
eta : sym_var
eta : float or variable
Model parameter for the speed anticipation term.
kappa : sym_var
kappa : float or variable
Model parameter for the speed anticipation term.
T : sym_var
T : float or variable
Sampling time.
delta : sym_var, optional
delta : float or variable, optional
Model parameter for merging phenomenum. By default, not considered.
phi : sym_var, optional
phi : float or variable, optional
Model parameter for lane drop phenomenum. By defaul, not considered.
engine : EngineBase, optional
The engine to be used. If `None`, the current engine is used.
Returns
-------
Dict[str, sym_var]
Dict[str, variable]
A dict with the states of the link (speeds and densities) at the next time
step.
"""
if engine is None:
engine = get_current_engine()

node_up, node_down = net.nodes_by_link[self]
node_up, node_down = net.nodes_by_link[self] # type: ignore[index]
rho = self.states["rho"]
v = self.states["v"]
q = self.get_flow(engine=engine)
Expand Down Expand Up @@ -204,10 +206,12 @@ def step_dynamics(
# check for lane drops in the next link (only if one link downstream)
lanes_drop = None
if phi is not None:
links_down = net.out_links(node_down)
links_down: Collection[
Tuple["Node", "Node", "Link[VarType]"]
] = net.out_links(node_down)
if len(links_down) == 1:
link_down = first(links_down)[-1]
lanes_drop = self.lam - link_down.lam
lanes_drop = self.lam - link_down.lam # type: ignore[operator]
if lanes_drop == 0:
lanes_drop = None

Expand Down
Loading

0 comments on commit 3b0d4db

Please sign in to comment.