-
Notifications
You must be signed in to change notification settings - Fork 4
/
rollout.py
35 lines (27 loc) · 1.24 KB
/
rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import dataset
import model
def rollout(model, data, metadata, noise_std):
device = next(model.parameters()).device
model.eval()
window_size = model.window_size + 1
total_time = data["position"].size(0)
traj = data["position"][:window_size]
traj = traj.permute(1, 0, 2)
particle_type = data["particle_type"]
for time in range(total_time - window_size):
with torch.no_grad():
graph = dataset.preprocess(particle_type, traj[:, -window_size:], None, metadata, 0.0)
graph = graph.to(device)
acceleration = model(graph).cpu()
acceleration = acceleration * torch.sqrt(torch.tensor(metadata["acc_std"]) ** 2 + noise_std ** 2) + torch.tensor(metadata["acc_mean"])
recent_position = traj[:, -1]
recent_velocity = recent_position - traj[:, -2]
new_velocity = recent_velocity + acceleration
new_position = recent_position + new_velocity
traj = torch.cat((traj, new_position.unsqueeze(1)), dim=1)
return traj
simulator = model.LearnedSimulator()
simulator = simulator.cuda()
test_dataset = dataset.RolloutDataset("../datasets/WaterDrop", "valid")
rollout(simulator, test_dataset[0], test_dataset.metadata, 0.0)