Skip to content

Commit

Permalink
start migration from dm_robotics
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdavidfagan committed Dec 22, 2023
1 parent c5a3b36 commit 917589b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
14 changes: 14 additions & 0 deletions mujoco_controllers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""A very basic abstract base class for Mujoco Controller"""

import abc

class MujocoController(abc.ABC):
"""Abstract class for Mujoco Controller"""

@abc.abstractmethod
def compute_control_output(self):
pass

@abc.abstractmethod
def is_converged(self):
pass
19 changes: 14 additions & 5 deletions mujoco_controllers/osc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from mujoco import viewer

from dm_control import composer, mjcf
from dm_robotics.transformations.transformations import mat_to_quat, quat_to_mat, quat_to_euler
from dm_robotics.transformations import transformations as tr

from mujoco_controllers.build_env import construct_physics
Expand Down Expand Up @@ -71,7 +70,10 @@ def current_eef_position(self):

@property
def current_eef_quat(self):
return mat_to_quat(self.physics.bind(self.eef_site).xmat.reshape(3,3).copy())
quat = np.zeros(4,)
rot_mat = self.physics.bind(self.eef_site).xmat.copy()
mujoco.mju_mat2Quat(quat, rot_mat)
return quat

@property
def current_eef_velocity(self):
Expand Down Expand Up @@ -132,13 +134,20 @@ def _orientation_error(
quat: np.ndarray,
quat_des: np.ndarray,
) -> np.ndarray:
quat_err = tr.quat_mul(quat, tr.quat_conj(quat_des))
quat_err /= np.linalg.norm(quat_err)

quat_conj = np.zeros(4,)
mujoco.mju_negQuat(quat_conj, quat_des)
quat_conj /= np.linalg.norm(quat_conj)

quat_err = np.zeros(4,)
mujoco.mju_mulQuat(quat_err, quat, quat_conj)

axis_angle = tr.quat_to_axisangle(quat_err)
if quat_err[0] < 0.0:
angle = np.linalg.norm(axis_angle) - 2 * np.pi
else:
angle = np.linalg.norm(axis_angle)
print(axis_angle * angle)
return axis_angle * angle

def current_orientation_error(self):
Expand Down Expand Up @@ -252,5 +261,5 @@ def is_converged(self):

pre_pick_height = 0.6
pick_height = 0.15
default_quat = mat_to_quat(R.from_euler('xyz', [0, 180, 0], degrees=True).as_matrix())
default_quat = (R.from_euler('xyz', [0, 180, 0], degrees=True).as_matrix()).as_quat()

0 comments on commit 917589b

Please sign in to comment.