diff --git a/myosuite/envs/myo/myochallenge/__init__.py b/myosuite/envs/myo/myochallenge/__init__.py index a34b6633..0ef2185c 100644 --- a/myosuite/envs/myo/myochallenge/__init__.py +++ b/myosuite/envs/myo/myochallenge/__init__.py @@ -33,7 +33,7 @@ def register_env_with_variants(id, entry_point, max_episode_steps, kwargs): register_env_with_variants(id='myoChallengeBimanual-v0', entry_point='myosuite.envs.myo.myochallenge.bimanual_v0:BimanualEnvV1', - max_episode_steps=300, + max_episode_steps=1000, kwargs={ 'model_path': curr_dir + '/../assets/arm/myoarm_bionic_bimanual.xml', 'normalize_act': True, diff --git a/myosuite/envs/myo/myochallenge/bimanual_v0.py b/myosuite/envs/myo/myochallenge/bimanual_v0.py index 7bf225ee..31ab7cf2 100644 --- a/myosuite/envs/myo/myochallenge/bimanual_v0.py +++ b/myosuite/envs/myo/myochallenge/bimanual_v0.py @@ -18,6 +18,8 @@ from myosuite.envs.myo.base_v0 import BaseV0 CONTACT_TRAJ_MIN_LENGTH = 100 +GOAL_CONTACT = 10 +MAX_TIME = 10.0 class BimanualEnvV1(BaseV0): @@ -116,7 +118,7 @@ def _setup(self, self.over_max = False self.max_force = 0 self.goal_touch = 0 - self.TARGET_GOAL_TOUCH = 5 + self.TARGET_GOAL_TOUCH = GOAL_CONTACT self.touch_history = [] @@ -291,10 +293,10 @@ def get_reward_dict(self, obs_dict): return rwd_dict def _get_done(self, z): - if self.obs_dict['time'] > 3.0: + if self.obs_dict['time'] > MAX_TIME: return 1 elif z < 0.3: - self.obs_dict['time'] = 3.0 + self.obs_dict['time'] = MAX_TIME return 1 elif self.rwd_dict and self.rwd_dict['solved']: return 1 @@ -448,9 +450,9 @@ def get_touching_objects(model: mujoco.MjModel, data: mujoco.MjData, id_info: Id def body_id_to_label(body_id, id_info: IdInfo): - if id_info.myo_body_range[0] < body_id < id_info.myo_body_range[1]: + if id_info.myo_body_range[0] <= body_id <= id_info.myo_body_range[1]: return ObjLabels.MYO - elif id_info.prosth_body_range[0] < body_id < id_info.prosth_body_range[1]: + elif id_info.prosth_body_range[0] <= body_id <= id_info.prosth_body_range[1]: return ObjLabels.PROSTH elif body_id == id_info.start_id: return ObjLabels.START @@ -474,5 +476,5 @@ def evaluate_contact_trajectory(contact_trajectory: List[set]): return ContactTrajIssue.PROSTH_SHORT # Check if only goal was touching object for the last CONTACT_TRAJ_MIN_LENGTH frames - elif not np.all([{ObjLabels.GOAL} == s for s in contact_trajectory[-CONTACT_TRAJ_MIN_LENGTH:]]): + elif not np.all([{ObjLabels.GOAL} == s for s in contact_trajectory[-GOAL_CONTACT + 2:]]): # Subtract 2 from the calculation to maintain a buffer zone around trajectory boundaries for safety/accuracy. return ContactTrajIssue.NO_GOAL