-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
148 lines (120 loc) · 4.75 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# conda create --name modelmaker python=3.9
# conda activate modelmaker
# pip install setuptools==72.1.0 Cython==3.0.11 numpy==1.24.3
# pip install -r requirements.txt
import sys
from datetime import datetime, timedelta
# pylint: disable=no-name-in-module
from configs import models
from data.csv_loader import CSVLoader
from data.tiingo_data_fetcher import DataFetcher
from data.utils.data_preprocessing import preprocess_data
from models.model_factory import ModelFactory
from utils.common import print_colored
def select_data(fetcher, default_selection=None, file_path=None):
"""Provide an interface to choose between Tiingo stock, Tiingo crypto, or CSV data."""
default_end_date = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d")
if default_selection is None:
print("Select the data source:")
print("1. Tiingo Stock Data")
print("2. Tiingo Crypto Data")
print("3. Load data from CSV file")
selection = input("Enter your choice (1/2/3): ").strip()
else:
selection = default_selection
if selection == "1":
print("You selected Tiingo Stock Data.")
symbol = input("Enter the stock symbol (default: AAPL): ").strip() or "AAPL"
frequency = (
input(
"Enter the frequency (daily/weekly/monthly/annually, default: daily): "
).strip()
or "daily"
)
start_date = (
input("Enter the start date (YYYY-MM-DD, default: 2021-01-01): ").strip()
or "2021-01-01"
)
end_date = (
input(
f"Enter the end date (YYYY-MM-DD, default: {default_end_date}): "
).strip()
or default_end_date
)
print(
f"Fetching Tiingo Stock Data for {symbol} from {start_date} to {end_date} with {frequency} frequency..."
)
return fetcher.fetch_tiingo_stock_data(symbol, start_date, end_date, frequency)
if selection == "2":
print("You selected Tiingo Crypto Data.")
symbol = (
input("Enter the crypto symbol (default: btcusd): ").strip() or "btcusd"
)
frequency = (
input("Enter the frequency (1min/5min/4hour/1day, default: 1day): ").strip()
or "1day"
)
start_date = (
input("Enter the start date (YYYY-MM-DD, default: 2021-01-01): ").strip()
or "2021-01-01"
)
end_date = (
input(
f"Enter the end date (YYYY-MM-DD, default: {default_end_date}): "
).strip()
or default_end_date
)
print(
f"Fetching Tiingo Crypto Data for {symbol} from {start_date} to {end_date} with {frequency} frequency..."
)
return fetcher.fetch_tiingo_crypto_data(symbol, start_date, end_date, frequency)
if selection == "3":
print("You selected to load data from a CSV file.")
if file_path is None:
file_path = input("Enter the CSV file path: ").strip()
return CSVLoader.load_csv(file_path)
# Exit the program if the user enters an invalid choice
print_colored("Invalid choice", "error")
sys.exit(1)
def model_selection_input():
print("Select the models to train:")
print("1. All models")
print("2. Custom selection")
model_selection = input("Enter your choice (1/2): ").strip()
if model_selection == "1":
model_types = models
elif model_selection == "2":
available_models = {str(i + 1): model for i, model in enumerate(models)}
print("Available models to train:")
for key, value in available_models.items():
print(f"{key}. {value}")
selected_models = input(
"Enter the numbers of the models to train (e.g., 1,3,5): "
).strip()
model_types = [
available_models[num.strip()]
for num in selected_models.split(",")
if num.strip() in available_models
]
else:
print_colored("Invalid choice, defaulting to all models.", "error")
model_types = models
return model_types
def main():
fetcher = DataFetcher()
# Select data dynamically based on user input
data = select_data(fetcher) # example testing defaults , "4", "data/sets/eth.csv"
# Normalize and preprocess the data
data = preprocess_data(data)
# Initialize ModelFactory
factory = ModelFactory()
# Select models to train
model_types = model_selection_input()
# Train and save the selected models
for model_type in model_types:
print(f"Training {model_type} model...")
model = factory.create_model(model_type)
model.train(data)
print_colored("Model training and saving complete!", "success")
if __name__ == "__main__":
main()