-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
63 lines (44 loc) · 1.52 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import numpy as np
import pandas as pd
import torch
from config import classes_order
class Dataset(torch.utils.data.Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
classes_filename = os.listdir(self.data_dir)
classes_filename.sort()
classes = []
# sorting class_order
for i, (k, v) in enumerate(classes_order.items()):
for filename in classes_filename:
if v in filename:
classes.append(filename)
print(classes)
lst_input = []
lst_label = []
for i, c in enumerate(classes):
a = os.listdir(os.path.join(data_dir, c))
print(a)
for j in a:
df = pd.read_csv(os.path.join(data_dir, c, j))
lst_input.append(df)
lst_label.append(i)
self.lst_label = lst_label
self.lst_input = lst_input
def __len__(self):
return len(self.lst_label)
def __getitem__(self, index):
sample = self.lst_input[index]
accel = sample.iloc[:, 1:4]
gyro = sample.iloc[:, 4:7]
accel = np.array(accel)
gyro = np.array(gyro)
accel = torch.tensor(accel, dtype=torch.float)
gyro = torch.tensor(gyro, dtype=torch.float)
label = self.lst_label[index]
data = {'input': (accel, gyro), 'label': label}
if self.transform:
data = self.transform(data)
return data