diff --git a/Project.toml b/Project.toml index 1148da9..1dc7ec2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "EvoTrees" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" authors = ["jeremiedb "] -version = "0.16.5" +version = "0.16.6" [deps] BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" diff --git a/benchmarks/Higgs-logloss.jl b/benchmarks/Higgs-logloss.jl index 8b5c61d..758ca42 100644 --- a/benchmarks/Higgs-logloss.jl +++ b/benchmarks/Higgs-logloss.jl @@ -4,6 +4,7 @@ using CSV using DataFrames using StatsBase using Statistics: mean, std +using CUDA using EvoTrees using Solage: Connectors using AWS: AWSCredentials, AWSConfig, @service @@ -33,10 +34,11 @@ dtest = df_tot[end-500_000+1:end, :]; config = EvoTreeRegressor( loss=:logloss, nrounds=5000, - eta=0.15, - nbins=128, - max_depth=9, - lambda=1.0, + eta=0.2, + nbins=224, + max_depth=11, + L2=1, + lambda=0.0, gamma=0.0, rowsample=0.8, colsample=0.8, @@ -44,12 +46,11 @@ config = EvoTreeRegressor( rng=123, ) -device = "gpu" +device = "cpu" metric = "logloss" @time m_evo = fit_evotree(config, dtrain; target_name, fnames=feature_names, deval, metric, device, early_stopping_rounds=200, print_every_n=100); p_test = m_evo(dtest); -@info extrema(p_test) logloss_test = mean(-dtest.y .* log.(p_test) .+ (dtest.y .- 1) .* log.(1 .- p_test)) @info "LogLoss - dtest" logloss_test error_test = 1 - mean(round.(Int, p_test) .== dtest.y) @@ -63,8 +64,8 @@ error_test = 1 - mean(round.(Int, p_test) .== dtest.y) @info "train" using XGBoost params_xgb = Dict( - :num_round => 4000, - :max_depth => 8, + :num_round => 2000, + :max_depth => 10, :eta => 0.15, :objective => "reg:logistic", :print_every_n => 5, @@ -81,8 +82,6 @@ watchlist = Dict("eval" => DMatrix(select(deval, feature_names), deval.y)); @time m_xgb = xgboost(dtrain_xgb; watchlist, nthread=Threads.nthreads(), verbosity=0, eval_metric="logloss", params_xgb...); pred_xgb = XGBoost.predict(m_xgb, DMatrix(select(deval, feature_names))); -@info extrema(pred_xgb) -# (1.9394008f-6, 0.9999975f0) logloss_test = mean(-dtest.y .* log.(pred_xgb) .+ (dtest.y .- 1) .* log.(1 .- pred_xgb)) @info "LogLoss - dtest" logloss_test error_test = 1 - mean(round.(Int, pred_xgb) .== dtest.y) diff --git a/src/MLJ.jl b/src/MLJ.jl index 4711f98..47499dd 100644 --- a/src/MLJ.jl +++ b/src/MLJ.jl @@ -165,7 +165,8 @@ A model type for constructing a EvoTreeRegressor, based on [EvoTrees.jl](https:/ - `nrounds=10`: Number of rounds. It corresponds to the number of trees that will be sequentially stacked. Must be >= 1. - `eta=0.1`: Learning rate. Each tree raw predictions are scaled by `eta` prior to be added to the stack of predictions. Must be > 0. A lower `eta` results in slower learning, requiring a higher `nrounds` but typically improves model performance. -- `lambda::T=0.0`: L2 regularization term on weights. Must be >= 0. Higher lambda can result in a more robust model. +- `L2::T=0.0`: L2 regularization factor on aggregate gain. Must be >= 0. Higher L2 can result in a more robust model. +- `lambda::T=0.0`: L2 regularization factor on individual gain. Must be >= 0. Higher lambda can result in a more robust model. - `gamma::T=0.0`: Minimum gain improvement needed to perform a node split. Higher gamma can result in a more robust model. Must be >= 0. - `alpha::T=0.5`: Loss specific parameter in the [0, 1] range: - `:quantile`: target quantile for the regression. @@ -286,22 +287,23 @@ EvoTreeClassifier is used to perform multi-class classification, using cross-ent # Hyper-parameters -- `nrounds=10`: Number of rounds. It corresponds to the number of trees that will be sequentially stacked. Must be >= 1. +- `nrounds=10`: Number of rounds. It corresponds to the number of trees that will be sequentially stacked. Must be >= 1. - `eta=0.1`: Learning rate. Each tree raw predictions are scaled by `eta` prior to be added to the stack of predictions. Must be > 0. A lower `eta` results in slower learning, requiring a higher `nrounds` but typically improves model performance. -- `lambda::T=0.0`: L2 regularization term on weights. Must be >= 0. Higher lambda can result in a more robust model. -- `gamma::T=0.0`: Minimum gain improvement needed to perform a node split. Higher gamma can result in a more robust model. Must be >= 0. -- `max_depth=5`: Maximum depth of a tree. Must be >= 1. A tree of depth 1 is made of a single prediction leaf. +- `L2::T=0.0`: L2 regularization factor on aggregate gain. Must be >= 0. Higher L2 can result in a more robust model. +- `lambda::T=0.0`: L2 regularization factor on individual gain. Must be >= 0. Higher lambda can result in a more robust model. +- `gamma::T=0.0`: Minimum gain improvement needed to perform a node split. Higher gamma can result in a more robust model. Must be >= 0. +- `max_depth=5`: Maximum depth of a tree. Must be >= 1. A tree of depth 1 is made of a single prediction leaf. A complete tree of depth N contains `2^(N - 1)` terminal leaves and `2^(N - 1) - 1` split nodes. Compute cost is proportional to `2^max_depth`. Typical optimal values are in the 3 to 9 range. -- `min_weight=1.0`: Minimum weight needed in a node to perform a split. Matches the number of observations by default or the sum of weights as provided by the `weights` vector. Must be > 0. -- `rowsample=1.0`: Proportion of rows that are sampled at each iteration to build the tree. Should be in `]0, 1]`. -- `colsample=1.0`: Proportion of columns / features that are sampled at each iteration to build the tree. Should be in `]0, 1]`. -- `nbins=32`: Number of bins into which each feature is quantized. Buckets are defined based on quantiles, hence resulting in equal weight bins. Should be between 2 and 255. +- `min_weight=1.0`: Minimum weight needed in a node to perform a split. Matches the number of observations by default or the sum of weights as provided by the `weights` vector. Must be > 0. +- `rowsample=1.0`: Proportion of rows that are sampled at each iteration to build the tree. Should be in `]0, 1]`. +- `colsample=1.0`: Proportion of columns / features that are sampled at each iteration to build the tree. Should be in `]0, 1]`. +- `nbins=32`: Number of bins into which each feature is quantized. Buckets are defined based on quantiles, hence resulting in equal weight bins. Should be between 2 and 255. - `tree_type="binary"` Tree structure to be used. One of: - `binary`: Each node of a tree is grown independently. Tree are built depthwise until max depth is reach or if min weight or gain (see `gamma`) stops further node splits. - `oblivious`: A common splitting condition is imposed to all nodes of a given depth. -- `rng=123`: Either an integer used as a seed to the random number generator or an actual random number generator (`::Random.AbstractRNG`). +- `rng=123`: Either an integer used as a seed to the random number generator or an actual random number generator (`::Random.AbstractRNG`). # Internal API @@ -410,23 +412,24 @@ EvoTreeCount is used to perform Poisson probabilistic regression on count target # Hyper-parameters -- `nrounds=10`: Number of rounds. It corresponds to the number of trees that will be sequentially stacked. Must be >= 1. +- `nrounds=10`: Number of rounds. It corresponds to the number of trees that will be sequentially stacked. Must be >= 1. - `eta=0.1`: Learning rate. Each tree raw predictions are scaled by `eta` prior to be added to the stack of predictions. Must be > 0. A lower `eta` results in slower learning, requiring a higher `nrounds` but typically improves model performance. -- `lambda::T=0.0`: L2 regularization term on weights. Must be >= 0. Higher lambda can result in a more robust model. Must be >= 0. -- `gamma::T=0.0`: Minimum gain imprvement needed to perform a node split. Higher gamma can result in a more robust model. -- `max_depth=5`: Maximum depth of a tree. Must be >= 1. A tree of depth 1 is made of a single prediction leaf. +- `L2::T=0.0`: L2 regularization factor on aggregate gain. Must be >= 0. Higher L2 can result in a more robust model. +- `lambda::T=0.0`: L2 regularization factor on individual gain. Must be >= 0. Higher lambda can result in a more robust model. +- `gamma::T=0.0`: Minimum gain imprvement needed to perform a node split. Higher gamma can result in a more robust model. +- `max_depth=5`: Maximum depth of a tree. Must be >= 1. A tree of depth 1 is made of a single prediction leaf. A complete tree of depth N contains `2^(N - 1)` terminal leaves and `2^(N - 1) - 1` split nodes. Compute cost is proportional to 2^max_depth. Typical optimal values are in the 3 to 9 range. -- `min_weight=1.0`: Minimum weight needed in a node to perform a split. Matches the number of observations by default or the sum of weights as provided by the `weights` vector. Must be > 0. -- `rowsample=1.0`: Proportion of rows that are sampled at each iteration to build the tree. Should be `]0, 1]`. -- `colsample=1.0`: Proportion of columns / features that are sampled at each iteration to build the tree. Should be `]0, 1]`. -- `nbins=32`: Number of bins into which each feature is quantized. Buckets are defined based on quantiles, hence resulting in equal weight bins. Should be between 2 and 255. +- `min_weight=1.0`: Minimum weight needed in a node to perform a split. Matches the number of observations by default or the sum of weights as provided by the `weights` vector. Must be > 0. +- `rowsample=1.0`: Proportion of rows that are sampled at each iteration to build the tree. Should be `]0, 1]`. +- `colsample=1.0`: Proportion of columns / features that are sampled at each iteration to build the tree. Should be `]0, 1]`. +- `nbins=32`: Number of bins into which each feature is quantized. Buckets are defined based on quantiles, hence resulting in equal weight bins. Should be between 2 and 255. - `monotone_constraints=Dict{Int, Int}()`: Specify monotonic constraints using a dict where the key is the feature index and the value the applicable constraint (-1=decreasing, 0=none, 1=increasing). - `tree_type="binary"` Tree structure to be used. One of: - `binary`: Each node of a tree is grown independently. Tree are built depthwise until max depth is reach or if min weight or gain (see `gamma`) stops further node splits. - `oblivious`: A common splitting condition is imposed to all nodes of a given depth. -- `rng=123`: Either an integer used as a seed to the random number generator or an actual random number generator (`::Random.AbstractRNG`). +- `rng=123`: Either an integer used as a seed to the random number generator or an actual random number generator (`::Random.AbstractRNG`). # Internal API @@ -539,24 +542,25 @@ EvoTreeGaussian is used to perform Gaussian probabilistic regression, fitting μ # Hyper-parameters -- `nrounds=10`: Number of rounds. It corresponds to the number of trees that will be sequentially stacked. Must be >= 1. +- `nrounds=10`: Number of rounds. It corresponds to the number of trees that will be sequentially stacked. Must be >= 1. - `eta=0.1`: Learning rate. Each tree raw predictions are scaled by `eta` prior to be added to the stack of predictions. Must be > 0. A lower `eta` results in slower learning, requiring a higher `nrounds` but typically improves model performance. -- `lambda::T=0.0`: L2 regularization term on weights. Must be >= 0. Higher lambda can result in a more robust model. -- `gamma::T=0.0`: Minimum gain imprvement needed to perform a node split. Higher gamma can result in a more robust model. Must be >= 0. -- `max_depth=5`: Maximum depth of a tree. Must be >= 1. A tree of depth 1 is made of a single prediction leaf. +- `L2::T=0.0`: L2 regularization factor on aggregate gain. Must be >= 0. Higher L2 can result in a more robust model. +- `lambda::T=0.0`: L2 regularization factor on individual gain. Must be >= 0. Higher lambda can result in a more robust model. +- `gamma::T=0.0`: Minimum gain imprvement needed to perform a node split. Higher gamma can result in a more robust model. Must be >= 0. +- `max_depth=5`: Maximum depth of a tree. Must be >= 1. A tree of depth 1 is made of a single prediction leaf. A complete tree of depth N contains `2^(N - 1)` terminal leaves and `2^(N - 1) - 1` split nodes. Compute cost is proportional to 2^max_depth. Typical optimal values are in the 3 to 9 range. -- `min_weight=8.0`: Minimum weight needed in a node to perform a split. Matches the number of observations by default or the sum of weights as provided by the `weights` vector. Must be > 0. -- `rowsample=1.0`: Proportion of rows that are sampled at each iteration to build the tree. Should be in `]0, 1]`. -- `colsample=1.0`: Proportion of columns / features that are sampled at each iteration to build the tree. Should be in `]0, 1]`. -- `nbins=32`: Number of bins into which each feature is quantized. Buckets are defined based on quantiles, hence resulting in equal weight bins. Should be between 2 and 255. +- `min_weight=8.0`: Minimum weight needed in a node to perform a split. Matches the number of observations by default or the sum of weights as provided by the `weights` vector. Must be > 0. +- `rowsample=1.0`: Proportion of rows that are sampled at each iteration to build the tree. Should be in `]0, 1]`. +- `colsample=1.0`: Proportion of columns / features that are sampled at each iteration to build the tree. Should be in `]0, 1]`. +- `nbins=32`: Number of bins into which each feature is quantized. Buckets are defined based on quantiles, hence resulting in equal weight bins. Should be between 2 and 255. - `monotone_constraints=Dict{Int, Int}()`: Specify monotonic constraints using a dict where the key is the feature index and the value the applicable constraint (-1=decreasing, 0=none, 1=increasing). !Experimental feature: note that for Gaussian regression, constraints may not be enforce systematically. - `tree_type="binary"` Tree structure to be used. One of: - `binary`: Each node of a tree is grown independently. Tree are built depthwise until max depth is reach or if min weight or gain (see `gamma`) stops further node splits. - `oblivious`: A common splitting condition is imposed to all nodes of a given depth. -- `rng=123`: Either an integer used as a seed to the random number generator or an actual random number generator (`::Random.AbstractRNG`). +- `rng=123`: Either an integer used as a seed to the random number generator or an actual random number generator (`::Random.AbstractRNG`). # Internal API @@ -676,24 +680,25 @@ EvoTreeMLE performs maximum likelihood estimation. Assumed distribution is speci `loss=:gaussian`: Loss to be be minimized during training. One of: - `:gaussian` / `:gaussian_mle` - `:logistic` / `:logistic_mle` -- `nrounds=10`: Number of rounds. It corresponds to the number of trees that will be sequentially stacked. Must be >= 1. +- `nrounds=10`: Number of rounds. It corresponds to the number of trees that will be sequentially stacked. Must be >= 1. - `eta=0.1`: Learning rate. Each tree raw predictions are scaled by `eta` prior to be added to the stack of predictions. Must be > 0. A lower `eta` results in slower learning, requiring a higher `nrounds` but typically improves model performance. -- `lambda::T=0.0`: L2 regularization term on weights. Must be >= 0. Higher lambda can result in a more robust model. -- `gamma::T=0.0`: Minimum gain imprvement needed to perform a node split. Higher gamma can result in a more robust model. Must be >= 0. -- `max_depth=5`: Maximum depth of a tree. Must be >= 1. A tree of depth 1 is made of a single prediction leaf. +- `L2::T=0.0`: L2 regularization factor on aggregate gain. Must be >= 0. Higher L2 can result in a more robust model. +- `lambda::T=0.0`: L2 regularization factor on individual gain. Must be >= 0. Higher lambda can result in a more robust model. +- `gamma::T=0.0`: Minimum gain imprvement needed to perform a node split. Higher gamma can result in a more robust model. Must be >= 0. +- `max_depth=5`: Maximum depth of a tree. Must be >= 1. A tree of depth 1 is made of a single prediction leaf. A complete tree of depth N contains `2^(N - 1)` terminal leaves and `2^(N - 1) - 1` split nodes. Compute cost is proportional to 2^max_depth. Typical optimal values are in the 3 to 9 range. -- `min_weight=8.0`: Minimum weight needed in a node to perform a split. Matches the number of observations by default or the sum of weights as provided by the `weights` vector. Must be > 0. -- `rowsample=1.0`: Proportion of rows that are sampled at each iteration to build the tree. Should be in `]0, 1]`. -- `colsample=1.0`: Proportion of columns / features that are sampled at each iteration to build the tree. Should be in `]0, 1]`. -- `nbins=32`: Number of bins into which each feature is quantized. Buckets are defined based on quantiles, hence resulting in equal weight bins. Should be between 2 and 255. +- `min_weight=8.0`: Minimum weight needed in a node to perform a split. Matches the number of observations by default or the sum of weights as provided by the `weights` vector. Must be > 0. +- `rowsample=1.0`: Proportion of rows that are sampled at each iteration to build the tree. Should be in `]0, 1]`. +- `colsample=1.0`: Proportion of columns / features that are sampled at each iteration to build the tree. Should be in `]0, 1]`. +- `nbins=32`: Number of bins into which each feature is quantized. Buckets are defined based on quantiles, hence resulting in equal weight bins. Should be between 2 and 255. - `monotone_constraints=Dict{Int, Int}()`: Specify monotonic constraints using a dict where the key is the feature index and the value the applicable constraint (-1=decreasing, 0=none, 1=increasing). !Experimental feature: note that for MLE regression, constraints may not be enforced systematically. -- `tree_type="binary"` Tree structure to be used. One of: +- `tree_type="binary"` Tree structure to be used. One of: - `binary`: Each node of a tree is grown independently. Tree are built depthwise until max depth is reach or if min weight or gain (see `gamma`) stops further node splits. - `oblivious`: A common splitting condition is imposed to all nodes of a given depth. -- `rng=123`: Either an integer used as a seed to the random number generator or an actual random number generator (`::Random.AbstractRNG`). +- `rng=123`: Either an integer used as a seed to the random number generator or an actual random number generator (`::Random.AbstractRNG`). # Internal API diff --git a/src/loss.jl b/src/loss.jl index 07d006d..63b6420 100644 --- a/src/loss.jl +++ b/src/loss.jl @@ -144,13 +144,13 @@ end # GradientRegression function get_gain(params::EvoTypes{L}, ∑::AbstractVector) where {L<:GradientRegression} ϵ = eps(eltype(∑)) - ∑[1]^2 / max(ϵ, (∑[2] + params.lambda * ∑[3])) / 2 + ∑[1]^2 / max(ϵ, (∑[2] + params.lambda * ∑[3] + params.L2)) / 2 end # GaussianRegression function get_gain(params::EvoTypes{L}, ∑::AbstractVector) where {L<:MLE2P} ϵ = eps(eltype(∑)) - (∑[1]^2 / max(ϵ, (∑[3] + params.lambda * ∑[5])) + ∑[2]^2 / max(ϵ, (∑[4] + params.lambda * ∑[5]))) / 2 + (∑[1]^2 / max(ϵ, (∑[3] + params.lambda * ∑[5] + params.L2)) + ∑[2]^2 / max(ϵ, (∑[4] + params.lambda * ∑[5] + params.L2))) / 2 end # MultiClassRegression @@ -159,7 +159,7 @@ function get_gain(params::EvoTypes{L}, ∑::AbstractVector{T}) where {L<:MLogLos gain = zero(T) K = (length(∑) - 1) ÷ 2 @inbounds for k = 1:K - gain += ∑[k]^2 / max(ϵ, (∑[k+K] + params.lambda * ∑[end])) / 2 + gain += ∑[k]^2 / max(ϵ, (∑[k+K] + params.lambda * ∑[end] + params.L2)) / 2 end return gain end diff --git a/src/models.jl b/src/models.jl index 5a74c0f..9e8571a 100644 --- a/src/models.jl +++ b/src/models.jl @@ -34,6 +34,7 @@ mk_rng(int::Integer) = Random.MersenneTwister(int) mutable struct EvoTreeRegressor{L<:ModelType} <: MMI.Deterministic nrounds::Int + L2::Float64 lambda::Float64 gamma::Float64 eta::Float64 @@ -54,6 +55,7 @@ function EvoTreeRegressor(; kwargs...) args = Dict{Symbol,Any}( :loss => :mse, :nrounds => 100, + :L2 => 0.0, :lambda => 0.0, :gamma => 0.0, # min gain to split :eta => 0.1, # learning rate @@ -102,6 +104,7 @@ function EvoTreeRegressor(; kwargs...) model = EvoTreeRegressor{L}( args[:nrounds], + args[:L2], args[:lambda], args[:gamma], args[:eta], @@ -125,6 +128,7 @@ end mutable struct EvoTreeCount{L<:ModelType} <: MMI.Probabilistic nrounds::Int + L2::Float64 lambda::Float64 gamma::Float64 eta::Float64 @@ -144,6 +148,7 @@ function EvoTreeCount(; kwargs...) # defaults arguments args = Dict{Symbol,Any}( :nrounds => 100, + :L2 => 0.0, :lambda => 0.0, :gamma => 0.0, # min gain to split :eta => 0.1, # learning rate @@ -169,6 +174,7 @@ function EvoTreeCount(; kwargs...) model = EvoTreeCount{L}( args[:nrounds], + args[:L2], args[:lambda], args[:gamma], args[:eta], @@ -192,6 +198,7 @@ end mutable struct EvoTreeClassifier{L<:ModelType} <: MMI.Probabilistic nrounds::Int + L2::Float64 lambda::Float64 gamma::Float64 eta::Float64 @@ -210,6 +217,7 @@ function EvoTreeClassifier(; kwargs...) # defaults arguments args = Dict{Symbol,Any}( :nrounds => 100, + :L2 => 0.0, :lambda => 0.0, :gamma => 0.0, # min gain to split :eta => 0.1, # learning rate @@ -234,6 +242,7 @@ function EvoTreeClassifier(; kwargs...) model = EvoTreeClassifier{L}( args[:nrounds], + args[:L2], args[:lambda], args[:gamma], args[:eta], @@ -256,6 +265,7 @@ end mutable struct EvoTreeMLE{L<:ModelType} <: MMI.Probabilistic nrounds::Int + L2::Float64 lambda::Float64 gamma::Float64 eta::Float64 @@ -276,6 +286,7 @@ function EvoTreeMLE(; kwargs...) args = Dict{Symbol,Any}( :loss => :gaussian_mle, :nrounds => 100, + :L2 => 0.0, :lambda => 0.0, :gamma => 0.0, # min gain to split :eta => 0.1, # learning rate @@ -312,6 +323,7 @@ function EvoTreeMLE(; kwargs...) model = EvoTreeMLE{L}( args[:nrounds], + args[:L2], args[:lambda], args[:gamma], args[:eta], @@ -341,6 +353,7 @@ end mutable struct EvoTreeGaussian{L<:ModelType} <: MMI.Probabilistic nrounds::Int + L2::Float64 lambda::Float64 gamma::Float64 eta::Float64 @@ -359,6 +372,7 @@ function EvoTreeGaussian(; kwargs...) # defaults arguments args = Dict{Symbol,Any}( :nrounds => 100, + :L2 => 0.0, :lambda => 0.0, :gamma => 0.0, # min gain to split :eta => 0.1, # learning rate @@ -384,6 +398,7 @@ function EvoTreeGaussian(; kwargs...) model = EvoTreeGaussian{L}( args[:nrounds], + args[:L2], args[:lambda], args[:gamma], args[:eta], diff --git a/src/predict.jl b/src/predict.jl index 9c753b2..2e2f07d 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -120,22 +120,22 @@ end function pred_leaf_cpu!(p::Matrix, n, ∑::AbstractVector{T}, params::EvoTypes{L}, ∇, is) where {L<:GradientRegression,T} ϵ = eps(T) - p[1, n] = -params.eta * ∑[1] / max(ϵ, (∑[2] + params.lambda * ∑[3])) + p[1, n] = -params.eta * ∑[1] / max(ϵ, (∑[2] + params.lambda * ∑[3] + params.L2)) end function pred_scalar(∑::AbstractVector{T}, params::EvoTypes{L}) where {L<:GradientRegression,T} ϵ = eps(T) - -params.eta * ∑[1] / max(ϵ, (∑[2] + params.lambda * ∑[3])) + -params.eta * ∑[1] / max(ϵ, (∑[2] + params.lambda * ∑[3] + params.L2)) end # prediction in Leaf - MLE2P function pred_leaf_cpu!(p::Matrix, n, ∑::AbstractVector{T}, params::EvoTypes{L}, ∇, is) where {L<:MLE2P,T} ϵ = eps(T) - p[1, n] = -params.eta * ∑[1] / max(ϵ, (∑[3] + params.lambda * ∑[5])) - p[2, n] = -params.eta * ∑[2] / max(ϵ, (∑[4] + params.lambda * ∑[5])) + p[1, n] = -params.eta * ∑[1] / max(ϵ, (∑[3] + params.lambda * ∑[5] + params.L2)) + p[2, n] = -params.eta * ∑[2] / max(ϵ, (∑[4] + params.lambda * ∑[5] + params.L2)) end function pred_scalar(∑::AbstractVector{T}, params::EvoTypes{L}) where {L<:MLE2P,T} ϵ = eps(T) - -params.eta * ∑[1] / max(ϵ, (∑[3] + params.lambda * ∑[5])) + -params.eta * ∑[1] / max(ϵ, (∑[3] + params.lambda * ∑[5] + params.L2)) end # prediction in Leaf - MultiClassRegression @@ -143,21 +143,21 @@ function pred_leaf_cpu!(p::Matrix, n, ∑::AbstractVector{T}, params::EvoTypes{L ϵ = eps(T) K = size(p, 1) @inbounds for k = axes(p, 1) - p[k, n] = -params.eta * ∑[k] / max(ϵ, (∑[k+K] + params.lambda * ∑[end])) + p[k, n] = -params.eta * ∑[k] / max(ϵ, (∑[k+K] + params.lambda * ∑[end] + params.L2)) end end # prediction in Leaf - Quantile function pred_leaf_cpu!(p::Matrix, n, ∑::AbstractVector{T}, params::EvoTypes{L}, ∇, is) where {L<:Quantile,T} - p[1, n] = params.eta * quantile(∇[2, is], params.alpha) / (1 + params.lambda) + p[1, n] = params.eta * quantile(∇[2, is], params.alpha) / (1 + params.lambda + params.L2) end # prediction in Leaf - L1 function pred_leaf_cpu!(p::Matrix, n, ∑::AbstractVector{T}, params::EvoTypes{L}, ∇, is) where {L<:L1,T} ϵ = eps(T) - p[1, n] = params.eta * ∑[1] / max(ϵ, (∑[3] * (1 + params.lambda))) + p[1, n] = params.eta * ∑[1] / max(ϵ, (∑[3] * (1 + params.lambda + params.L2))) end function pred_scalar(∑::AbstractVector, params::EvoTypes{L1}) ϵ = eps(T) - params.eta * ∑[1] / max(ϵ, (∑[3] * (1 + params.lambda))) + params.eta * ∑[1] / max(ϵ, (∑[3] * (1 + params.lambda + params.L2))) end