-
Notifications
You must be signed in to change notification settings - Fork 0
/
central_processing.py
executable file
·165 lines (138 loc) · 5.32 KB
/
central_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
from typing import Dict, Any
import numpy as np
from beartype.typing import Dict, Optional, Any, NoReturn
from neolibrary.monitoring.logger import NeoLogger
from utils.helpers import timer
from config import config
from neotemplate.base_central_processing import CPNeoTemplate
logger = NeoLogger(__name__)
class CentralProcessing(CPNeoTemplate):
"""
Central processing unit for preprocessing, postprocessing, predicting and training.
Parameters
------------
args: argparse.Namespace
The arguments for the central processing unit.
Warning
------------
Remember to include methods for preprocessing, postprocessing, predict_step or you will get an error.
"""
def __init__(self) -> NoReturn:
"""Constructor for the central processing unit."""
super().__init__()
logger.info('Initializing central processing unit')
self.test_data = np.random.randn(
32, 32, 32
) # Please make sure this data mimics your own data
@timer
def preprocess(
self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}
) -> np.ndarray:
"""Preprocess the data before training/val/test/predict.
Parameters
----------
data : np.ndarray
the data to be preprocessed
extras: dict
additional arguments for preprocessing such as resolution information etc.
If provided, explain in depth in the docstring of the input and the input type.
Example of extras:
resolution [list]: resolution of the image, e.g. {"resolution": [1.0, 1.0, 1.0]}
Important
-------
Extras dictionary is something the researchers need to define. There has to be a proper explanation of what the extras dictionary is and what it contains, as shown in the example above.
Returns
-------
np.ndarray
the preprocessed data
"""
try:
logger.info(f'Preprocessing data with shape {data.shape}')
# --------------------- #
# TODO: Your preprocessing code here
# --------------------- #
logger.success('=> Preprocessing completed successfully')
return data
except (NameError, ValueError, TypeError, AttributeError, RuntimeError) as e:
msg = f'I failed preprocessing the image with error: {e}'
logger.error(msg)
except Exception as e:
msg = f'I failed preprocessing the image. Unexpected exception: type={type(e)}, e:{e}'
logger.error(msg)
@timer
def predict_step(self, data: np.ndarray, model: config.ModelInput) -> np.ndarray:
"""
Predict step function.
Parameters
------------
data: np.ndarray
data input
Returns
------------
np.ndarray
Predictions.
"""
try:
self.eval()
with torch.no_grad():
logger.info(f'Predicting data with shape {data.shape}')
data = model(data)
# --------------------- #
# TODO: Your prediction code here
# --------------------- #
logger.success('=> Prediction completed successfully')
return data
except (
NameError,
ValueError,
TypeError,
AttributeError,
RuntimeError,
) as e:
msg = f'I failed predicting the image with error: {e}'
logger.error(msg)
except Exception as e:
msg = f'I failed predicting the image. Unexpected exception: type={type(e)}, e:{e}'
logger.error(msg)
@timer
def postprocess(
self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}
) -> np.ndarray:
"""Postprocess the data after training/val/test/predict
Parameters
----------
data : np.ndarray
the data to be postprocessed
extras: dict
additional arguments for preprocessing such as resolution information etc.
If provided, explain in depth in the docstring of the input and the input type.
Example of extras:
resolution [list]: resolution of the image, e.g. {"resolution": [1.0, 1.0, 1.0]}
Important
-------
Extras dictionary is something the researchers need to define. There has to be a proper explanation of what the extras dictionary is and what it contains, as shown in the example above.
Returns
-------
np.ndarray
the postprocessed data
"""
try:
logger.info(f'Postprocessing data with shape {data.shape}')
# --------------------- #
# TODO: Your postprocessing code here
# --------------------- #
logger.success('=> Postprocessing completed successfully')
return data
except (
NameError,
ValueError,
TypeError,
AttributeError,
RuntimeError,
) as e:
msg = f'I failed postprocessing with error {e}'
logger.error(msg)
except Exception as e:
msg = f'I failed postprocessing the image. Unexpected exception: type={type(e)}, e:{e}'
logger.error(msg)