Skip to content

Commit

Permalink
dev - new visu script
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Mar 12, 2024
1 parent 76174db commit 6318ba6
Show file tree
Hide file tree
Showing 11 changed files with 262 additions and 26 deletions.
118 changes: 109 additions & 9 deletions hitchhiking_rotations/cfgs/cfg_cube_image_to_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def get_cfg_cube_image_to_pose(device):
shared_trainer_cfg = {
"_target_": "hitchhiking_rotations.utils.Trainer",
"lr": 0.001,
"optimizer": "SGD",
"optimizer": "Adam",
"logger": "${logger}",
"verbose": "${verbose}",
"device": device,
Expand All @@ -17,7 +17,7 @@ def get_cfg_cube_image_to_pose(device):
return {
"verbose": True,
"batch_size": 32,
"epochs": 100,
"epochs": 1000,
"training_data": {
"_target_": "hitchhiking_rotations.datasets.CubeImageToPoseDataset",
"mode": "train",
Expand Down Expand Up @@ -45,7 +45,7 @@ def get_cfg_cube_image_to_pose(device):
"metrics": ["l1", "l2", "geodesic_distance", "chordal_distance"],
},
"trainers": {
"r9_l1": {
"r9_svd_l1": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -55,7 +55,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model9}",
},
},
"r9_l2": {
"r9_svd_l2": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -65,7 +65,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model9}",
},
},
"r9_geodesic_distance": {
"r9_svd_geodesic_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -75,7 +75,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model9}",
},
},
"r9_chordal_distance": {
"r9_svd_chordal_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -85,7 +85,67 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model9}",
},
},
"r9_l1": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:flatten}",
"postprocess_pred_loss": "${u:flatten}",
"postprocess_pred_logging": "${u:procrustes_to_rotmat}",
"loss": "${u:l1}",
"model": "${model9}",
},
},
"r9_l2": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:flatten}",
"postprocess_pred_loss": "${u:flatten}",
"postprocess_pred_logging": "${u:procrustes_to_rotmat}",
"loss": "${u:l2}",
"model": "${model9}",
},
},
"r9_geodesic_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
"postprocess_pred_loss": "${u:n_3x3}",
"postprocess_pred_logging": "${u:procrustes_to_rotmat}",
"loss": "${u:geodesic_distance}",
"model": "${model9}",
},
},
"r9_chordal_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
"postprocess_pred_loss": "${u:n_3x3}",
"postprocess_pred_logging": "${u:procrustes_to_rotmat}",
"loss": "${u:chordal_distance}",
"model": "${model9}",
},
},
"r6_l1": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_gramschmidt_f}",
"postprocess_pred_loss": "${u:flatten}",
"postprocess_pred_logging": "${u:gramschmidt_to_rotmat}",
"loss": "${u:l1}",
"model": "${model6}",
},
},
"r6_l2": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_gramschmidt_f}",
"postprocess_pred_loss": "${u:flatten}",
"postprocess_pred_logging": "${u:gramschmidt_to_rotmat}",
"loss": "${u:l2}",
"model": "${model6}",
},
},
"r6_gso_l1": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -95,7 +155,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model6}",
},
},
"r6_l2": {
"r6_gso_l2": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -105,7 +165,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model6}",
},
},
"r6_geodesic_distance": {
"r6_gso_geodesic_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -115,7 +175,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model6}",
},
},
"r6_chordal_distance": {
"r6_gso_chordal_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand Down Expand Up @@ -175,6 +235,46 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model4}",
},
},
"quat_rf_cosine_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_quaternion_rand_flip}",
"postprocess_pred_loss": "${u:passthrough}",
"postprocess_pred_logging": "${u:quaternion_to_rotmat}",
"loss": "${u:cosine_distance}",
"model": "${model4}",
},
},
"quat_rf_l2": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_quaternion_rand_flip}",
"postprocess_pred_loss": "${u:passthrough}",
"postprocess_pred_logging": "${u:quaternion_to_rotmat}",
"loss": "${u:l2}",
"model": "${model4}",
},
},
"quat_rf_l1": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_quaternion_rand_flip}",
"postprocess_pred_loss": "${u:passthrough}",
"postprocess_pred_logging": "${u:quaternion_to_rotmat}",
"loss": "${u:l1}",
"model": "${model4}",
},
},
"quat_rf_l2_dp": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_quaternion_rand_flip}",
"postprocess_pred_loss": "${u:passthrough}",
"postprocess_pred_logging": "${u:quaternion_to_rotmat}",
"loss": "${u:l2_dp}",
"model": "${model4}",
},
},
"rotvec_l1": {
**shared_trainer_cfg,
**{
Expand Down
16 changes: 8 additions & 8 deletions hitchhiking_rotations/datasets/cube_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ def __init__(self, height: int, width: int):
<mujoco>
<worldbody>
<light name="top" pos="0 0 0"/>
<body name="cube" euler="0 0 0">
<body name="cube" euler="0 0 0" pos="0 0 0">
<joint type="ball" stiffness="0" damping="0" frictionloss="0" armature="0"/>
<geom type="box" size="0.1 0.1 0.1" pos="0 0 0" rgba="0.5 0.5 0.5 1"/>
<geom type="box" size="1 1 0.01" pos="0 0 0.9" rgba="1 0 0 1"/>
<geom type="box" size="1 1 0.01" pos="0 0 -0.99" rgba="0 0 1 1"/>
<geom type="box" size="0.01 1 1" pos="0.99 0 0" rgba="0 1 0 1"/>
<geom type="box" size="0.01 1 1" pos="-0.99 0 0" rgba="0 0.6 0.6 1"/>
<geom type="box" size="1 0.01 1" pos="0 0.99 0" rgba="0.6 0.6 0 1"/>
<geom type="box" size="1 0.01 1" pos="0 -0.99 0" rgba="0.6 0 0.6 1"/>
<geom type="box" size="0.1 0.1 0.1" pos="0 0 0" rgba="0.5 0.5 0.5 1"/>
<geom type="box" size="1 1 0.01" pos="0 0 0.9" rgba="1 0 0 1"/>
<geom type="box" size="1 1 0.01" pos="0 0 -0.99" rgba="0 0 1 1"/>
<geom type="box" size="0.01 1 1" pos="0.99 0 0" rgba="0 1 0 1"/>
<geom type="box" size="0.01 1 1" pos="-0.99 0 0" rgba="0 0.6 0.6 1"/>
<geom type="box" size="1 0.01 1" pos="0 0.99 0" rgba="0.6 0.6 0 1"/>
<geom type="box" size="1 0.01 1" pos="0 -0.99 0" rgba="0.6 0 0.6 1"/>
</body>
</worldbody>
</mujoco>
Expand Down
12 changes: 12 additions & 0 deletions hitchhiking_rotations/datasets/cube_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,15 @@ def __init__(self, mode, dataset_size, device):

def __getitem__(self, idx):
return roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32), self.imgs[idx].type(torch.float32) / 255


if __name__ == "__main__":
from PIL import Image
import numpy as np
from hitchhiking_rotations import HITCHHIKING_ROOT_DIR

dataset = CubeImageToPoseDataset("train", 2048, "cpu")
for i in range(10):
img, quat = dataset[i]
img = Image.fromarray(np.uint8(img.cpu().numpy() * 255))
img.save(join(HITCHHIKING_ROOT_DIR, "results", f"example_img_{i}.png"))
2 changes: 1 addition & 1 deletion hitchhiking_rotations/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
from .logger import OrientationLogger
from .trainer import Trainer
from .loading import *
from .helper import passthrough, flatten
from .helper import passthrough, flatten, n_3x3
from .notation import RotRep
4 changes: 4 additions & 0 deletions hitchhiking_rotations/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ def passthrough(*x):

def flatten(x):
return x.reshape(x.shape[0], -1)


def n_3x3(x):
return x.reshape(-1, 3, 3)
3 changes: 3 additions & 0 deletions hitchhiking_rotations/utils/notation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ class RotRep(Enum):
SVD = "$\mathbb{R}^9$+SVD"
QUAT_C = "Quat$^+$"
QUAT = "Quat"
QUAT_RF = "Quat+RF"
EULER = "Euler"
EXP = "Exp"
ROTMAT = "$\mathbb{R}^9$"
RSIX = "$\mathbb{R}^6$"

def __str__(self):
return "%s" % self.value
10 changes: 10 additions & 0 deletions scripts/run_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from hitchhiking_rotations import HITCHHIKING_ROOT_DIR
import os

p = os.path.join(HITCHHIKING_ROOT_DIR, "scripts", "train.py")

for seed in range(10):
os.system(f"python3 {p} --experiment cube_image_to_pose --seed {seed}")

for seed in range(10):
os.system(f"python3 {p} --experiment pose_to_cube_image --seed {seed}")
11 changes: 7 additions & 4 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"--experiment",
type=str,
choices=["cube_image_to_pose", "pose_to_cube_image", "pcd_to_pose"] + fourier_choices,
default="pose_to_cube_image",
default="cube_image_to_pose",
help="Experiment Configuration",
)
parser.add_argument(
Expand Down Expand Up @@ -93,9 +93,12 @@

trainer.train_batch(x.clone(), target.clone(), epoch)

if cfg_exp.verbose:
scores = [t.logger.get_score("train", "loss") for t in trainers.values()]
bar.set_postfix({"running_train_loss": np.array(scores).mean()})
try:
if cfg_exp.verbose:
scores = [t.logger.get_score("train", "loss") for t in trainers.values()]
bar.set_postfix({"running_train_loss": np.array(scores).mean()})
except:
pass

if validate_every_n > 0 and epoch % validate_every_n == 0:
# Perform validation
Expand Down
4 changes: 2 additions & 2 deletions visu/figure_12a.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
df = pd.DataFrame.from_dict(df_res)

mapping = {
"r9": RotRep.SVD,
"r6": RotRep.GSO,
"r9_svd": RotRep.SVD,
"r6_gso": RotRep.GSO,
"quat_c": RotRep.QUAT_C,
"rotvec": RotRep.EXP,
"euler": RotRep.EULER,
Expand Down
4 changes: 2 additions & 2 deletions visu/figure_12b.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@

if rename_and_filter:
mapping = {
"r9": RotRep.SVD,
"r6": RotRep.GSO,
"r9_svo": RotRep.SVD,
"r6_gso": RotRep.GSO,
"quat_c": RotRep.QUAT_C,
"quat_rf": str(RotRep.QUAT) + "_rf",
"rotvec": RotRep.EXP,
Expand Down
Loading

0 comments on commit 6318ba6

Please sign in to comment.