diff --git a/src/rod/__init__.py b/src/rod/__init__.py index 3533ff9..0291906 100644 --- a/src/rod/__init__.py +++ b/src/rod/__init__.py @@ -115,9 +115,9 @@ def check_compatible_sdformat(specification_version: str) -> None: if not GazeboHelper.has_gazebo(): return - else: - cmdline = GazeboHelper.get_gazebo_executable() - logging.info(f"Calling sdformat through '{cmdline} sdf'") + + cmdline = GazeboHelper.get_gazebo_executable() + logging.info(f"Calling sdformat through '{cmdline} sdf'") output_sdf_version = packaging.version.Version( xmltodict.parse( diff --git a/src/rod/builder/primitive_builder.py b/src/rod/builder/primitive_builder.py index 6e34062..58f9d97 100644 --- a/src/rod/builder/primitive_builder.py +++ b/src/rod/builder/primitive_builder.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import abc import dataclasses -from typing import Optional, Union +from typing import Optional import numpy as np import numpy.typing as npt @@ -14,7 +16,7 @@ class PrimitiveBuilder(abc.ABC): name: str mass: float - element: Union[rod.Model, rod.Link, rod.Inertial, rod.Collision, rod.Visual] = ( + element: rod.Model | rod.Link | rod.Inertial | rod.Collision | rod.Visual = ( dataclasses.field( default=None, init=False, repr=False, hash=False, compare=False ) @@ -22,7 +24,7 @@ class PrimitiveBuilder(abc.ABC): def build( self, - ) -> Union[rod.Model, rod.Link, rod.Inertial, rod.Collision, rod.Visual]: + ) -> rod.Model | rod.Link | rod.Inertial | rod.Collision | rod.Visual: return self.element # ================ @@ -43,9 +45,9 @@ def _geometry(self) -> rod.Geometry: def build_model( self, - name: Optional[str] = None, - pose: Optional[rod.Pose] = None, - ) -> "PrimitiveBuilder": + name: str | None = None, + pose: rod.Pose | None = None, + ) -> PrimitiveBuilder: self._check_element() self.element = self._model(name=name, pose=pose) @@ -54,16 +56,16 @@ def build_model( def build_link( self, - name: Optional[str] = None, - pose: Optional[rod.Pose] = None, - ) -> "PrimitiveBuilder": + name: str | None = None, + pose: rod.Pose | None = None, + ) -> PrimitiveBuilder: self._check_element() self.element = self._link(name=name, pose=pose) return self - def build_inertial(self, pose: Optional[rod.Pose] = None) -> "PrimitiveBuilder": + def build_inertial(self, pose: rod.Pose | None = None) -> PrimitiveBuilder: self._check_element() self.element = self._inertial(pose=pose) @@ -72,9 +74,9 @@ def build_inertial(self, pose: Optional[rod.Pose] = None) -> "PrimitiveBuilder": def build_visual( self, - name: Optional[str] = None, - pose: Optional[rod.Pose] = None, - ) -> "PrimitiveBuilder": + name: str | None = None, + pose: rod.Pose | None = None, + ) -> PrimitiveBuilder: self._check_element() self.element = self._visual(name=name, pose=pose) @@ -83,9 +85,9 @@ def build_visual( def build_collision( self, - name: Optional[str] = None, - pose: Optional[rod.Pose] = None, - ) -> "PrimitiveBuilder": + name: str | None = None, + pose: rod.Pose | None = None, + ) -> PrimitiveBuilder: self._check_element() self.element = self._collision(name=name, pose=pose) @@ -98,10 +100,10 @@ def build_collision( def add_link( self, - name: Optional[str] = None, - pose: Optional[rod.Pose] = None, - link: Optional[rod.Link] = None, - ) -> "PrimitiveBuilder": + name: str | None = None, + pose: rod.Pose | None = None, + link: rod.Link | None = None, + ) -> PrimitiveBuilder: if not isinstance(self.element, rod.Model): raise ValueError(type(self.element)) @@ -116,9 +118,9 @@ def add_link( def add_inertial( self, - pose: Optional[rod.Pose] = None, - inertial: Optional[rod.Inertial] = None, - ) -> "PrimitiveBuilder": + pose: rod.Pose | None = None, + inertial: rod.Inertial | None = None, + ) -> PrimitiveBuilder: if not isinstance(self.element, (rod.Model, rod.Link)): raise ValueError(type(self.element)) @@ -144,11 +146,11 @@ def add_inertial( def add_visual( self, - name: Optional[str] = None, + name: str | None = None, use_inertial_pose: bool = True, - pose: Optional[rod.Pose] = None, - visual: Optional[rod.Visual] = None, - ) -> "PrimitiveBuilder": + pose: rod.Pose | None = None, + visual: rod.Visual | None = None, + ) -> PrimitiveBuilder: if not isinstance(self.element, (rod.Model, rod.Link)): raise ValueError(type(self.element)) @@ -180,11 +182,11 @@ def add_visual( def add_collision( self, - name: Optional[str] = None, + name: str | None = None, use_inertial_pose: bool = True, - pose: Optional[rod.Pose] = None, - collision: Optional[rod.Collision] = None, - ) -> "PrimitiveBuilder": + pose: rod.Pose | None = None, + collision: rod.Collision | None = None, + ) -> PrimitiveBuilder: if not isinstance(self.element, (rod.Model, rod.Link)): raise ValueError(type(self.element)) @@ -224,8 +226,8 @@ def add_collision( def _model( self, - name: Optional[str] = None, - pose: Optional[rod.Pose] = None, + name: str | None = None, + pose: rod.Pose | None = None, ) -> rod.Model: name = name if name is not None else self.name logging.debug(f"Building model '{name}'") @@ -240,15 +242,15 @@ def _model( def _link( self, - name: Optional[str] = None, - pose: Optional[rod.Pose] = None, + name: str | None = None, + pose: rod.Pose | None = None, ) -> rod.Link: return rod.Link( name=name if name is not None else f"{self.name}_link", pose=pose, ) - def _inertial(self, pose: Optional[rod.Pose] = None) -> rod.Inertial: + def _inertial(self, pose: rod.Pose | None = None) -> rod.Inertial: return rod.Inertial( pose=pose, mass=self.mass, @@ -257,8 +259,8 @@ def _inertial(self, pose: Optional[rod.Pose] = None) -> rod.Inertial: def _visual( self, - name: Optional[str] = None, - pose: Optional[rod.Pose] = None, + name: str | None = None, + pose: rod.Pose | None = None, ) -> rod.Visual: name = name if name is not None else f"{self.name}_visual" @@ -271,7 +273,7 @@ def _visual( def _collision( self, name: Optional[str], - pose: Optional[rod.Pose] = None, + pose: rod.Pose | None = None, ) -> rod.Collision: name = name if name is not None else f"{self.name}_collision" @@ -297,7 +299,7 @@ def build_pose( relative_to: str = None, degrees: bool = None, rotation_format: str = None, - ) -> Optional[rod.Pose]: + ) -> rod.Pose | None: if pos is None and rpy is None: return rod.Pose.from_transform(transform=np.eye(4), relative_to=relative_to) diff --git a/src/rod/builder/primitives.py b/src/rod/builder/primitives.py index b6bb883..b3786f5 100644 --- a/src/rod/builder/primitives.py +++ b/src/rod/builder/primitives.py @@ -1,6 +1,5 @@ import dataclasses import pathlib -from typing import Union import trimesh from numpy.typing import NDArray @@ -63,7 +62,7 @@ def _geometry(self) -> rod.Geometry: @dataclasses.dataclass class MeshBuilder(PrimitiveBuilder): - mesh_path: Union[str, pathlib.Path] + mesh_path: str | pathlib.Path scale: NDArray def __post_init__(self) -> None: diff --git a/src/rod/kinematics/kinematic_tree.py b/src/rod/kinematics/kinematic_tree.py index b0c4672..f66dd96 100644 --- a/src/rod/kinematics/kinematic_tree.py +++ b/src/rod/kinematics/kinematic_tree.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import copy import dataclasses import functools -from typing import Dict, List, Sequence, Tuple, Union +from typing import Dict, List, Sequence, Tuple import numpy as np @@ -12,7 +14,7 @@ @dataclasses.dataclass(frozen=True) class KinematicTree(DirectedTree): - model: "rod.Model" + model: rod.Model joints: List[TreeEdge] = dataclasses.field(default_factory=list) frames: List[TreeFrame] = dataclasses.field(default_factory=list) @@ -46,7 +48,7 @@ def joint_names(self) -> List[str]: return [joint.name() for joint in self.joints] @staticmethod - def build(model: "rod.Model", is_top_level: bool = True) -> "KinematicTree": + def build(model: rod.Model, is_top_level: bool = True) -> KinematicTree: logging.debug(msg=f"Building kinematic tree of model '{model.name}'") if model.model is not None: @@ -199,7 +201,7 @@ def build(model: "rod.Model", is_top_level: bool = True) -> "KinematicTree": new_base_node, additional_frames = KinematicTree.remove_edge( edge=world_to_base_edge, keep_parent=False ) - assert any([f.name() == TreeFrame.WORLD for f in additional_frames]) + assert any(f.name() == TreeFrame.WORLD for f in additional_frames) # Replace the former base node with the new base node nodes_links_dict[new_base_node.name()] = new_base_node diff --git a/src/rod/kinematics/tree_transforms.py b/src/rod/kinematics/tree_transforms.py index 6bbe261..00f4e8a 100644 --- a/src/rod/kinematics/tree_transforms.py +++ b/src/rod/kinematics/tree_transforms.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import dataclasses @@ -12,11 +14,13 @@ @dataclasses.dataclass class TreeTransforms: kinematic_tree: KinematicTree = dataclasses.dataclass(init=False) + _transform_cache: dict[str, npt.NDArray] = dataclasses.field(default_factory=dict) @staticmethod - def build(model: "rod.Model", is_top_level: bool = True) -> "TreeTransforms": - - # Operate on a deep copy of the model to avoid side effects. + def build( + model: rod.Model, + is_top_level: bool = True, + ) -> TreeTransforms: model = copy.deepcopy(model) # Make sure that all elements have a pose attribute with explicit 'relative_to'. @@ -28,60 +32,72 @@ def build(model: "rod.Model", is_top_level: bool = True) -> "TreeTransforms": ) def transform(self, name: str) -> npt.NDArray: - if name == TreeFrame.WORLD: - return np.eye(4) + if name in self._transform_cache: + return self._transform_cache[name] + + self._transform_cache[name] = self._compute_transform(name=name) + return self._transform_cache[name] + + def _compute_transform(self, name: str) -> npt.NDArray: + match name: + case TreeFrame.WORLD: + + return np.eye(4) + + case name if name in {TreeFrame.MODEL, self.kinematic_tree.model.name}: + + relative_to = self.kinematic_tree.model.pose.relative_to + assert relative_to in {None, ""}, (relative_to, name) + return self.kinematic_tree.model.pose.transform() - if name in {TreeFrame.MODEL, self.kinematic_tree.model.name}: - relative_to = self.kinematic_tree.model.pose.relative_to - assert relative_to in {None, ""}, (relative_to, name) - return self.kinematic_tree.model.pose.transform() + case name if name in self.kinematic_tree.joint_names(): - if name in self.kinematic_tree.joint_names(): - edge = self.kinematic_tree.joints_dict[name] - assert edge.name() == name + edge = self.kinematic_tree.joints_dict[name] + assert edge.name() == name - # Get the pose of the frame in which the node's pose is expressed - assert edge._source.pose.relative_to not in {"", None} - x_H_E = edge._source.pose.transform() - W_H_x = self.transform(name=edge._source.pose.relative_to) + # Get the pose of the frame in which the node's pose is expressed + assert edge._source.pose.relative_to not in {"", None} + x_H_E = edge._source.pose.transform() + W_H_x = self.transform(name=edge._source.pose.relative_to) - # Compute the world-to-node transform - # TODO: this assumes all joint positions to be 0 - W_H_E = W_H_x @ x_H_E + # Compute the world-to-node transform + # TODO: this assumes all joint positions to be 0 + W_H_E = W_H_x @ x_H_E - return W_H_E + return W_H_E - if name in self.kinematic_tree.link_names(): + case name if name in self.kinematic_tree.link_names(): - element = self.kinematic_tree.links_dict[name] + element = self.kinematic_tree.links_dict[name] - assert element.name() == name - assert element._source.pose.relative_to not in {"", None} + assert element.name() == name + assert element._source.pose.relative_to not in {"", None} - # Get the pose of the frame in which the link's pose is expressed. - x_H_L = element._source.pose.transform() - W_H_x = self.transform(name=element._source.pose.relative_to) + # Get the pose of the frame in which the link's pose is expressed. + x_H_L = element._source.pose.transform() + W_H_x = self.transform(name=element._source.pose.relative_to) - # Compute the world transform of the link. - W_H_L = W_H_x @ x_H_L - return W_H_L + # Compute the world transform of the link. + W_H_L = W_H_x @ x_H_L + return W_H_L - if name in self.kinematic_tree.frame_names(): + case name if name in self.kinematic_tree.frame_names(): - element = self.kinematic_tree.frames_dict[name] + element = self.kinematic_tree.frames_dict[name] - assert element.name() == name - assert element._source.pose.relative_to not in {"", None} + assert element.name() == name + assert element._source.pose.relative_to not in {"", None} - # Get the pose of the frame in which the frame's pose is expressed. - x_H_F = element._source.pose.transform() - W_H_x = self.transform(name=element._source.pose.relative_to) + # Get the pose of the frame in which the frame's pose is expressed. + x_H_F = element._source.pose.transform() + W_H_x = self.transform(name=element._source.pose.relative_to) - # Compute the world transform of the frame. - W_H_F = W_H_x @ x_H_F - return W_H_F + # Compute the world transform of the frame. + W_H_F = W_H_x @ x_H_F + return W_H_F - raise ValueError(name) + case _: + raise ValueError(name) def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: diff --git a/src/rod/logging.py b/src/rod/logging.py index 5e1565f..7d4ed64 100644 --- a/src/rod/logging.py +++ b/src/rod/logging.py @@ -1,6 +1,5 @@ import enum import logging -from typing import Union import coloredlogs @@ -20,7 +19,7 @@ def _logger() -> logging.Logger: return logging.getLogger(name=LOGGER_NAME) -def set_logging_level(level: Union[int, LoggingLevel] = LoggingLevel.WARNING): +def set_logging_level(level: int | LoggingLevel = LoggingLevel.WARNING): if isinstance(level, int): level = LoggingLevel(level) diff --git a/src/rod/pretty_printer.py b/src/rod/pretty_printer.py index 40ecd85..479d523 100644 --- a/src/rod/pretty_printer.py +++ b/src/rod/pretty_printer.py @@ -30,7 +30,7 @@ def list_to_string(obj: List[Any], level: int = 1) -> str: ] return ( - f"[\n" + "[\n" + ",\n".join(f"{spacing_level}{el!s}" for el in list_str) + f",\n{spacing_level_up}]" ) @@ -45,26 +45,27 @@ def dataclass_to_str(obj: Any, level: int = 1) -> str: for field in dataclasses.fields(obj): attr = getattr(obj, field.name) - if attr is None or attr == "": - continue - - elif isinstance(attr, list): - list_str = DataclassPrettyPrinter.list_to_string( - obj=attr, level=level + 1 - ) - serialization += [(field.name, list_str)] - continue - - elif dataclasses.is_dataclass(attr): - dataclass_str = DataclassPrettyPrinter.dataclass_to_str( - obj=attr, level=level + 1 - ) - serialization += [(field.name, dataclass_str)] - continue - - else: - serialization += [(field.name, f"{attr!s}")] - continue + match attr: + case None | "": + continue + + case list(): + list_str = DataclassPrettyPrinter.list_to_string( + obj=attr, level=level + 1 + ) + serialization += [(field.name, list_str)] + continue + + case _ if dataclasses.is_dataclass(attr): + dataclass_str = DataclassPrettyPrinter.dataclass_to_str( + obj=attr, level=level + 1 + ) + serialization += [(field.name, dataclass_str)] + continue + + case _: + serialization += [(field.name, f"{attr!s}")] + continue spacing = " " * 4 spacing_level = spacing * level diff --git a/src/rod/sdf/element.py b/src/rod/sdf/element.py index 97744c3..138d05d 100644 --- a/src/rod/sdf/element.py +++ b/src/rod/sdf/element.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import mashumaro.config import mashumaro.mixins.dict @@ -37,7 +37,7 @@ def deserialize_bool(data: str) -> bool: false_vals = {"0", "False", "false"} assert data in true_vals.union(false_vals) - return True if data in true_vals else False + return data in true_vals @staticmethod def serialize_float(data: float) -> str: @@ -50,7 +50,7 @@ def serialize_list(data: List[float]) -> str: return " ".join(np.array(data, dtype=str)) @staticmethod - def deserialize_list(data: str, length: Optional[int] = None) -> List[float]: + def deserialize_list(data: str, length: int | None = None) -> List[float]: assert isinstance(data, str) array = np.atleast_1d(np.array(data.split(sep=" "), dtype=float).squeeze()) diff --git a/src/rod/sdf/link.py b/src/rod/sdf/link.py index 48189d1..dcc8832 100644 --- a/src/rod/sdf/link.py +++ b/src/rod/sdf/link.py @@ -1,5 +1,5 @@ import dataclasses -from typing import List, Optional, Union +from typing import List, Optional import mashumaro import numpy as np @@ -73,11 +73,9 @@ class Link(Element): inertial: Optional[Inertial] = dataclasses.field(default=None) - visual: Optional[Union[Visual, List[Visual]]] = dataclasses.field(default=None) + visual: Optional[Visual | List[Visual]] = dataclasses.field(default=None) - collision: Optional[Union[Collision, List[Collision]]] = dataclasses.field( - default=None - ) + collision: Optional[Collision | List[Collision]] = dataclasses.field(default=None) gravity: Optional[bool] = dataclasses.field( default=None, diff --git a/src/rod/sdf/model.py b/src/rod/sdf/model.py index de33a64..5ba02e1 100644 --- a/src/rod/sdf/model.py +++ b/src/rod/sdf/model.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import dataclasses -from typing import List, Optional, Union +from typing import List, Optional import mashumaro @@ -54,13 +56,13 @@ class Model(Element): pose: Optional[Pose] = dataclasses.field(default=None) - model: Optional[Union["Model", List["Model"]]] = dataclasses.field(default=None) + model: Optional[Model | List[Model]] = dataclasses.field(default=None) - frame: Optional[Union[Frame, List[Frame]]] = dataclasses.field(default=None) + frame: Optional[Frame | List[Frame]] = dataclasses.field(default=None) - link: Optional[Union[Link, List[Link]]] = dataclasses.field(default=None) + link: Optional[Link | List[Link]] = dataclasses.field(default=None) - joint: Optional[Union[Joint, List[Joint]]] = dataclasses.field(default=None) + joint: Optional[Joint | List[Joint]] = dataclasses.field(default=None) def is_fixed_base(self) -> bool: joints_having_world_parent = [j for j in self.joints() if j.parent == "world"] @@ -80,7 +82,7 @@ def get_canonical_link(self) -> str: return self.links()[0].name - def models(self) -> List["Model"]: + def models(self) -> List[Model]: if self.model is None: return [] @@ -155,7 +157,7 @@ def resolve_frames( def switch_frame_convention( self, - frame_convention: "rod.FrameConvention", + frame_convention: rod.FrameConvention, is_top_level: bool = True, explicit_frames: bool = True, attach_frames_to_links: bool = True, diff --git a/src/rod/sdf/sdf.py b/src/rod/sdf/sdf.py index af57a76..5fbc4da 100644 --- a/src/rod/sdf/sdf.py +++ b/src/rod/sdf/sdf.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import dataclasses import os import pathlib -from typing import List, Optional, Union +from typing import List, Optional import mashumaro import packaging.specifiers @@ -19,9 +21,9 @@ class Sdf(Element): version: str = dataclasses.field(metadata=mashumaro.field_options(alias="@version")) - world: Optional[Union[World, List[World]]] = dataclasses.field(default=None) + world: Optional[World | List[World]] = dataclasses.field(default=None) - model: Optional[Union[Model, List[Model]]] = dataclasses.field(default=None) + model: Optional[Model | List[Model]] = dataclasses.field(default=None) def worlds(self) -> List[World]: if self.world is None: @@ -44,7 +46,7 @@ def models(self) -> List[Model]: return self.model @staticmethod - def load(sdf: Union[pathlib.Path, str], is_urdf: Optional[bool] = None) -> "Sdf": + def load(sdf: pathlib.Path | str, is_urdf: bool | None = None) -> Sdf: """ Load an SDF resource. @@ -80,7 +82,7 @@ def load(sdf: Union[pathlib.Path, str], is_urdf: Optional[bool] = None) -> "Sdf" and len(sdf) <= MAX_PATH and pathlib.Path(sdf).is_file() ): - sdf_string = pathlib.Path(sdf).read_text() + sdf_string = pathlib.Path(sdf).read_text(encoding="utf-8") is_urdf = ( is_urdf if is_urdf is not None else pathlib.Path(sdf).suffix == ".urdf" ) @@ -99,8 +101,8 @@ def load(sdf: Union[pathlib.Path, str], is_urdf: Optional[bool] = None) -> "Sdf" # Parse the SDF to dict try: xml_dict = xmltodict.parse(xml_input=sdf_string) - except Exception: - raise ValueError("Failed to parse 'sdf' argument") + except Exception as exc: + raise exc("Failed to parse 'sdf' argument") # Look for the top-level element try: diff --git a/src/rod/sdf/world.py b/src/rod/sdf/world.py index a5b915f..de0c39b 100644 --- a/src/rod/sdf/world.py +++ b/src/rod/sdf/world.py @@ -1,5 +1,5 @@ import dataclasses -from typing import List, Optional, Union +from typing import List, Optional import mashumaro @@ -34,9 +34,9 @@ class World(Element): scene: Scene = dataclasses.field(default_factory=Scene) - model: Optional[Union[Model, List[Model]]] = dataclasses.field(default=None) + model: Optional[Model | List[Model]] = dataclasses.field(default=None) - frame: Optional[Union[Frame, List[Frame]]] = dataclasses.field(default=None) + frame: Optional[Frame | List[Frame]] = dataclasses.field(default=None) def models(self) -> List[Model]: if self.model is None: diff --git a/src/rod/tree/directed_tree.py b/src/rod/tree/directed_tree.py index 7bb2bfe..9ee7f69 100644 --- a/src/rod/tree/directed_tree.py +++ b/src/rod/tree/directed_tree.py @@ -1,7 +1,7 @@ import collections.abc import dataclasses import functools -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List from .tree_elements import DirectedTreeNode @@ -33,7 +33,7 @@ def nodes_dict(self) -> Dict[str, DirectedTreeNode]: @staticmethod def breadth_first_search( root: DirectedTreeNode, - sort_children: Optional[Callable[[Any], Any]] = lambda node: node.name(), + sort_children: Callable[[Any], Any] | None = lambda node: node.name(), ) -> Iterable[DirectedTreeNode]: queue = [root] @@ -69,8 +69,8 @@ def pretty_print(self) -> None: ) def __getitem__( - self, key: Union[int, slice, str] - ) -> Union[DirectedTreeNode, List[DirectedTreeNode]]: + self, key: int | slice | str + ) -> DirectedTreeNode | List[DirectedTreeNode]: # Get the nodes' dictionary (already inserted in order following BFS) nodes_dict = self.nodes_dict @@ -100,7 +100,7 @@ def __iter__(self) -> Iterable[DirectedTreeNode]: def __reversed__(self) -> Iterable[DirectedTreeNode]: yield from reversed(self) - def __contains__(self, item: Union[str, DirectedTreeNode]) -> bool: + def __contains__(self, item: str | DirectedTreeNode) -> bool: if isinstance(item, str): return item in self.nodes_dict.keys() diff --git a/src/rod/tree/tree_elements.py b/src/rod/tree/tree_elements.py index 5426b94..dc08c5f 100644 --- a/src/rod/tree/tree_elements.py +++ b/src/rod/tree/tree_elements.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import abc import dataclasses -from typing import ClassVar, List, Optional, Union +from typing import ClassVar, List, Optional import rod from rod import logging @@ -29,19 +31,19 @@ def __hash__(self): @dataclasses.dataclass(eq=False) class DirectedTreeNode(TreeElement): - parent: Optional["DirectedTreeNode"] = None - children: List["DirectedTreeNode"] = dataclasses.field(default_factory=list) + parent: Optional[DirectedTreeNode] = None + children: List[DirectedTreeNode] = dataclasses.field(default_factory=list) - _source: Optional["rod.Link"] = dataclasses.field(default=None, repr=False) + _source: Optional[rod.Link] = dataclasses.field(default=None, repr=False) def name(self) -> str: return self._source.name - def pose(self) -> "rod.Pose": + def pose(self) -> rod.Pose: if self._source is not None and self._source.pose is not None: return self._source.pose - else: - return rod.Pose(relative_to="world") + + return rod.Pose(relative_to="world") @property def tree_label(self) -> str: @@ -69,7 +71,7 @@ class TreeEdge(TreeElement): _source: Optional[rod.Joint] = dataclasses.field(default=None, repr=False) - def pose(self) -> "rod.Pose": + def pose(self) -> rod.Pose: return self._source.pose def name(self) -> str: @@ -88,7 +90,7 @@ class TreeFrame(TreeElement): WORLD: ClassVar[str] = "world" MODEL: ClassVar[str] = "__model__" - _source: Optional["rod.Frame"] = dataclasses.field(default=None, repr=False) + _source: Optional[rod.Frame] = dataclasses.field(default=None, repr=False) def name(self) -> str: return self._source.name @@ -109,8 +111,8 @@ def __str__(self) -> str: @staticmethod def from_node( node: DirectedTreeNode, - attached_to: Union[DirectedTreeNode, "TreeFrame", TreeEdge] = None, - ) -> "TreeFrame": + attached_to: DirectedTreeNode | TreeFrame | TreeEdge | None = None, + ) -> TreeFrame: attached_to = attached_to if attached_to is not None else node.parent logging.debug( @@ -128,8 +130,8 @@ def from_node( @staticmethod def from_edge( edge: TreeEdge, - attached_to: Union[DirectedTreeNode, "TreeFrame", TreeEdge] = None, - ) -> "TreeFrame": + attached_to: DirectedTreeNode | TreeFrame | TreeEdge | None = None, + ) -> TreeFrame: attached_to = attached_to if attached_to is not None else edge.parent logging.debug( diff --git a/src/rod/urdf/exporter.py b/src/rod/urdf/exporter.py index 606989f..96d4d66 100644 --- a/src/rod/urdf/exporter.py +++ b/src/rod/urdf/exporter.py @@ -1,7 +1,7 @@ import abc import copy import dataclasses -from typing import Any, ClassVar, Dict, List, Set, Union +from typing import Any, ClassVar, Dict, List, Set import numpy as np import xmltodict @@ -24,7 +24,7 @@ class UrdfExporter(abc.ABC): # Whether to inject additional `` elements in the resulting URDF # to preserve fixed joints in case of re-loading into sdformat. # If a list of strings is passed, only the listed fixed joints will be preserved. - gazebo_preserve_fixed_joints: Union[bool, List[str]] = False + gazebo_preserve_fixed_joints: bool | List[str] = False SupportedSdfJointTypes: ClassVar[Set[str]] = { "revolute", @@ -42,10 +42,10 @@ class UrdfExporter(abc.ABC): @staticmethod def sdf_to_urdf_string( - sdf: Union[rod.Sdf, rod.Model], + sdf: rod.Sdf | rod.Model, pretty: bool = False, indent: str = " ", - gazebo_preserve_fixed_joints: Union[bool, List[str]] = False, + gazebo_preserve_fixed_joints: bool | List[str] = False, ) -> str: msg = "This method is deprecated, please use '{}' instead." @@ -57,7 +57,7 @@ def sdf_to_urdf_string( gazebo_preserve_fixed_joints=gazebo_preserve_fixed_joints, ).to_urdf_string(sdf=sdf) - def to_urdf_string(self, sdf: Union[rod.Sdf, rod.Model]) -> str: + def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str: """ Convert an in-memory SDF model to a URDF string. diff --git a/src/rod/utils/frame_convention.py b/src/rod/utils/frame_convention.py index f164b78..d98ff0d 100644 --- a/src/rod/utils/frame_convention.py +++ b/src/rod/utils/frame_convention.py @@ -15,7 +15,7 @@ class FrameConvention(enum.IntEnum): def switch_frame_convention( - model: "rod.Model", + model: rod.Model, frame_convention: FrameConvention, is_top_level: bool = True, attach_frames_to_links: bool = True, @@ -70,102 +70,106 @@ def switch_frame_convention( # Define the default reference frames of the different elements # ============================================================= - if frame_convention is FrameConvention.World: - - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: "world" - reference_frame_frames = lambda f: "world" - reference_frame_joints = lambda j: "world" - reference_frame_visuals = lambda v: "world" - reference_frame_inertials = lambda i, parent_link: "world" - reference_frame_collisions = lambda c: "world" - reference_frame_link_canonical = "world" - - elif frame_convention is FrameConvention.Model: - - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: "__model__" - reference_frame_frames = lambda f: "__model__" - reference_frame_joints = lambda j: "__model__" - reference_frame_visuals = lambda v: "__model__" - reference_frame_inertials = lambda i, parent_link: "__model__" - reference_frame_collisions = lambda c: "__model__" - reference_frame_link_canonical = "__model__" - - elif frame_convention is FrameConvention.Sdf: - - visual_name_to_parent_link = { - visual_name: parent_link - for d in [{v.name: link for v in link.visuals()} for link in model.links()] - for visual_name, parent_link in d.items() - } - - collision_name_to_parent_link = { - collision_name: parent_link - for d in [ - {c.name: link for c in link.collisions()} for link in model.links() - ] - for collision_name, parent_link in d.items() - } - - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: "__model__" - reference_frame_frames = lambda f: f.attached_to - reference_frame_joints = lambda j: joint.child - reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name - reference_frame_inertials = lambda i, parent_link: parent_link.name - reference_frame_collisions = lambda c: collision_name_to_parent_link[ - c.name - ].name - reference_frame_link_canonical = "__model__" - - elif frame_convention is FrameConvention.Urdf: - - visual_name_to_parent_link = { - visual_name: parent_link - for d in [{v.name: link for v in link.visuals()} for link in model.links()] - for visual_name, parent_link in d.items() - } - - collision_name_to_parent_link = { - collision_name: parent_link - for d in [ - {c.name: link for c in link.collisions()} for link in model.links() - ] - for collision_name, parent_link in d.items() - } - - link_name_to_parent_joint_names = defaultdict(list) - - for j in model.joints(): - if j.child != model.get_canonical_link(): - link_name_to_parent_joint_names[j.child].append(j.name) - else: - # The pose of the canonical link is used to define the origin of - # the URDF joint connecting the world to the robot - assert model.is_fixed_base() - link_name_to_parent_joint_names[j.child].append("world") - - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: link_name_to_parent_joint_names[l.name][0] - reference_frame_frames = lambda f: f.attached_to - reference_frame_joints = lambda j: j.parent - reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name - reference_frame_inertials = lambda i, parent_link: parent_link.name - reference_frame_collisions = lambda c: collision_name_to_parent_link[ - c.name - ].name - - if model.is_fixed_base(): - canonical_link = {l.name: l for l in model.links()}[ - model.get_canonical_link() - ] - reference_frame_link_canonical = reference_frame_links(l=canonical_link) - else: + match frame_convention: + case FrameConvention.World: + reference_frame_model = lambda m: "world" + reference_frame_links = lambda l: "world" + reference_frame_frames = lambda f: "world" + reference_frame_joints = lambda j: "world" + reference_frame_visuals = lambda v: "world" + reference_frame_inertials = lambda i, parent_link: "world" + reference_frame_collisions = lambda c: "world" + reference_frame_link_canonical = "world" + + case FrameConvention.Model: + + reference_frame_model = lambda m: "world" + reference_frame_links = lambda l: "__model__" + reference_frame_frames = lambda f: "__model__" + reference_frame_joints = lambda j: "__model__" + reference_frame_visuals = lambda v: "__model__" + reference_frame_inertials = lambda i, parent_link: "__model__" + reference_frame_collisions = lambda c: "__model__" reference_frame_link_canonical = "__model__" - else: - raise ValueError(frame_convention) + case FrameConvention.Sdf: + + visual_name_to_parent_link = { + visual_name: parent_link + for d in [ + {v.name: link for v in link.visuals()} for link in model.links() + ] + for visual_name, parent_link in d.items() + } + + collision_name_to_parent_link = { + collision_name: parent_link + for d in [ + {c.name: link for c in link.collisions()} for link in model.links() + ] + for collision_name, parent_link in d.items() + } + + reference_frame_model = lambda m: "world" + reference_frame_links = lambda l: "__model__" + reference_frame_frames = lambda f: f.attached_to + reference_frame_joints = lambda j: joint.child + reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name + reference_frame_inertials = lambda i, parent_link: parent_link.name + reference_frame_collisions = lambda c: collision_name_to_parent_link[ + c.name + ].name + reference_frame_link_canonical = "__model__" + + case FrameConvention.Urdf: + + visual_name_to_parent_link = { + visual_name: parent_link + for d in [ + {v.name: link for v in link.visuals()} for link in model.links() + ] + for visual_name, parent_link in d.items() + } + + collision_name_to_parent_link = { + collision_name: parent_link + for d in [ + {c.name: link for c in link.collisions()} for link in model.links() + ] + for collision_name, parent_link in d.items() + } + + link_name_to_parent_joint_names = defaultdict(list) + + for j in model.joints(): + if j.child != model.get_canonical_link(): + link_name_to_parent_joint_names[j.child].append(j.name) + else: + # The pose of the canonical link is used to define the origin of + # the URDF joint connecting the world to the robot + assert model.is_fixed_base() + link_name_to_parent_joint_names[j.child].append("world") + + reference_frame_model = lambda m: "world" + reference_frame_links = lambda l: link_name_to_parent_joint_names[l.name][0] + reference_frame_frames = lambda f: f.attached_to + reference_frame_joints = lambda j: j.parent + reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name + reference_frame_inertials = lambda i, parent_link: parent_link.name + reference_frame_collisions = lambda c: collision_name_to_parent_link[ + c.name + ].name + + if model.is_fixed_base(): + canonical_link = {l.name: l for l in model.links()}[ + model.get_canonical_link() + ] + reference_frame_link_canonical = reference_frame_links(l=canonical_link) + else: + reference_frame_link_canonical = "__model__" + + case _: + raise ValueError(frame_convention) # ========================================= # Process the reference frames of the model @@ -288,33 +292,37 @@ def find_parent_link_of_frame(frame: rod.Frame, model: rod.Model) -> str: assert isinstance(frame, rod.Frame) - if frame.attached_to in links_dict: - parent = links_dict[frame.attached_to] + match frame.attached_to: + case anchor if anchor in links_dict: + parent = links_dict[frame.attached_to] - elif frame.attached_to in frames_dict: - parent = frames_dict[frame.attached_to] + case anchor if anchor in frames_dict: + parent = frames_dict[frame.attached_to] - elif frame.attached in {model.name, "__model__"}: - return model.get_canonical_link() + case anchor if anchor in {model.name, "__model__"}: + return model.get_canonical_link() - elif frame.attached_to in joints_dict: - raise ValueError("Frames cannot be attached to joints") + case anchor if anchor in joints_dict: + raise ValueError("Frames cannot be attached to joints") - elif frame.attached_to in sub_models_dict: - raise RuntimeError("Model composition not yet supported") + case anchor if anchor in sub_models_dict: + raise RuntimeError("Model composition not yet supported") - else: - raise RuntimeError(f"Failed to find element with name '{frame.attached_to}'") + case _: + raise RuntimeError( + f"Failed to find element with name '{frame.attached_to}'" + ) # At this point, the parent is either a link or another frame. assert isinstance(parent, (rod.Link, rod.Frame)) - # If the parent is a link, can stop searching. - if isinstance(parent, rod.Link): - return parent.name + match parent: + # If the parent is a link, can stop searching. + case parent if isinstance(parent, rod.Link): + return parent.name - # If the parent is another frame, keep looking for the parent link. - if isinstance(parent, rod.Frame): - return find_parent_link_of_frame(frame=parent, model=model) + # If the parent is another frame, keep looking for the parent link. + case parent if isinstance(parent, rod.Frame): + return find_parent_link_of_frame(frame=parent, model=model) raise RuntimeError("This recursive function should never arrive here.") diff --git a/src/rod/utils/gazebo.py b/src/rod/utils/gazebo.py index a1ff7ec..98632bf 100644 --- a/src/rod/utils/gazebo.py +++ b/src/rod/utils/gazebo.py @@ -3,12 +3,16 @@ import shutil import subprocess import tempfile -from typing import Union class GazeboHelper: - @staticmethod - def get_gazebo_executable() -> pathlib.Path: + _cached_executable: pathlib.Path = None + + @classmethod + def get_gazebo_executable(cls) -> pathlib.Path: + if cls._cached_executable is not None: + return cls._cached_executable + gz = shutil.which("gz") ign = shutil.which("ign") @@ -25,25 +29,29 @@ def get_gazebo_executable() -> pathlib.Path: raise TypeError(executable) # Check if the sdf plugin of the simulator is installed - cp = subprocess.run([executable, "sdf", "--help"], capture_output=True) - - if cp.returncode != 0: + try: + subprocess.run( + [executable, "sdf", "--help"], check=True, capture_output=True + ) + except subprocess.CalledProcessError as e: msg = f"Failed to find 'sdf' command part of {executable} installation" - raise RuntimeError(msg) + raise RuntimeError(msg) from e + + cls._cached_executable = executable - return executable + return cls._cached_executable @staticmethod def has_gazebo() -> bool: try: _ = GazeboHelper.get_gazebo_executable() return True - except: + except Exception: return False @staticmethod def process_model_description_with_sdformat( - model_description: Union[str, pathlib.Path] + model_description: str | pathlib.Path, ) -> str: # ============================= # Select the correct input type @@ -65,7 +73,9 @@ def process_model_description_with_sdformat( and len(model_description) <= MAX_PATH and pathlib.Path(model_description).is_file() ): - model_description_string = pathlib.Path(model_description).read_text() + model_description_string = pathlib.Path(model_description).read_text( + encoding="utf-8" + ) # Finally, it must be a SDF/URDF string else: @@ -92,16 +102,20 @@ def process_model_description_with_sdformat( fp.write(model_description_string) fp.close() - cp = subprocess.run( - [str(gazebo_executable), "sdf", "-p", fp.name], - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) - - if cp.returncode != 0: - print(cp.stdout) - raise RuntimeError("Failed to process the input with sdformat") + try: + cp = subprocess.run( + [str(gazebo_executable), "sdf", "-p", fp.name], + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=True, + ) + except subprocess.CalledProcessError as e: + if e.returncode != 0: + print(e.stdout) + raise RuntimeError( + "Failed to process the input with sdformat" + ) from e # Get the resulting SDF string sdf_string = cp.stdout diff --git a/src/rod/utils/resolve_frames.py b/src/rod/utils/resolve_frames.py index 5677966..1604c15 100644 --- a/src/rod/utils/resolve_frames.py +++ b/src/rod/utils/resolve_frames.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import functools -from typing import List, Union +from typing import List import numpy as np @@ -8,7 +10,7 @@ def update_element_with_pose( - element: Element, default_relative_to: Union[str, List[str]], explicit_frames: bool + element: Element, default_relative_to: str | List[str], explicit_frames: bool ) -> None: if not hasattr(element, "pose"): raise ValueError("The input element has no 'pose' attribute") @@ -52,7 +54,7 @@ def update_element_with_pose( def resolve_model_frames( - model: "rod.Model", is_top_level: bool = True, explicit_frames: bool = True + model: rod.Model, is_top_level: bool = True, explicit_frames: bool = True ) -> None: # Close the helper for compactness update_element = functools.partial( diff --git a/src/rod/utils/resolve_uris.py b/src/rod/utils/resolve_uris.py index 5d2d194..8806a52 100644 --- a/src/rod/utils/resolve_uris.py +++ b/src/rod/utils/resolve_uris.py @@ -11,7 +11,7 @@ def resolve_local_uri(uri: str) -> pathlib.Path: try: return resolve_robotics_uri_py.resolve_robotics_uri(uri=uri) - except: + except FileNotFoundError: pass # Remove the prefix of the URI diff --git a/tests/utils_models.py b/tests/utils_models.py index 9c10916..64d99df 100644 --- a/tests/utils_models.py +++ b/tests/utils_models.py @@ -1,7 +1,6 @@ import enum import importlib import pathlib -from typing import Union class Robot(enum.IntEnum): @@ -38,7 +37,7 @@ class ModelFactory: """Factory class providing URDF files used by the tests.""" @staticmethod - def get_model_description(robot: Union[Robot, str]) -> pathlib.Path: + def get_model_description(robot: Robot | str) -> pathlib.Path: """ Get the URDF file of different robots.