-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_model_advanced.py
129 lines (108 loc) · 5 KB
/
test_model_advanced.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import time
import numpy as np
import pybullet as p # Add pybullet for video recording
from stable_baselines3 import PPO
import imageio # Add imageio to combine images into a video
from utils import parser as parse
from utils.data import plot_observation_data, plot_actions_data
from env.werdna_balance import WerdnaBalance
from env.werdna_stand import WerdnaStand
from env.werdna_advanced import WerdnaAdvanced
def main():
total_steps = 4096 # Set the total number of steps
config = parse.parser("config/config_advanced.yaml")
robot_model = config['robot_model']
env_name = config['environment']
filename = config['filename']
biases = config['biases']
tb_log_name = f"{env_name}_" + "_".join([f"{key}{value}" for key, value in biases.items()])
record_video = config.get('record_video', False) # Check if video recording is enabled
# Initialize the environment based on the provided configuration
if env_name == 'werdna_balance':
env = WerdnaBalance(model=robot_model, render_mode='GUI')
elif env_name == 'werdna_stand':
env = WerdnaStand(model=robot_model, render_mode='GUI')
elif env_name == 'werdna_advanced':
env = WerdnaAdvanced(model=robot_model, render_mode='GUI')
else:
raise ValueError(f"Unknown Environment: {env_name}")
complete_filename = os.path.join('results', tb_log_name, filename)
model = PPO.load(complete_filename)
model.set_env(env)
if record_video:
image_folder = os.path.join("video", "frames")
os.makedirs("video", exist_ok=True)
os.makedirs(image_folder, exist_ok=True)
# Lists to store observation data for plotting
roll_data = []
pitch_data = []
yaw_data = []
roll_rate_data = []
pitch_rate_data = []
yaw_rate_data = []
position_data = []
velocity_data = []
left_wheel_torque_data = []
right_wheel_torque_data = []
joint_angle_data = []
episode_num = 0
frame_filenames = []
# Configure camera parameters
camera_distance = 2.0 # Distance from the robot
camera_yaw = 50 # Horizontal angle
camera_pitch = -30 # Vertical angle
camera_follow_interval = 10 # Number of steps between camera updates
while episode_num < 1:
obs, _ = env.reset() # Reset the environment
total_reward = 0
episode_num += 1
# Run for the specified number of steps
for step in range(total_steps):
time.sleep(1. / 60.)
action, _ = model.predict(obs)
obs, reward, truncated, terminated, info = env.step(action)
total_reward += reward
# Update camera position every few steps to follow the robot
if step % camera_follow_interval == 0:
robot_position, _ = p.getBasePositionAndOrientation(env.robotID)
p.resetDebugVisualizerCamera(cameraDistance=camera_distance,
cameraYaw=camera_yaw,
cameraPitch=camera_pitch,
cameraTargetPosition=robot_position)
if step % 10 == 0:
roll_data.append(obs[0])
pitch_data.append(obs[1])
yaw_data.append(obs[2])
roll_rate_data.append(obs[3])
pitch_rate_data.append(obs[4])
yaw_rate_data.append(obs[5])
position_data.append(obs[6])
velocity_data.append(obs[7])
left_wheel_torque_data.append(action[0] * 3)
right_wheel_torque_data.append(action[1] * 3)
joint_angle_data.append(action[2] * np.pi / 12)
if record_video:
width, height, rgb_img, _, _ = p.getCameraImage(1920, 1080)
frame_filename = os.path.join(image_folder, f"step_{step}.png")
frame_filenames.append(frame_filename)
imageio.imwrite(frame_filename, rgb_img)
print(f"Step {step + 1} saved, total reward so far: {total_reward}")
if terminated:
obs, _ = env.reset() # Reset if the episode ends
if record_video:
video_filename = f"{filename}_{total_steps}_steps.mp4"
with imageio.get_writer(video_filename, fps=60) as video:
for frame_filename in frame_filenames:
img = imageio.imread(frame_filename)
video.append_data(img)
print(f"Video saved to {video_filename}")
for frame_filename in frame_filenames:
os.remove(frame_filename)
print(f"The Agent ended with total reward of {total_reward} over {total_steps} timesteps")
env.close()
# Plot the data
plot_observation_data(total_steps, roll_data, pitch_data, yaw_data, roll_rate_data, pitch_rate_data, yaw_rate_data, position_data, velocity_data)
plot_actions_data(total_steps, left_wheel_torque_data, right_wheel_torque_data, joint_angle_data)
if __name__ == "__main__":
main()