-
Notifications
You must be signed in to change notification settings - Fork 0
/
LoadMountainCar.py
55 lines (53 loc) · 1.5 KB
/
LoadMountainCar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from DuelingNetwork.model import Net
import numpy as np
import matplotlib.pyplot as plt
import gym
import random
import torch
from torch import nn
from torch import optim
from collections import namedtuple
import warnings
import time
import torch.nn.functional as F
from DuelingNetwork.model import Net
import numpy as np
import matplotlib.pyplot as plt
import gym
import random
import torch
from torch import nn
from torch import optim
from collections import namedtuple
import warnings
import time
import torch.nn.functional as F
from frameSave import save_frames_as_gif
model = torch.load("pth/MountainCar.pth")
model.eval()
ENV = 'MountainCar-v0'
env=gym.make(ENV)
observation = env.reset()
state=observation
state=torch.from_numpy(state).type(torch.FloatTensor)
state=torch.unsqueeze(state,0)
frames=[]
for i in range(200):
with torch.no_grad():
model.eval()
action=model(state).max(1)[1].view(1,1)
observation_next, _, done, _ = env.step(
action.item())
if done:
break
else:
state_next = observation_next # 관측 결과를 그대로 상태로 사용
state_next = torch.from_numpy(state_next).type(
torch.FloatTensor) # numpy 변수를 파이토치 텐서로 변환
state_next = torch.unsqueeze(state_next, 0)
state = state_next
frames.append(env.render(mode="rgb_array"))
time.sleep(0.01)
# print(i)
env.close()
save_frames_as_gif(frames,filename='MountainCar.gif')