diff --git a/alf/trainers/policy_trainer.py b/alf/trainers/policy_trainer.py index 23f6d3d1f..9798510a7 100644 --- a/alf/trainers/policy_trainer.py +++ b/alf/trainers/policy_trainer.py @@ -922,7 +922,7 @@ def _step(algorithm, trans_state, metrics, render=False, - recorder=None, + video_writer=None, sleep_time_per_step=0, selective_criteria_func=None): """Perform one step interaction using the outpupt action from ``algorithm`` @@ -940,8 +940,7 @@ def _step(algorithm, metrics (StepMetric): a list of metrics that will be updated based on ``time_step``. render (bool|False): if True, display the frames of ``env`` on a screen. - recorder (VideoRecorder|None): recorder the frames of ``env`` and other - additional images in prediction step info if present. + video_writer: Append images to it and it will convert them to video sleep_time_per_step (int|0): The sleep time between two frames when ``render`` is True. selective_criteria_func (callable|None): a callable for determining @@ -971,24 +970,9 @@ def _step(algorithm, time_step, trans_state) policy_step = algorithm.predict_step(transformed_time_step, policy_state) - if recorder and selective_criteria_func is None: - recorder.capture_frame(policy_step.info, time_step.is_last()) - - elif recorder and selective_criteria_func is not None: - env_frame = recorder.capture_env_frame() - recorder.cache_frame_and_pred_info(env_frame, policy_step.info) - - if time_step.is_last(): - if selective_criteria_func( - map_structure(lambda x: x.cpu().numpy(), - metrics[1].latest()), - map_structure(lambda x: x.cpu().numpy(), - metrics[3].latest())): - logging.info( - "+++++++++ Selective Case Discovered! +++++++++++") - recorder.generate_video_from_cache() - else: - recorder.clear_cache() + if video_writer is not None: + image = env.render(mode="rgb_array") + video_writer.send(image) elif render: if env.batch_size > 1: @@ -1002,6 +986,21 @@ def _step(algorithm, return next_time_step, policy_step, trans_state +def make_video_writer(video_file: Path, fps: float): + import mediapy as media + + frame = yield + + with media.VideoWriter(video_file, fps=fps, + shape=(frame.shape[0], frame.shape[1])) as video_writer: + while True: + frame = yield + if frame is None: + break + video_writer.add_image(frame) + yield + + @common.mark_eval def play(root_dir, env, @@ -1078,19 +1077,13 @@ def play(root_dir, Trainer.progress())) batch_size = env.batch_size - recorder = None + video_writer = None if record_file is not None: assert batch_size == 1, 'video recording is not supported for parallel play' - # Note that ``VideoRecorder`` will import ``matplotlib`` which might have - # some side effects on xserver (if its backend needs graphics). - # This is incompatible with RLBench parallel envs >1 (or other - # envs requiring xserver) for some unknown reasons, so we have a lazy import here. - from alf.utils.video_recorder import VideoRecorder - recorder = VideoRecorder( - env, - last_step_repeats=last_step_repeats, - append_blank_frames=append_blank_frames, - path=record_file) + video_writer = make_video_writer( + video_file=record_file, + fps=env.metadata["video.frames_per_second"]) + next(video_writer) elif render: if batch_size > 1: env.envs[0].render(mode='human') @@ -1147,7 +1140,7 @@ def play(root_dir, trans_state=trans_state, metrics=metrics, render=render, - recorder=recorder, + video_writer=video_writer, sleep_time_per_step=sleep_time_per_step, selective_criteria_func=selective_criteria_func) @@ -1171,5 +1164,5 @@ def play(root_dir, time_step = next_time_step env.reset() - if recorder: - recorder.close() + if video_writer: + video_writer.send(None)