-
Notifications
You must be signed in to change notification settings - Fork 18
/
data_i3d_audio.py
121 lines (97 loc) · 4.72 KB
/
data_i3d_audio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import os
import nltk
from PIL import Image
from pycocotools.coco import COCO
import numpy as np
import json as jsonmod
import pickle
import h5py
class VTTDataset(data.Dataset):
'''
Video to Text The description class of the data set used to load and provide data
Supports MSR-VTT and MSVD data sets
The following input is required for the construction:
1. Provide pkl files with textual features
2. The npy file containing the video frame i3d feature
Provide text and video npy features, and return data based on caption's id
'''
def __init__(self, cap_pkl, vid_feature_dir):
with open(cap_pkl, 'rb') as f:
self.captions, self.lengths, self.video_ids = pickle.load(f)
#imfeat_file = os.path.join(feature_file, data_name)
self.vid_feat_dir = vid_feature_dir
def __getitem__(self, index):
'''
Return a training sample pair (including the video frame feature and the corresponding caption)
According to the caption to find the corresponding video, so the need for video storage is in accordance with the id ascending order
'''
caption = self.captions[index]
length = self.lengths[index]
video_id = self.video_ids[index]
vid_feat_dir = self.vid_feat_dir
# activity (i3d) feature
path1=vid_feat_dir+'/video_features'+ "/msr_vtt-I3D-RGBFeatures-video"+ str(video_id) + ".npy"
video_feat = torch.from_numpy(np.load(path1))
video_feat = video_feat.mean(dim=0, keepdim=False)
# audio (soundnet) Feature
audio_feat_file = vid_feat_dir+'/audio_features/'+"/video"+str(video_id)+".mp3.soundnet.h5"
audio_h5 = h5py.File(audio_feat_file,'r')
audio_feat=audio_h5['layer24'][()]
audio_feat=torch.from_numpy(audio_feat)
audio_feat = audio_feat.mean(dim=1, keepdim=False)
video_feat = torch.cat([video_feat,audio_feat])
return video_feat, caption, index, video_id
def __len__(self):
return len(self.captions)
def collate_fn(data):
"""Build mini-batch tensors from a list of (image, caption) tuples.
Args:
data: list of (image, caption) tuple.
- image: torch tensor of shape (3, 256, 256).
- caption: torch tensor of shape (?); variable length.
Returns:
images: torch tensor of shape (batch_size, 3, 256, 256).
targets: torch tensor of shape (batch_size, padded_length).
lengths: list; valid length for each padded caption.
"""
# Sort a data list by caption length
data.sort(key=lambda x: len(x[1]), reverse=True)
images, captions, ids, img_ids = zip(*data)
# Merge images (convert tuple of 3D tensor to 4D tensor)
images = torch.stack(images, 0)
# Merget captions (convert tuple of 1D tensor to 2D tensor)
lengths = [len(cap) for cap in captions]
targets = torch.zeros(len(captions), max(lengths)).long()
for i, cap in enumerate(captions):
end = lengths[i]
targets[i, :end] = cap[:end]
return images, targets, lengths, ids
def get_vtt_loader(cap_pkl, feature, opt, batch_size=100, shuffle=True, num_workers=2):
v2t = VTTDataset(cap_pkl, feature)
data_loader = torch.utils.data.DataLoader(dataset=v2t,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True,
collate_fn=collate_fn)
return data_loader
def get_loaders(data_name, vocab, crop_size, batch_size, workers, opt):
dpath = os.path.join(opt.data_path, data_name)
if opt.data_name.endswith('vtt'):
train_caption_pkl_path = '/hdd2/mithun/VTT/vsepp_data/msr-vtt/captions_pkl/msr-vtt_captions_train.pkl'
val_caption_pkl_path = '/hdd2/mithun/VTT/vsepp_data/msr-vtt/captions_pkl/msr-vtt_captions_val.pkl'
feature_path = dpath
train_loader = get_vtt_loader(train_caption_pkl_path, feature_path, opt, batch_size, True, workers)
val_loader = get_vtt_loader(val_caption_pkl_path, feature_path, opt, batch_size, False, workers)
return train_loader, val_loader
def get_test_loader(split_name, data_name, vocab, crop_size, batch_size,
workers, opt):
dpath = os.path.join(opt.data_path, data_name)
if opt.data_name.endswith('vtt'):
test_caption_pkl_path = '/hdd2/mithun/VTT/vsepp_data/msr-vtt/captions_pkl/msr-vtt_captions_'+split_name+'.pkl'
feature_path = dpath
test_loader = get_vtt_loader(test_caption_pkl_path, feature_path, opt, batch_size, True, workers)
return test_loader