Skip to content

Commit

Permalink
Merge pull request #19 from ami-iit/feature/remove_plucker
Browse files Browse the repository at this point in the history
Remove references of Plucker coordinates
  • Loading branch information
diegoferigo authored Sep 23, 2022
2 parents 154904d + ea409d1 commit 74b57e8
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 169 deletions.
101 changes: 52 additions & 49 deletions src/jaxsim/math/adjoint.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,75 @@
import jax.numpy as jnp

import jaxsim.typing as jtp
from jaxsim.sixd import so3

from .quaternion import Quaternion
from .skew import Skew


class Adjoint:
@staticmethod
def rotate_x(theta: float) -> jtp.Matrix:
def from_quaternion_and_translation(
quaternion: jtp.Vector = jnp.array([1.0, 0, 0, 0]),
translation: jtp.Vector = jnp.zeros(3),
inverse: bool = False,
normalize_quaternion: bool = False,
) -> jtp.Matrix:

c = jnp.cos(theta).squeeze()
s = jnp.sin(theta).squeeze()
assert quaternion.size == 4
assert translation.size == 3

return jnp.array(
[
[1, 0, 0, 0, 0, 0],
[0, c, s, 0, 0, 0],
[0, -s, c, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, c, s],
[0, 0, 0, 0, -s, c],
]
Q_sixd = so3.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion))
Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()

return Adjoint.from_rotation_and_translation(
rotation=Q_sixd.as_matrix(), translation=translation, inverse=inverse
)

@staticmethod
def rotate_y(theta: float) -> jtp.Matrix:
def from_rotation_and_translation(
rotation: jtp.Matrix = jnp.eye(3),
translation: jtp.Vector = jnp.zeros(3),
inverse: bool = False,
) -> jtp.Matrix:

c = jnp.cos(theta).squeeze()
s = jnp.sin(theta).squeeze()

return jnp.array(
[
[c, 0, -s, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[s, 0, c, 0, 0, 0],
[0, 0, 0, c, 0, -s],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, s, 0, c],
]
)
assert rotation.shape == (3, 3)
assert translation.size == 3

@staticmethod
def rotate_z(theta: float) -> jtp.Matrix:
A_R_B = rotation.squeeze()
A_o_B = translation.squeeze()

c = jnp.cos(theta).squeeze()
s = jnp.sin(theta).squeeze()
if not inverse:
X = A_X_B = jnp.block(
[
[A_R_B, Skew.wedge(A_o_B) @ A_R_B],
[jnp.zeros(shape=(3, 3)), A_R_B],
]
)
else:
X = B_X_A = jnp.block(
[
[A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)],
[jnp.zeros(shape=(3, 3)), A_R_B.T],
]
)

return jnp.array(
[
[c, s, 0, 0, 0, 0],
[-s, c, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 0, c, s, 0],
[0, 0, 0, -s, c, 0],
[0, 0, 0, 0, 0, 1],
]
)
return X

@staticmethod
def translate(direction: jtp.Vector) -> jtp.Matrix:
def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix:

X = adjoint.squeeze()
assert X.shape == (6, 6)

x, y, z = direction
R = X[0:3, 0:3]
o_x_R = X[0:3, 3:6]

return jnp.array(
H = jnp.block(
[
[1, 0, 0, 0, z, -y],
[0, 1, 0, -z, 0, x],
[0, 0, 1, y, -x, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1],
[R, Skew.vee(matrix=o_x_R @ R.T)],
[0, 0, 0, 1],
]
)

return H
44 changes: 33 additions & 11 deletions src/jaxsim/math/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from jaxsim.parsers.descriptions import JointDescriptor, JointGenericAxis, JointType

from .adjoint import Adjoint
from .plucker import Plucker
from .rotation import Rotation


Expand All @@ -27,46 +26,69 @@ def jcalc(
elif code is JointType.R:

jtyp: JointGenericAxis
Xj = Plucker.from_rot_and_trans(
dcm=Rotation.from_axis_angle(vector=(q * jtyp.axis)),
translation=jnp.zeros(3),

Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.from_axis_angle(vector=(q * jtyp.axis)), inverse=True
)

S = jnp.vstack(jnp.hstack([jnp.zeros(3), jtyp.axis.squeeze()]))

elif code is JointType.P:

jtyp: JointGenericAxis
Xj = Adjoint.translate(direction=(q * jtyp.axis))

Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array(q * jtyp.axis), inverse=True
)

S = jnp.vstack(jnp.hstack([jtyp.axis.squeeze(), jnp.zeros(3)]))

elif code is JointType.Rx:

Xj = Adjoint.rotate_x(theta=q)
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.x(theta=q), inverse=True
)

S = jnp.vstack([0, 0, 0, 1.0, 0, 0])

elif code is JointType.Ry:

Xj = Adjoint.rotate_y(theta=q)
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.y(theta=q), inverse=True
)

S = jnp.vstack([0, 0, 0, 0, 1.0, 0])

elif code is JointType.Rz:

Xj = Adjoint.rotate_z(theta=q)
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.z(theta=q), inverse=True
)

S = jnp.vstack([0, 0, 0, 0, 0, 1.0])

elif code is JointType.Px:

Xj = Adjoint.translate(direction=jnp.hstack([q, 0.0, 0.0]))
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([q, 0.0, 0.0]), inverse=True
)

S = jnp.vstack([1.0, 0, 0, 0, 0, 0])

elif code is JointType.Py:

Xj = Adjoint.translate(direction=jnp.hstack([0.0, q, 0.0]))
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([0.0, q, 0.0]), inverse=True
)

S = jnp.vstack([0, 1.0, 0, 0, 0, 0])

elif code is JointType.Pz:

Xj = Adjoint.translate(direction=jnp.hstack([0.0, 0.0, q]))
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([0.0, 0.0, q]), inverse=True
)

S = jnp.vstack([0, 0, 1.0, 0, 0, 0])

else:
Expand Down
69 changes: 33 additions & 36 deletions src/jaxsim/math/quaternion.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,34 @@
import jax.lax
import jax.numpy as jnp

import jaxsim.typing as jtp

from .skew import Skew
from jaxsim.sixd import so3


class Quaternion:
@staticmethod
def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix:

q = quaternion / jnp.linalg.norm(quaternion)

q0s = q[0] * q[0]
q1s = q[1] * q[1]
q2s = q[2] * q[2]
q3s = q[3] * q[3]
q01 = q[0] * q[1]
q02 = q[0] * q[2]
q03 = q[0] * q[3]
q12 = q[1] * q[2]
q13 = q[3] * q[1]
q23 = q[2] * q[3]

R = 2 * jnp.array(
[
[q0s + q1s - 0.5, q12 + q03, q13 - q02],
[q12 - q03, q0s + q2s - 0.5, q23 + q01],
[q13 + q02, q23 - q01, q0s + q3s - 0.5],
]
)
def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector:

return R.squeeze()
return wxyz.squeeze()[jnp.array([1, 2, 3, 0])]

@staticmethod
def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:
def to_wxyz(xyzw: jtp.Vector) -> jtp.Vector:

R = dcm.squeeze()
return xyzw.squeeze()[jnp.array([3, 0, 1, 2])]

tr = jnp.trace(R)
v = -Skew.vee(R)
@staticmethod
def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix:

q = jnp.vstack([(tr + 1) / 2.0, v])
return so3.SO3.from_quaternion_xyzw(
xyzw=Quaternion.to_xyzw(quaternion)
).as_matrix()

return jnp.vstack(q) / jnp.linalg.norm(q)
@staticmethod
def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:

return Quaternion.to_wxyz(
xyzw=so3.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw()
)

@staticmethod
def derivative(
Expand All @@ -53,11 +39,13 @@ def derivative(
) -> jtp.Vector:

w = omega.squeeze()
qw, qx, qy, qz = quaternion.squeeze()
quaternion = quaternion.squeeze()

def Q_body(q: jtp.Vector) -> jtp.Matrix:

if omega_in_body_fixed:
qw, qx, qy, qz = q

Q = jnp.array(
return jnp.array(
[
[qw, -qx, -qy, -qz],
[qx, qw, -qz, qy],
Expand All @@ -66,9 +54,11 @@ def derivative(
]
)

else:
def Q_inertial(q: jtp.Vector) -> jtp.Matrix:

Q = jnp.array(
qw, qx, qy, qz = q

return jnp.array(
[
[qw, -qx, -qy, -qz],
[qx, qw, qz, -qy],
Expand All @@ -77,6 +67,13 @@ def derivative(
]
)

Q = jax.lax.cond(
pred=omega_in_body_fixed,
true_fun=Q_body,
false_fun=Q_inertial,
operand=quaternion,
)

qd = 0.5 * (
Q
@ jnp.hstack(
Expand Down
Loading

0 comments on commit 74b57e8

Please sign in to comment.