forked from gzerveas/mvts_transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
example_data_class.py
106 lines (90 loc) · 4.92 KB
/
example_data_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
class MachineData(BaseData):
"""
Dataset class for Machine dataset.
Attributes:
all_df: dataframe indexed by ID, with multiple rows corresponding to the same index (sample).
Each row is a time step; Each column contains either metadata (e.g. timestamp) or a feature.
feature_df: contains the subset of columns of `all_df` which correspond to selected features
feature_names: names of columns contained in `feature_df` (same as feature_df.columns)
all_IDs: IDs contained in `all_df`/`feature_df` (same as all_df.index.unique() )
max_seq_len: maximum sequence (time series) length. If None, script argument `max_seq_len` will be used.
(Moreover, script argument overrides this attribute)
"""
def __init__(self, root_dir, file_list=None, pattern=None, n_proc=1, limit_size=None, config=None):
self.set_num_processes(n_proc=n_proc)
self.all_df = self.load_all(root_dir, file_list=file_list, pattern=pattern)
self.all_df = self.all_df.sort_values(by=['machine_record_index']) # datasets is presorted
self.all_df = self.all_df.set_index('machine_record_index')
self.all_IDs = self.all_df.index.unique() # all sample (session) IDs
self.max_seq_len = 66
if limit_size is not None:
if limit_size > 1:
limit_size = int(limit_size)
else: # interpret as proportion if in (0, 1]
limit_size = int(limit_size * len(self.all_IDs))
self.all_IDs = self.all_IDs[:limit_size]
self.all_df = self.all_df.loc[self.all_IDs]
self.feature_names = ['feed_speed', 'current', 'voltage', 'motor_current', 'power']
self.feature_df = self.all_df[self.feature_names]
def load_all(self, root_dir, file_list=None, pattern=None):
"""
Loads datasets from csv files contained in `root_dir` into a dataframe, optionally choosing from `pattern`
Args:
root_dir: directory containing all individual .csv files
file_list: optionally, provide a list of file paths within `root_dir` to consider.
Otherwise, entire `root_dir` contents will be used.
pattern: optionally, apply regex string to select subset of files
Returns:
all_df: a single (possibly concatenated) dataframe with all data corresponding to specified files
"""
# each file name corresponds to another date. Also tools (A, B) and others.
# Select paths for training and evaluation
if file_list is None:
data_paths = glob.glob(os.path.join(root_dir, '*')) # list of all paths
else:
data_paths = [os.path.join(root_dir, p) for p in file_list]
if len(data_paths) == 0:
raise Exception('No files found using: {}'.format(os.path.join(root_dir, '*')))
if pattern is None:
# by default evaluate on
selected_paths = data_paths
else:
selected_paths = list(filter(lambda x: re.search(pattern, x), data_paths))
input_paths = [p for p in selected_paths if os.path.isfile(p) and p.endswith('.csv')]
if len(input_paths) == 0:
raise Exception("No .csv files found using pattern: '{}'".format(pattern))
if self.n_proc > 1:
# Load in parallel
_n_proc = min(self.n_proc, len(input_paths)) # no more than file_names needed here
logger.info("Loading {} datasets files using {} parallel processes ...".format(len(input_paths), _n_proc))
with Pool(processes=_n_proc) as pool:
all_df = pd.concat(pool.map(machineData.load_single, input_paths))
else: # read 1 file at a time
all_df = pd.concat(machineData.load_single(path) for path in input_paths)
return all_df
@staticmethod
def load_single(filepath):
df = machineData.read_data(filepath)
df = machineData.select_columns(df)
num_nan = df.isna().sum().sum()
if num_nan > 0:
logger.warning("{} nan values in {} will be replaced by 0".format(num_nan, filepath))
df = df.fillna(0)
return df
@staticmethod
def read_data(filepath):
"""Reads a single .csv, which typically contains a day of datasets of various machine sessions.
"""
df = pd.read_csv(filepath)
return df
@staticmethod
def select_columns(df):
""""""
df = df.rename(columns={"per_energy": "power"})
# Sometimes 'diff_time' is not measured correctly (is 0), and power ('per_energy') becomes infinite
is_error = df['power'] > 1e16
df.loc[is_error, 'power'] = df.loc[is_error, 'true_energy'] / df['diff_time'].median()
df['machine_record_index'] = df['machine_record_index'].astype(int)
keep_cols = ['machine_record_index', 'wire_feed_speed', 'current', 'voltage', 'motor_current', 'power']
df = df[keep_cols]
return df