Skip to content

Commit

Permalink
Merge pull request #31 from flferretti/flferretti-patch-1
Browse files Browse the repository at this point in the history
Improve model handling speed
  • Loading branch information
diegoferigo authored May 17, 2024
2 parents 3c535da + 77a5f47 commit 1964e85
Show file tree
Hide file tree
Showing 20 changed files with 347 additions and 301 deletions.
6 changes: 3 additions & 3 deletions src/rod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def check_compatible_sdformat(specification_version: str) -> None:

if not GazeboHelper.has_gazebo():
return
else:
cmdline = GazeboHelper.get_gazebo_executable()
logging.info(f"Calling sdformat through '{cmdline} sdf'")

cmdline = GazeboHelper.get_gazebo_executable()
logging.info(f"Calling sdformat through '{cmdline} sdf'")

output_sdf_version = packaging.version.Version(
xmltodict.parse(
Expand Down
82 changes: 42 additions & 40 deletions src/rod/builder/primitive_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import abc
import dataclasses
from typing import Optional, Union
from typing import Optional

import numpy as np
import numpy.typing as npt
Expand All @@ -14,15 +16,15 @@ class PrimitiveBuilder(abc.ABC):
name: str
mass: float

element: Union[rod.Model, rod.Link, rod.Inertial, rod.Collision, rod.Visual] = (
element: rod.Model | rod.Link | rod.Inertial | rod.Collision | rod.Visual = (
dataclasses.field(
default=None, init=False, repr=False, hash=False, compare=False
)
)

def build(
self,
) -> Union[rod.Model, rod.Link, rod.Inertial, rod.Collision, rod.Visual]:
) -> rod.Model | rod.Link | rod.Inertial | rod.Collision | rod.Visual:
return self.element

# ================
Expand All @@ -43,9 +45,9 @@ def _geometry(self) -> rod.Geometry:

def build_model(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
) -> "PrimitiveBuilder":
name: str | None = None,
pose: rod.Pose | None = None,
) -> PrimitiveBuilder:
self._check_element()

self.element = self._model(name=name, pose=pose)
Expand All @@ -54,16 +56,16 @@ def build_model(

def build_link(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
) -> "PrimitiveBuilder":
name: str | None = None,
pose: rod.Pose | None = None,
) -> PrimitiveBuilder:
self._check_element()

self.element = self._link(name=name, pose=pose)

return self

def build_inertial(self, pose: Optional[rod.Pose] = None) -> "PrimitiveBuilder":
def build_inertial(self, pose: rod.Pose | None = None) -> PrimitiveBuilder:
self._check_element()

self.element = self._inertial(pose=pose)
Expand All @@ -72,9 +74,9 @@ def build_inertial(self, pose: Optional[rod.Pose] = None) -> "PrimitiveBuilder":

def build_visual(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
) -> "PrimitiveBuilder":
name: str | None = None,
pose: rod.Pose | None = None,
) -> PrimitiveBuilder:
self._check_element()

self.element = self._visual(name=name, pose=pose)
Expand All @@ -83,9 +85,9 @@ def build_visual(

def build_collision(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
) -> "PrimitiveBuilder":
name: str | None = None,
pose: rod.Pose | None = None,
) -> PrimitiveBuilder:
self._check_element()

self.element = self._collision(name=name, pose=pose)
Expand All @@ -98,10 +100,10 @@ def build_collision(

def add_link(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
link: Optional[rod.Link] = None,
) -> "PrimitiveBuilder":
name: str | None = None,
pose: rod.Pose | None = None,
link: rod.Link | None = None,
) -> PrimitiveBuilder:
if not isinstance(self.element, rod.Model):
raise ValueError(type(self.element))

Expand All @@ -116,9 +118,9 @@ def add_link(

def add_inertial(
self,
pose: Optional[rod.Pose] = None,
inertial: Optional[rod.Inertial] = None,
) -> "PrimitiveBuilder":
pose: rod.Pose | None = None,
inertial: rod.Inertial | None = None,
) -> PrimitiveBuilder:
if not isinstance(self.element, (rod.Model, rod.Link)):
raise ValueError(type(self.element))

Expand All @@ -144,11 +146,11 @@ def add_inertial(

def add_visual(
self,
name: Optional[str] = None,
name: str | None = None,
use_inertial_pose: bool = True,
pose: Optional[rod.Pose] = None,
visual: Optional[rod.Visual] = None,
) -> "PrimitiveBuilder":
pose: rod.Pose | None = None,
visual: rod.Visual | None = None,
) -> PrimitiveBuilder:
if not isinstance(self.element, (rod.Model, rod.Link)):
raise ValueError(type(self.element))

Expand Down Expand Up @@ -180,11 +182,11 @@ def add_visual(

def add_collision(
self,
name: Optional[str] = None,
name: str | None = None,
use_inertial_pose: bool = True,
pose: Optional[rod.Pose] = None,
collision: Optional[rod.Collision] = None,
) -> "PrimitiveBuilder":
pose: rod.Pose | None = None,
collision: rod.Collision | None = None,
) -> PrimitiveBuilder:
if not isinstance(self.element, (rod.Model, rod.Link)):
raise ValueError(type(self.element))

Expand Down Expand Up @@ -224,8 +226,8 @@ def add_collision(

def _model(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
name: str | None = None,
pose: rod.Pose | None = None,
) -> rod.Model:
name = name if name is not None else self.name
logging.debug(f"Building model '{name}'")
Expand All @@ -240,15 +242,15 @@ def _model(

def _link(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
name: str | None = None,
pose: rod.Pose | None = None,
) -> rod.Link:
return rod.Link(
name=name if name is not None else f"{self.name}_link",
pose=pose,
)

def _inertial(self, pose: Optional[rod.Pose] = None) -> rod.Inertial:
def _inertial(self, pose: rod.Pose | None = None) -> rod.Inertial:
return rod.Inertial(
pose=pose,
mass=self.mass,
Expand All @@ -257,8 +259,8 @@ def _inertial(self, pose: Optional[rod.Pose] = None) -> rod.Inertial:

def _visual(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
name: str | None = None,
pose: rod.Pose | None = None,
) -> rod.Visual:
name = name if name is not None else f"{self.name}_visual"

Expand All @@ -271,7 +273,7 @@ def _visual(
def _collision(
self,
name: Optional[str],
pose: Optional[rod.Pose] = None,
pose: rod.Pose | None = None,
) -> rod.Collision:
name = name if name is not None else f"{self.name}_collision"

Expand All @@ -297,7 +299,7 @@ def build_pose(
relative_to: str = None,
degrees: bool = None,
rotation_format: str = None,
) -> Optional[rod.Pose]:
) -> rod.Pose | None:
if pos is None and rpy is None:
return rod.Pose.from_transform(transform=np.eye(4), relative_to=relative_to)

Expand Down
3 changes: 1 addition & 2 deletions src/rod/builder/primitives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses
import pathlib
from typing import Union

import trimesh
from numpy.typing import NDArray
Expand Down Expand Up @@ -63,7 +62,7 @@ def _geometry(self) -> rod.Geometry:

@dataclasses.dataclass
class MeshBuilder(PrimitiveBuilder):
mesh_path: Union[str, pathlib.Path]
mesh_path: str | pathlib.Path
scale: NDArray

def __post_init__(self) -> None:
Expand Down
10 changes: 6 additions & 4 deletions src/rod/kinematics/kinematic_tree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import copy
import dataclasses
import functools
from typing import Dict, List, Sequence, Tuple, Union
from typing import Dict, List, Sequence, Tuple

import numpy as np

Expand All @@ -12,7 +14,7 @@

@dataclasses.dataclass(frozen=True)
class KinematicTree(DirectedTree):
model: "rod.Model"
model: rod.Model

joints: List[TreeEdge] = dataclasses.field(default_factory=list)
frames: List[TreeFrame] = dataclasses.field(default_factory=list)
Expand Down Expand Up @@ -46,7 +48,7 @@ def joint_names(self) -> List[str]:
return [joint.name() for joint in self.joints]

@staticmethod
def build(model: "rod.Model", is_top_level: bool = True) -> "KinematicTree":
def build(model: rod.Model, is_top_level: bool = True) -> KinematicTree:
logging.debug(msg=f"Building kinematic tree of model '{model.name}'")

if model.model is not None:
Expand Down Expand Up @@ -199,7 +201,7 @@ def build(model: "rod.Model", is_top_level: bool = True) -> "KinematicTree":
new_base_node, additional_frames = KinematicTree.remove_edge(
edge=world_to_base_edge, keep_parent=False
)
assert any([f.name() == TreeFrame.WORLD for f in additional_frames])
assert any(f.name() == TreeFrame.WORLD for f in additional_frames)

# Replace the former base node with the new base node
nodes_links_dict[new_base_node.name()] = new_base_node
Expand Down
Loading

0 comments on commit 1964e85

Please sign in to comment.