-
Notifications
You must be signed in to change notification settings - Fork 11
/
rsu-metc-classify-merged-network.py
executable file
·68 lines (55 loc) · 1.95 KB
/
rsu-metc-classify-merged-network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python3
#
# This work is inspired by this NN architecture:
# https://github.com/Realtime-Action-Recognition/Realtime-Action-Recognition/blob/master/three_stream_model.py
#
from classify_dataset import evaluate, get_merged_model, IMG_SIZE, N_CLASSES, batch_size
from dataset_utilities import get_weights_dict
from generator_rgb_with_of import merged_dataset_from_directories
import os
import tensorflow as tf
# make sure to provide correct paths to the folders on your machine
rgb_dir = 'dataset-metc/RSU_METC_dataset_preprocessed/frames/trainval'
of_dir = 'dataset-metc/RSU_METC_dataset_preprocessed/of/trainval'
test_rgb_dir = 'dataset-metc/RSU_METC_dataset_preprocessed/frames/test'
test_of_dir = 'dataset-metc/RSU_METC_dataset_preprocessed/of/test'
CLASS_NAMES = [str(i) for i in range(N_CLASSES)]
train_ds = merged_dataset_from_directories(
rgb_dir,
of_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=IMG_SIZE,
shuffle=True,
label_mode='categorical',
crop_to_aspect_ratio=False,
batch_size=batch_size)
val_ds = merged_dataset_from_directories(
rgb_dir,
of_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=IMG_SIZE,
shuffle=True,
label_mode='categorical',
crop_to_aspect_ratio=False,
batch_size=batch_size)
test_ds = merged_dataset_from_directories(
test_rgb_dir,
test_of_dir,
seed=123,
image_size=IMG_SIZE,
shuffle=False,
label_mode='categorical',
crop_to_aspect_ratio=False,
batch_size=batch_size)
# to improve performance, use buffered prefetching to load images
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)
weights_dict = get_weights_dict(rgb_dir, CLASS_NAMES)
model = get_merged_model()
evaluate("rsu-metc-merged", train_ds, val_ds, test_ds, weights_dict, model=model)