This project provides a Temporal Fusion Transformer (TFT) framework for time series forecasting. The TFT model is known for handling multiple time series with temporal relationships and incorporates various data sources efficiently for future prediction. With this project, users can plug in their custom data, configure preprocessing, and train the TFT model quickly and easily.
- Easy integration of custom datasets
- Flexible configuration for preprocessing and model parameters
- Automated training and evaluation pipeline
- Baseline comparison and visualization of predictions
To install the necessary dependencies, run:
git clone https://github.com/anhphan2705/temporal_fusion_transformer_plugnplay.git
cd temporal_fusion_transformer_plugnplay
python -m venv tft_env
source tft_env/bin/activate # On Windows, use `tft_env\Scripts\activate`
pip install -r requirements.txt
- To use your own data, place it in the
./data
directory. - Ensure that your data is formatted appropriately and includes any necessary preprocessing scripts in the
./datasets
directory. - Make an import of your preprocess data file to
./tools/data_process.py
, create an option to use it in the data_pipeline
Make a copy of the sample_config.yaml
file to set the desired parameters for training, hyperparameter tuning, and logging. This file includes paths for data, model checkpoints, and logging directories, as well as hyperparameters for the TFT model.
To train the model, run:
python main.py --mode train --config configs/your_config.yaml
This will start the training pipeline, which includes data loading, model initialization, hyperparameter tuning, and final training.
To evaluate the model, run:
python main.py --mode eval --config configs/your_config.yaml --model path_to_your_model
The training pipeline evaluates the model against a baseline model. Validation metrics (loss, RMSE, MAE) are logged during training and evaluation.
After training, the script generates plots comparing the trained model's predictions with actual data and the baseline model's predictions. These plots are saved in the specified logs directory.
The events.out.tfevents...
file generated by PyTorch Lightning is meant to be read and visualized using TensorBoard. To read and interpret this log file, follow these steps:
-
Install TensorBoard (if not already installed):
pip install tensorboard
-
Navigate to the Directory containing your log files:
cd path_to_your_log_directory
-
Run TensorBoard:
tensorboard --logdir=.
-
Open TensorBoard in Your Browser: Open a web browser and go to
http://localhost:6006/
. You should see the TensorBoard dashboard, which will visualize all the metrics, hyperparameters, and other logged data from your training process.
Contributions are welcome! Please feel free to submit a Pull Request. If you are adding new data preprocessing scripts, place them in the ./datasets
directory and update tools/data_process.py
to include an option for your preprocessing.
This project is licensed under the Apache-2.0 License. See the LICENSE file for details.