Skip to content

Commit

Permalink
update dashboard
Browse files Browse the repository at this point in the history
  • Loading branch information
Marco Zanotti committed Jan 8, 2024
1 parent 9e2ada7 commit ef1add3
Show file tree
Hide file tree
Showing 6 changed files with 998 additions and 133 deletions.
20 changes: 15 additions & 5 deletions dashboard/R/generate_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ generate_ts_forecast <- function(data, method, params, n_future, seed = 1992) {
window_size = params$window_size
) |>
set_engine("window_function", window_function = mean, na.rm = TRUE) |>
fit(value ~ date, data = data |> select(-id))
fit(value ~ date, data = data |> select(-id, -frequency))

} else if (method == "ETS") {

Expand All @@ -28,11 +28,21 @@ generate_ts_forecast <- function(data, method, params, n_future, seed = 1992) {
smooth_seasonal = params$smooth_seasonal
) |>
set_engine("ets") |>
fit(value ~ date, data = data |> select(-id))
fit(value ~ date, data = data |> select(-id, -frequency))

} else if (method == "ARIMA") {

check_parameters(method, params)
wkfl_fit <- arima_reg(
non_seasonal_ar = params$non_seasonal_ar,
non_seasonal_differences = params$non_seasonal_differences,
non_seasonal_ma = params$non_seasonal_ma,
seasonal_ar = params$seasonal_ar,
seasonal_differences = params$seasonal_differences,
seasonal_ma = params$seasonal_ma
) |>
set_engine("arima") |>
fit(value ~ date, data = data |> select(-id, -frequency))

} else {
stop(paste("Unknown method", method))
Expand Down Expand Up @@ -60,7 +70,7 @@ generate_ml_forecast <- function(data, method, params, n_future, seed = 1992) {
future_tbl <- data |>
future_frame(.date_var = date, .length_out = n_future)

ml_rcp <- recipe(value ~ ., data = data |> select(-id)) |>
ml_rcp <- recipe(value ~ ., data = data |> select(-id, -frequency)) |>
step_timeseries_signature(date) |>
step_normalize(date_index.num) |>
step_zv(all_predictors()) |>
Expand All @@ -77,7 +87,7 @@ generate_ml_forecast <- function(data, method, params, n_future, seed = 1992) {
wkfl_fit <- workflow() |>
add_recipe(ml_rcp) |>
add_model(model_spec) |>
fit(data = data |> select(-id))
fit(data = data |> select(-id, -frequency))

} else if (method == "Elastic Net") {

Expand All @@ -91,7 +101,7 @@ generate_ml_forecast <- function(data, method, params, n_future, seed = 1992) {
wkfl_fit <- workflow() |>
add_recipe(ml_rcp) |>
add_model(model_spec) |>
fit(data = data |> select(-id))
fit(data = data |> select(-id, -frequency))

} else {
stop(paste("Unknown method", method))
Expand Down
122 changes: 122 additions & 0 deletions dashboard/R/get_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# function to get the data
get_data <- function(dataset_name) { # Monthly

if (dataset_name == "Air Passengers") {
data <- tibble(
"date" = seq.Date(as.Date("1949-01-01"), as.Date("1960-12-01"), by = "month"),
"id" = "Air Passengers",
"frequency" = "month",
"value" = datasets::AirPassengers |> as.numeric()
)
} else if (dataset_name == "Electricity Demand") { # Half-Hourly
data <- tibble(
"date" = tsibbledata::vic_elec$Time,
"id" = "Electricity Demand",
"frequency" = "half-hour",
"value" = tsibbledata::vic_elec$Demand
)
} else if (dataset_name == "Stock Price") { # Daily
data <- tibble(
"date" = tsibbledata::gafa_stock |> filter(Symbol == "AAPL") |> pull(Date),
"id" = "Apple Stock Price",
"frequency" = "bus-day",
"value" = tsibbledata::gafa_stock |> filter(Symbol == "AAPL") |> pull(Adj_Close)
)
} else if (dataset_name == "Tobacco Prod") { # Quarterly
data <- tibble(
"date" = seq.Date(as.Date("1950-01-01"), as.Date("1998-04-01"), by = "quarter"),
"id" = "Tobacco Prod",
"frequency" = "quarter",
"value" = tsibbledata::aus_production |> drop_na() |> pull(Tobacco)
)
} else if (dataset_name == "EU Population") { # Yearly
data <- tibble(
"date" = seq.Date(as.Date("1960-01-01"), as.Date("2017-01-01"), by = "year"),
"id" = "EU Population",
"frequency" = "year",
"value" = tsibbledata::global_economy |> filter(Country == "European Union") |> pull(Population)
)
} else if (dataset_name == "People Traffic") { # Weekly
data <- tibble(
"date" = seq.Date(as.Date("2000-01-01"), as.Date("2005-06-01"), by = "week"),
"id" = "People Traffic",
"frequency" = "week",
"value" = tsibbledata::ansett |> group_by(Week) |> summarise(value = sum(Passengers)) |> pull(value)
)
} else {
stop(paste("Unknown dataset", dataset_name))
}

return(data)

}

# function to impute missing values
impute_data <- function(data, params, freq) {

if (params$impute == FALSE) {
return(data)
} else {
n2f <- trunc(nrow(data) / freq)
p <- ifelse(n2f < 1, 1, 2)
data_impute <- data |> mutate(value = ts_impute_vec(value, period = p, lambda = "auto"))
return(data_impute)
}

}

# function to transform data
transform_data <- function(data, params, freq) {

trf_prm <- getOption("tsf.dashboard.transfs")
if (!all(trf_prm %in% names(params))) {
stop(paste("Unknown transformations!"))
}

transf_params <- c(
params$log, params$boxcox, params$norm,
params$stand, params$diff, params$sdiff
) |> as.logical()

if (!all(transf_params) == FALSE) {
return(data)
} else {

data_transf <- data

if (params$log) { # Log
data_transf <- data_transf |> mutate(value = log1p(value))
}
if (params$boxcox) { # Box-Cox
data_transf <- data_transf |> mutate(value = box_cox_vec(value + 1, lambda = "auto"))
}
if (params$norm) { # Normalization
data_transf <- data_transf |> mutate(value = normalize_vec(value))
}
if (params$stand) { # Standardization
data_transf <- data_transf |> mutate(value = standardize_vec(value))
}
if (params$diff) { # Differencing
data_transf <- data_transf |> mutate(value = diff_vec(value, difference = 1)) |> drop_na()
}
if (params$sdiff) { # Seasonal differencing
data_transf <- data_transf |> mutate(value = diff_vec(value, difference = 1, lag = freq)) |> drop_na()
}

return(data_transf)
}

}

# function to clean data from anomalies
clean_data <- function(data, params) {

if (params$clean == FALSE) {
return(data)
} else {
data_clean <- data |> mutate(value = ts_clean_vec(value))
return(data_clean)
}

}

2 changes: 1 addition & 1 deletion dashboard/R/packages.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Packages
pkgs <- c(
"tidyverse",
"tidyverse", "tsibbledata", "janitor",
"forecast", "prophet", "glmnet", "earth", "kernlab", "kknn",
"randomForest", "ranger", "xgboost", "treesnip", "lightgbm", "catboost",
"Cubist", "rules",
Expand Down
40 changes: 37 additions & 3 deletions dashboard/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,14 @@ set_options <- function() {
tsf.dashboard.methods_params = list(
"Rolling Average" = c("window_size"),
"ETS" = c("error", "trend", "season", "damping", "smooth_level", "smooth_trend", "smooth_season"),
"ARIMA" = c(""),
"ARIMA" = c(
"non_seasonal_ar", "non_seasonal_differences", "non_seasonal_ma",
"seasonal_ar", "seasonal_differences", "seasonal_ma"
),
"Linear Regression" = "none",
"Elastic Net" = c("penalty", "mixture")
)
),
tsf.dashboard.transfs = c("log", "boxcox", "norm", "stand", "diff", "sdiff")
)
toset <- !(names(op.tsf.dashboard) %in% names(op))
if (any(toset)) options(op.tsf.dashboard[toset])
Expand All @@ -123,6 +127,36 @@ set_options <- function() {

}

# function to convert frequency from character to numeric
parse_frequency <- function(frequency) {
if (frequency == "year") {
freq <- 1
} else if (frequency == "semester") {
freq <- 2
} else if (frequency == "quarter") {
freq <- 4
} else if (frequency == "month") {
freq <- 12
} else if (frequency == "week") {
freq <- 52
} else if (frequency == "bus-day") {
freq <- 252
} else if (frequency == "day") {
freq <- 365
} else if (frequency == "bus-hour") {
freq <- 252 * 24
} else if (frequency == "hour") {
freq <- 365 * 24
} else if (frequency == "bus-half-hour") {
freq <- 252 * 48
} else if (frequency == "half-hour") {
freq <- 365 * 48
} else {
stop(paste("Unknown frequency", frequency))
}
return(freq)
}

# function to understand if the method is a time series or a machine learning one
parse_method <- function(method) {

Expand All @@ -145,7 +179,7 @@ check_parameters <- function(method, params) {

mtd_prm <- getOption("tsf.dashboard.methods_params")[[method]]
if (!all(mtd_prm %in% names(params))) {
stop(paste("Parameters for method", method, "are not correct!"))
stop(paste("Parameters for", method, "are not correct!"))
}

}
Expand Down
Loading

0 comments on commit ef1add3

Please sign in to comment.