diff --git a/examples/sb3_imitation.py b/examples/sb3_imitation.py index 49dc3189..fc12f8db 100644 --- a/examples/sb3_imitation.py +++ b/examples/sb3_imitation.py @@ -177,6 +177,7 @@ def close_env(): policy_kwargs=policy_kwargs, verbose=2, tensorboard_log=f"logs/{args.experiment_name}", + device="cpu" # seed=args.seed // Not currently supported as stable_baselines_wrapper.py seed() method is not yet implemented. ) @@ -190,6 +191,7 @@ def close_env(): rng=rng, policy=learner.policy, custom_logger=logger, + device="cpu" ) print("Starting Imitation Learning Training using BC:") bc_trainer.train(n_epochs=args.bc_epochs)