diff --git a/README.md b/README.md index ceaf5b4..aa8165e 100644 --- a/README.md +++ b/README.md @@ -1,43 +1,35 @@ -# Sensitivity Analysis for Multi-Horizon Time Series Forecasting +# Temporal Saliency Analysis for Multi-Horizon Time Series Forecasting using Deep Learning -## Sensitivity Analysis +Interpreting the model's behavior is important in understanding decision-making in practice. However, explaining complex time series forecasting models faces challenges due to temporal dependencies between subsequent time steps and the varying importance of input features over time. Many time series forecasting models use input context with a look-back window for better prediction performance. However, the existing studies (1) do not consider the temporal dependencies among the feature vectors in the input window and (2) separately consider the time dimension that the feature dimension when calculating the importance scores. In this work, we propose a novel **Windowed Temporal Saliency Analysis** method to address these issues. -### Definition -According to [Wikipedia](https://en.wikipedia.org/wiki/Sensitivity_analysis) -> Sensitivity analysis is the study of how the uncertainty in the output of a mathematical model or system (numerical or otherwise) can be apportioned to different sources of uncertainty in its inputs. - -The sensitivity of each input is often represented by a numeric value, called the sensitivity index. Sensitivity indices come in several forms: - -1. **First-order indices**: measures the contribution to the output variance by a single model input alone. -2. **Second-order indices**: measures the contribution to the output variance caused by the interaction of two model inputs. -3. **Total-order index**: measures the contribution to the output variance caused by a model input, including both its first-order effects (the input varying alone) and all higher-order interactions. +## Saliency Analysis -### [SALib](https://salib.readthedocs.io/en/latest/user_guide/getting-started.html) - -In this work we'll use **SALib** to perform the sensitivity analysis. -**SALib** is an open source library written in Python for performing sensitivity analyses. SALib provides a decoupled workflow, meaning it does not directly interface with the mathematical or computational model. Instead, SALib is responsible for generating the model inputs, using one of the sample functions, and computing the sensitivity indices from the model outputs, using one of the analyze functions. A typical sensitivity analysis using SALib follows four steps: +### Definition -1. Determine the model inputs (parameters) and their sample range. +Saliency Analysis is the study of input feature importance to model output using black-box interpretation techniques. We use the following libraries to perform the saliency analysis methods. -2. Run the sample function to generate the model inputs. +### [Captum](https://captum.ai/docs/introduction) +(“comprehension” in Latin) is an open source library for model interpretability built on PyTorch. -3. Evaluate the model using the generated inputs, saving the model outputs. +### [Time Interpret (tint)](https://josephenguehard.github.io/time_interpret/build/html/index.html) -4. Run the analyze function on the outputs to compute the sensitivity indices. +This package expands the Captum library with a specific focus on time-series. As such, it includes various interpretability methods specifically designed to handle time series data. ## Multi-Horizon Forecasting ### Definition -Multi-horizon forecasting is the prediction of variables-of-interest at multiple future time steps. It is a crucial challenge in time series machine learning. Most real-world datasets have a time component, and forecasting the future can unlock great value. For example, retailers can use future sales to optimize their supply chain and promotions, investment managers are interested in forecasting the future prices of financial assets to maximize their performance, and healthcare institutions can use the number of future patient admissions to have sufficient personnel and equipment. +Multi-horizon forecasting is the prediction of variables-of-interest at multiple future time steps. It is a crucial challenge in time series machine learning. Most real-world datasets have a time component, and forecasting the future can unlock great value. For example, retailers can use future sales to optimize their supply chain and promotions, investment managers are interested in forecasting the future prices of financial assets to maximize their performance, and healthcare institutions can use the number of future patient admissions to have sufficient personnel and equipment. -The current `Sensitivity Analysis` methods only count for analyzing sensitivity at a single point in time. However in this work we extend that to allow analyzing sensitivity for multiple input (window) and output (horizon) timesteps. +We use the following library for implementing the time series models, -### [PyTorch Forecasting](https://pytorch-forecasting.readthedocs.io/en/stable/getting-started.html) +### [Time-Series-Library (TSlib)](https://github.com/thuml/Time-Series-Library) -In this work we use `PyTorch Forecasting` to implement the timeseries models. This framework aims to ease state-of-the-art timeseries forecasting with neural networks for both real-world cases and research alike. The goal is to provide a high-level API with maximum flexibility for professionals and reasonable defaults for beginners. +TSlib is an open-source library for deep learning researchers, especially deep time series analysis. ## How to Reproduce +The module was developed using python 3.10. + ### Create Virtual Environment First create a virtual environment with the required libraries. For example, to create an venv named `ml`, you can either use the `Anaconda` library or your locally installed `python`. diff --git a/interpret.ipynb b/interpret.ipynb index 0be6a49..fb4a6c3 100644 --- a/interpret.ipynb +++ b/interpret.ipynb @@ -298,7 +298,7 @@ "results = []\n", "baseline_mode = \"aug\" # \"zeros\", \"aug\"\n", "result_columns = ['batch_index', 'explainer', 'metric', 'area', 'comp', 'suff']\n", - "output_file = open(\"interpretation_results.csv\", 'w')\n", + "output_file = open(os.path.join(result_folder, \"batch_interpretation_results.csv\"), 'w')\n", "output_file.write(','.join(result_columns))\n", "\n", "progress_bar = tqdm(\n", diff --git a/interpretation_results.csv b/interpretation_results.csv deleted file mode 100644 index fdc32ba..0000000 --- a/interpretation_results.csv +++ /dev/null @@ -1,73 +0,0 @@ -batch_index,explainer,metric,area,comp,suff -0,deep_lift,mae,0.1,3.041492462158203,4.536799430847168 -0,deep_lift,mse,0.1,0.42437857389450073,0.9564195871353149 -0,deep_lift,mae,0.2,4.30157470703125,5.090987205505371 -0,deep_lift,mse,0.2,0.8447372913360596,1.1792932748794556 -0,deep_lift,mae,0.5,4.672505855560303,4.955289363861084 -0,deep_lift,mse,0.5,1.0035548210144043,1.1028910875320435 -0,feature_ablation,mae,0.1,3.1585612297058105,4.690319538116455 -0,feature_ablation,mse,0.1,0.4561527967453003,1.0265034437179565 -0,feature_ablation,mae,0.2,4.333446979522705,5.2065229415893555 -0,feature_ablation,mse,0.2,0.8564496636390686,1.2322238683700562 -0,feature_ablation,mae,0.5,4.72280216217041,5.210166931152344 -0,feature_ablation,mse,0.5,1.0158641338348389,1.2171478271484375 -1,deep_lift,mae,0.1,3.8082618713378906,2.4431960582733154 -1,deep_lift,mse,0.1,0.7181265354156494,0.2862909734249115 -1,deep_lift,mae,0.2,5.094099044799805,3.6957709789276123 -1,deep_lift,mse,0.2,1.2297897338867188,0.6101534962654114 -1,deep_lift,mae,0.5,5.779333114624023,4.618597984313965 -1,deep_lift,mse,0.5,1.5334383249282837,0.9223220348358154 -1,feature_ablation,mae,0.1,3.612245798110962,2.5774784088134766 -1,feature_ablation,mse,0.1,0.628040075302124,0.32044675946235657 -1,feature_ablation,mae,0.2,4.8915276527404785,3.821155071258545 -1,feature_ablation,mse,0.2,1.1089290380477905,0.6431983113288879 -1,feature_ablation,mae,0.5,5.945882320404053,4.9836745262146 -1,feature_ablation,mse,0.5,1.613616704940796,1.0593359470367432 -2,deep_lift,mae,0.1,1.7236106395721436,2.1221845149993896 -2,deep_lift,mse,0.1,0.13686607778072357,0.29388725757598877 -2,deep_lift,mae,0.2,2.3030991554260254,2.6238250732421875 -2,deep_lift,mse,0.2,0.24028721451759338,0.38215166330337524 -2,deep_lift,mae,0.5,2.612053394317627,2.8026504516601562 -2,deep_lift,mse,0.5,0.31183862686157227,0.3913632333278656 -2,feature_ablation,mae,0.1,1.745111346244812,2.088286876678467 -2,feature_ablation,mse,0.1,0.14331957697868347,0.2835744023323059 -2,feature_ablation,mae,0.2,2.3512990474700928,2.609348773956299 -2,feature_ablation,mse,0.2,0.2579931616783142,0.36483266949653625 -2,feature_ablation,mae,0.5,2.847867488861084,2.91965913772583 -2,feature_ablation,mse,0.5,0.3936021625995636,0.4024607837200165 -3,deep_lift,mae,0.1,1.5281918048858643,1.710274338722229 -3,deep_lift,mse,0.1,0.11594387888908386,0.18924646079540253 -3,deep_lift,mae,0.2,2.149336814880371,2.0852315425872803 -3,deep_lift,mse,0.2,0.22342327237129211,0.2464267611503601 -3,deep_lift,mae,0.5,2.7848215103149414,2.4915027618408203 -3,deep_lift,mse,0.5,0.37463945150375366,0.2963715195655823 -3,feature_ablation,mae,0.1,1.5339816808700562,1.6620172262191772 -3,feature_ablation,mse,0.1,0.1122378259897232,0.17652803659439087 -3,feature_ablation,mae,0.2,2.1121225357055664,2.101673126220703 -3,feature_ablation,mse,0.2,0.2071421593427658,0.24249392747879028 -3,feature_ablation,mae,0.5,2.660428047180176,2.6340620517730713 -3,feature_ablation,mse,0.5,0.33043134212493896,0.31590867042541504 -4,deep_lift,mae,0.1,2.7547824382781982,2.215395927429199 -4,deep_lift,mse,0.1,0.35939887166023254,0.25845402479171753 -4,deep_lift,mae,0.2,3.865361452102661,2.812896966934204 -4,deep_lift,mse,0.2,0.6938855051994324,0.4277902841567993 -4,deep_lift,mae,0.5,4.964759349822998,3.354480504989624 -4,deep_lift,mse,0.5,1.1321254968643188,0.5172992944717407 -4,feature_ablation,mae,0.1,2.3283097743988037,2.2262160778045654 -4,feature_ablation,mse,0.1,0.24862250685691833,0.2732912003993988 -4,feature_ablation,mae,0.2,3.3290135860443115,2.8652126789093018 -4,feature_ablation,mse,0.2,0.5073292851448059,0.41869500279426575 -4,feature_ablation,mae,0.5,4.385256767272949,3.7674756050109863 -4,feature_ablation,mse,0.5,0.8946852087974548,0.643514096736908 -5,deep_lift,mae,0.1,1.2275972366333008,1.6170949935913086 -5,deep_lift,mse,0.1,0.07605863362550735,0.14558453857898712 -5,deep_lift,mae,0.2,1.5316154956817627,1.8780899047851562 -5,deep_lift,mse,0.2,0.12136727571487427,0.17312663793563843 -5,deep_lift,mae,0.5,1.6897684335708618,1.9774854183197021 -5,deep_lift,mse,0.5,0.1458672732114792,0.18247807025909424 -5,feature_ablation,mae,0.1,1.1961545944213867,1.584932565689087 -5,feature_ablation,mse,0.1,0.0690048336982727,0.1391582041978836 -5,feature_ablation,mae,0.2,1.5153186321258545,1.857421636581421 -5,feature_ablation,mse,0.2,0.11593860387802124,0.16239970922470093 -5,feature_ablation,mae,0.5,1.6729271411895752,2.036665678024292 -5,feature_ablation,mse,0.5,0.14078781008720398,0.18773791193962097 \ No newline at end of file