Skip to content

Commit

Permalink
Merge pull request #265 from MyoHub/dev
Browse files Browse the repository at this point in the history
BUGFIX: Manipulation ENV
  • Loading branch information
Vittorio-Caggiano authored Oct 31, 2024
2 parents 2be0106 + b4c9ce5 commit 7927eb0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion myosuite/envs/myo/myochallenge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def register_env_with_variants(id, entry_point, max_episode_steps, kwargs):

register_env_with_variants(id='myoChallengeBimanual-v0',
entry_point='myosuite.envs.myo.myochallenge.bimanual_v0:BimanualEnvV1',
max_episode_steps=300,
max_episode_steps=1000,
kwargs={
'model_path': curr_dir + '/../assets/arm/myoarm_bionic_bimanual.xml',
'normalize_act': True,
Expand Down
14 changes: 8 additions & 6 deletions myosuite/envs/myo/myochallenge/bimanual_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from myosuite.envs.myo.base_v0 import BaseV0

CONTACT_TRAJ_MIN_LENGTH = 100
GOAL_CONTACT = 10
MAX_TIME = 10.0


class BimanualEnvV1(BaseV0):
Expand Down Expand Up @@ -116,7 +118,7 @@ def _setup(self,
self.over_max = False
self.max_force = 0
self.goal_touch = 0
self.TARGET_GOAL_TOUCH = 5
self.TARGET_GOAL_TOUCH = GOAL_CONTACT


self.touch_history = []
Expand Down Expand Up @@ -291,10 +293,10 @@ def get_reward_dict(self, obs_dict):
return rwd_dict

def _get_done(self, z):
if self.obs_dict['time'] > 3.0:
if self.obs_dict['time'] > MAX_TIME:
return 1
elif z < 0.3:
self.obs_dict['time'] = 3.0
self.obs_dict['time'] = MAX_TIME
return 1
elif self.rwd_dict and self.rwd_dict['solved']:
return 1
Expand Down Expand Up @@ -448,9 +450,9 @@ def get_touching_objects(model: mujoco.MjModel, data: mujoco.MjData, id_info: Id


def body_id_to_label(body_id, id_info: IdInfo):
if id_info.myo_body_range[0] < body_id < id_info.myo_body_range[1]:
if id_info.myo_body_range[0] <= body_id <= id_info.myo_body_range[1]:
return ObjLabels.MYO
elif id_info.prosth_body_range[0] < body_id < id_info.prosth_body_range[1]:
elif id_info.prosth_body_range[0] <= body_id <= id_info.prosth_body_range[1]:
return ObjLabels.PROSTH
elif body_id == id_info.start_id:
return ObjLabels.START
Expand All @@ -474,5 +476,5 @@ def evaluate_contact_trajectory(contact_trajectory: List[set]):
return ContactTrajIssue.PROSTH_SHORT

# Check if only goal was touching object for the last CONTACT_TRAJ_MIN_LENGTH frames
elif not np.all([{ObjLabels.GOAL} == s for s in contact_trajectory[-CONTACT_TRAJ_MIN_LENGTH:]]):
elif not np.all([{ObjLabels.GOAL} == s for s in contact_trajectory[-GOAL_CONTACT + 2:]]): # Subtract 2 from the calculation to maintain a buffer zone around trajectory boundaries for safety/accuracy.
return ContactTrajIssue.NO_GOAL

0 comments on commit 7927eb0

Please sign in to comment.