Skip to content

Commit

Permalink
[fix] fix warm start for position retargeting with euler angle initia…
Browse files Browse the repository at this point in the history
…lized dummy joint angles
  • Loading branch information
yuzheqin committed Sep 17, 2024
1 parent b33035a commit a80caf5
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 49 deletions.
2 changes: 1 addition & 1 deletion dex_retargeting/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.3"
__version__ = "0.4.4"
24 changes: 24 additions & 0 deletions dex_retargeting/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,24 @@
from pathlib import Path
from typing import Optional

import numpy as np

OPERATOR2MANO_RIGHT = np.array(
[
[0, 0, -1],
[-1, 0, 0],
[0, 1, 0],
]
)

OPERATOR2MANO_LEFT = np.array(
[
[0, 0, -1],
[1, 0, 0],
[0, -1, 0],
]
)


class RobotName(enum.Enum):
allegro = enum.auto()
Expand Down Expand Up @@ -59,3 +77,9 @@ def get_default_config_path(
else:
config_name = f"{robot_name_str}_{hand_type_str}.yml"
return config_path / config_name


OPERATOR2MANO = {
HandType.right: OPERATOR2MANO_RIGHT,
HandType.left: OPERATOR2MANO_LEFT,
}
14 changes: 13 additions & 1 deletion dex_retargeting/robot_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,19 @@ def get_joint_index(self, name: str):
def get_link_index(self, name: str):
if name not in self.link_names:
raise ValueError(f"{name} is not a link name. Valid link names: \n{self.link_names}")
return self.model.getFrameId(name)
return self.model.getFrameId(name, pin.BODY)

def get_joint_parent_child_frames(self, joint_name: str):
joint_id = self.model.getFrameId(joint_name)
parent_id = self.model.frames[joint_id].parent
child_id = -1
for idx, frame in enumerate(self.model.frames):
if frame.previousFrame == joint_id:
child_id = idx
if child_id == -1:
raise ValueError(f"Can not find child link of {joint_name}")

return parent_id, child_id

# -------------------------------------------------------------------------- #
# Kinematics function
Expand Down
48 changes: 28 additions & 20 deletions dex_retargeting/seq_retarget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from pytransform3d import rotations

from dex_retargeting.constants import OPERATOR2MANO, HandType
from dex_retargeting.optimizer import Optimizer
from dex_retargeting.optimizer_utils import LPFilter

Expand Down Expand Up @@ -39,40 +40,37 @@ def __init__(
# Warm started
self.is_warm_started = False

# TODO: hack here
self.scene = None

def warm_start(self, wrist_pos: np.ndarray, wrist_orientation: np.ndarray, global_rot: np.array):
def warm_start(
self,
wrist_pos: np.ndarray,
wrist_quat: np.ndarray,
hand_type: HandType = HandType.right,
is_mano_convention: bool = False,
):
"""
Initialize the wrist joint pose using analytical computation instead of retargeting optimization.
This function is specifically for position retargeting with the flying robot hand, i.e. has 6D free joint
You are not expected to use this function for vector retargeting, e.g. when you are working on teleoperation
Args:
wrist_pos: position of the hand wrist, typically from human hand pose
wrist_orientation: orientation of the hand orientation, typically from human hand pose in MANO convention
global_rot:
wrist_quat: quaternion of the hand wrist, the same convention as the operator frame definition if not is_mano_convention
hand_type: hand type, used to determine the operator2mano matrix
is_mano_convention: whether the wrist_quat is in mano convention
"""
# This function can only be used when the first joints of robot are free joints
if len(wrist_pos) != 3:
raise ValueError(f"Wrist pos:{wrist_pos} is not a 3-dim vector.")
if len(wrist_orientation) != 3:
raise ValueError(f"Wrist orientation:{wrist_orientation} is not a 3-dim vector.")

if np.linalg.norm(wrist_orientation) < 1e-3:
mat = np.eye(3)
else:
mat = rotations.matrix_from_compact_axis_angle(wrist_orientation)
if len(wrist_pos) != 3:
raise ValueError(f"Wrist pos: {wrist_pos} is not a 3-dim vector.")
if len(wrist_quat) != 4:
raise ValueError(f"Wrist quat: {wrist_quat} is not a 4-dim vector.")

operator2mano = OPERATOR2MANO[hand_type] if is_mano_convention else np.eye(3)
robot = self.optimizer.robot
operator2mano = np.array([[0, 0, -1], [-1, 0, 0], [0, 1, 0]])
mat = global_rot.T @ mat @ operator2mano
target_wrist_pose = np.eye(4)
target_wrist_pose[:3, :3] = mat
target_wrist_pose[:3, :3] = rotations.matrix_from_quaternion(wrist_quat) @ operator2mano.T
target_wrist_pose[:3, 3] = wrist_pos

wrist_link_name = self.optimizer.wrist_link_name
wrist_link_id = self.optimizer.robot.get_link_index(wrist_link_name)
name_list = [
"dummy_x_translation_joint",
"dummy_y_translation_joint",
Expand All @@ -81,6 +79,9 @@ def warm_start(self, wrist_pos: np.ndarray, wrist_orientation: np.ndarray, globa
"dummy_y_rotation_joint",
"dummy_z_rotation_joint",
]
wrist_link_id = robot.get_joint_parent_child_frames(name_list[5])[1]

# Set the dummy joints angles to zero
old_qpos = robot.q0
new_qpos = old_qpos.copy()
for num, joint_name in enumerate(self.optimizer.target_joint_names):
Expand Down Expand Up @@ -128,6 +129,13 @@ def set_qpos(self, robot_qpos: np.ndarray):
target_qpos = robot_qpos[self.optimizer.idx_pin2target]
self.last_qpos = target_qpos

def get_qpos(self, fixed_qpos: np.ndarray | None = None):
robot_qpos = np.zeros(self.optimizer.robot.dof)
robot_qpos[self.optimizer.idx_pin2target] = self.last_qpos
if fixed_qpos is not None:
robot_qpos[self.optimizer.idx_pin2fixed] = fixed_qpos
return robot_qpos

def verbose(self):
min_value = self.optimizer.opt.last_optimum_value()
print(f"Retargeting {self.num_retargeting} times takes: {self.accumulated_time}s")
Expand Down
54 changes: 27 additions & 27 deletions example/position_retargeting/hand_robot_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,22 @@
from pathlib import Path
from typing import Dict, List

import numpy as np
import cv2
from tqdm import trange
import numpy as np
import sapien
import transforms3d.quaternions
from hand_viewer import HandDatasetSAPIENViewer
from pytransform3d import rotations
from tqdm import trange

from dex_retargeting import yourdfpy as urdf
from dex_retargeting.constants import RobotName, HandType, get_default_config_path, RetargetingType
from dex_retargeting.constants import (
HandType,
RetargetingType,
RobotName,
get_default_config_path,
)
from dex_retargeting.retargeting_config import RetargetingConfig
from dex_retargeting.seq_retarget import SeqRetargeting
from hand_viewer import HandDatasetSAPIENViewer

ROBOT2MANO = np.array(
[
[0, 0, -1],
[-1, 0, 0],
[0, 1, 0],
]
)
ROBOT2MANO_POSE = sapien.Pose(q=transforms3d.quaternions.mat2quat(ROBOT2MANO))


def prepare_position_retargeting(joint_pos: np.array, link_hand_indices: np.ndarray):
link_pos = joint_pos[link_hand_indices]
return link_pos


def prepare_vector_retargeting(joint_pos: np.array, link_hand_indices_pairs: np.ndarray):
joint_pos = joint_pos @ ROBOT2MANO
origin_link_pos = joint_pos[link_hand_indices_pairs[0]]
task_link_pos = joint_pos[link_hand_indices_pairs[1]]
return task_link_pos - origin_link_pos


class RobotHandDatasetSAPIENViewer(HandDatasetSAPIENViewer):
Expand All @@ -45,6 +29,7 @@ def __init__(self, robot_names: List[RobotName], hand_type: HandType, headless=F
self.robot_file_names: List[str] = []
self.retargetings: List[SeqRetargeting] = []
self.retarget2sapien: List[np.ndarray] = []
self.hand_type = hand_type

# Load optimizer and filter
loader = self.scene.create_urdf_loader()
Expand Down Expand Up @@ -126,7 +111,22 @@ def render_dexycb_data(self, data: Dict, fps=5, y_offset=0.8):
robot_names = "_".join(robot_names)
video_path = Path(__file__).parent.resolve() / f"data/{robot_names}_video.mp4"
writer = cv2.VideoWriter(
str(video_path), cv2.VideoWriter_fourcc(*"mp4v"), 30.0, (self.camera.get_width(), self.camera.get_height())
str(video_path),
cv2.VideoWriter_fourcc(*"mp4v"),
30.0,
(self.camera.get_width(), self.camera.get_height()),
)

# Warm start
hand_pose_start = hand_pose[start_frame]
wrist_quat = rotations.quaternion_from_compact_axis_angle(hand_pose_start[0, 0:3])
vertex, joint = self._compute_hand_geometry(hand_pose_start)
for robot, retargeting, retarget2sapien in zip(self.robots, self.retargetings, self.retarget2sapien):
retargeting.warm_start(
joint[0, :],
wrist_quat,
hand_type=self.hand_type,
is_mano_convention=True,
)

# Loop rendering
Expand Down

0 comments on commit a80caf5

Please sign in to comment.