Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Impossible to load model to use it for training #49

Open
4 tasks done
edofazza opened this issue May 27, 2024 · 1 comment
Open
4 tasks done

[Bug]: Impossible to load model to use it for training #49

edofazza opened this issue May 27, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@edofazza
Copy link

🐛 Bug

I'm trying to load a trained model to use it for testing, but I am facing with an error.
Thank you.

To Reproduce

import torch as th
import os
from rllte.xplore.reward import RND, Disagreement, RIDE
from rllte.env import make_mario_env
from rllte.agent import PPO, DDPG

if __name__ == '__main__':
    n_steps: int = 2048 * 16
    device = 'cuda' if th.cuda.is_available() else 'cpu'
    envs = make_mario_env('SuperMarioBros-1-1-v0', device=device, num_envs=1,
                          asynchronous=False, frame_stack=4, gray_scale=True)
    print(device, envs.observation_space, envs.action_space)
    # create the intrinsic reward module
    #irs = Disagreement(envs, device=device)
    # create the PPO agent
    agent = PPO(envs,
                device=device,
                batch_size=512,
                n_epochs=10,
                num_steps=n_steps//8,
                pretraining=True)
    agent.policy.load_state_dict(th.load("ride_1_1_1507328.pth", map_location=th.device('cpu')),)
    agent.eval(100)

Relevant log output / Error message

/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gym/envs/registration.py:555: UserWarning: WARN: The environment SuperMarioBros-1-1-v0 is out of date. You should consider upgrading to version `v3`.
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gym/envs/registration.py:627: UserWarning: WARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.metadata to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.metadata` for environment variables or `env.get_wrapper_attr('metadata')` that will search the reminding wrappers.
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.
  logger.warn(
cpu Box(0, 255, (4, 84, 84), uint8) Discrete(7)
Traceback (most recent call last):
  File "/Users/edoardofazzari/Documents/GitHub/got-it-memorized/src/tests.py", line 22, in <module>
    agent.policy.load_state_dict(th.load("/Users/edoardofazzari/Documents/GitHub/got-it-memorized/src/ride_1_1_1507328.pth",
  File "/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2103, in load_state_dict
    raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
TypeError: Expected state_dict to be dict-like, got <class 'rllte.common.utils.ExportModel'>.

System Info

No response

Checklist

  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • I have provided a minimal working example to reproduce the bug
  • I've used the markdown code blocks for both code and stack traces.
@edofazza edofazza added the bug Something isn't working label May 27, 2024
@yuanmingqi
Copy link
Contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants