Skip to content

Commit

Permalink
Replace recorder with mediapy based video writer
Browse files Browse the repository at this point in the history
  • Loading branch information
breakds committed Sep 15, 2023
1 parent e54b191 commit 51c086b
Showing 1 changed file with 28 additions and 35 deletions.
63 changes: 28 additions & 35 deletions alf/trainers/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def _step(algorithm,
trans_state,
metrics,
render=False,
recorder=None,
video_writer=None,
sleep_time_per_step=0,
selective_criteria_func=None):
"""Perform one step interaction using the outpupt action from ``algorithm``
Expand All @@ -940,8 +940,7 @@ def _step(algorithm,
metrics (StepMetric): a list of metrics that will be updated based on
``time_step``.
render (bool|False): if True, display the frames of ``env`` on a screen.
recorder (VideoRecorder|None): recorder the frames of ``env`` and other
additional images in prediction step info if present.
video_writer: Append images to it and it will convert them to video
sleep_time_per_step (int|0): The sleep time between two frames when
``render`` is True.
selective_criteria_func (callable|None): a callable for determining
Expand Down Expand Up @@ -971,24 +970,9 @@ def _step(algorithm,
time_step, trans_state)
policy_step = algorithm.predict_step(transformed_time_step, policy_state)

if recorder and selective_criteria_func is None:
recorder.capture_frame(policy_step.info, time_step.is_last())

elif recorder and selective_criteria_func is not None:
env_frame = recorder.capture_env_frame()
recorder.cache_frame_and_pred_info(env_frame, policy_step.info)

if time_step.is_last():
if selective_criteria_func(
map_structure(lambda x: x.cpu().numpy(),
metrics[1].latest()),
map_structure(lambda x: x.cpu().numpy(),
metrics[3].latest())):
logging.info(
"+++++++++ Selective Case Discovered! +++++++++++")
recorder.generate_video_from_cache()
else:
recorder.clear_cache()
if video_writer is not None:
image = env.render(mode="rgb_array")
video_writer.send(image)

elif render:
if env.batch_size > 1:
Expand All @@ -1002,6 +986,21 @@ def _step(algorithm,
return next_time_step, policy_step, trans_state


def make_video_writer(video_file: Path, fps: float):
import mediapy as media

frame = yield

with media.VideoWriter(video_file, fps=fps,
shape=(frame.shape[0], frame.shape[1])) as video_writer:
while True:
frame = yield
if frame is None:
break
video_writer.add_image(frame)
yield


@common.mark_eval
def play(root_dir,
env,
Expand Down Expand Up @@ -1078,19 +1077,13 @@ def play(root_dir,
Trainer.progress()))

batch_size = env.batch_size
recorder = None
video_writer = None
if record_file is not None:
assert batch_size == 1, 'video recording is not supported for parallel play'
# Note that ``VideoRecorder`` will import ``matplotlib`` which might have
# some side effects on xserver (if its backend needs graphics).
# This is incompatible with RLBench parallel envs >1 (or other
# envs requiring xserver) for some unknown reasons, so we have a lazy import here.
from alf.utils.video_recorder import VideoRecorder
recorder = VideoRecorder(
env,
last_step_repeats=last_step_repeats,
append_blank_frames=append_blank_frames,
path=record_file)
video_writer = make_video_writer(
video_file=record_file,
fps=env.metadata["video.frames_per_second"])
next(video_writer)
elif render:
if batch_size > 1:
env.envs[0].render(mode='human')
Expand Down Expand Up @@ -1147,7 +1140,7 @@ def play(root_dir,
trans_state=trans_state,
metrics=metrics,
render=render,
recorder=recorder,
video_writer=video_writer,
sleep_time_per_step=sleep_time_per_step,
selective_criteria_func=selective_criteria_func)

Expand All @@ -1171,5 +1164,5 @@ def play(root_dir,
time_step = next_time_step

env.reset()
if recorder:
recorder.close()
if video_writer:
video_writer.send(None)

0 comments on commit 51c086b

Please sign in to comment.