forked from SrikarValluri/aimbot-detection
-
Notifications
You must be signed in to change notification settings - Fork 5
/
save_cnn_output.py
78 lines (58 loc) · 2.2 KB
/
save_cnn_output.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
import sys
import re
if(len(sys.argv) != 3):
framesDir = "./no_hacks_data_nn"
outputDir = "./no_hacks_data_tensor"
else:
framesDir = sys.argv[1]
outputDir = sys.argv[2]
if(not os.path.exists(framesDir)):
print("Frames directory is invalid.")
sys.exit()
if(not os.path.exists(outputDir)):
print("Output directory is invalid.")
sys.exit()
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()
model.fc = nn.Identity()
allVideos = []
subfolders = [ f for f in os.scandir(framesDir) if f.is_dir() ]
# subfolders = sorted(subfolders, key=lambda x: float(re.findall(r'[\d\.]+', x.name)[-1]) + 60 * float(re.findall(r'[\d\.]+', x.name)[-2]))
# print(float(re.findall(r'[\d\.]+', subfolders[0].name)[-1]))
for folder in subfolders:
l = len([f for f in os.scandir(folder.path) if f.is_file()] )
if l != 60:
print(l, folder.name)
with torch.no_grad():
for folder in subfolders:
print(f'Extracting Features from clip at {(folder.name).replace("_", ":")}\t\t\t\n', end='', flush=True)
frames = [ f for f in os.scandir(folder.path) if f.is_file() ]
assert len(frames) == 60, f'{folder.name} has an incorrect number of frames ({len(frames)})'
frames = sorted(frames, key=lambda x: int(re.findall(r'\d+', x.name)[0]))
video = []
for frame in frames:
image = cv2.imread(frame.path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
transform = transforms.ToTensor()
tensor = transform(image)
tensor = tensor.unsqueeze(0)
frameFeatures = model(tensor)
video.append(frameFeatures.squeeze())
video = torch.stack(video)
allVideos.append(video)
print("\nSaving Features")
allVideos = torch.stack(allVideos)
print(allVideos.shape)
torch.save(allVideos, os.path.join(outputDir, "clips.pt"))
print("Features Saved!")