Skip to content

Commit

Permalink
fix #267
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedb committed Mar 31, 2024
1 parent b4dfe6f commit 6e92369
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EvoTrees"
uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
authors = ["jeremiedb <jeremie.db@evovest.com>"]
version = "0.16.6"
version = "0.16.7"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Expand Down
3 changes: 3 additions & 0 deletions ext/EvoTreesCUDAExt/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ function EvoTrees.init_core(params::EvoTrees.EvoTypes{L}, ::Type{<:EvoTrees.GPU}
T = Float32

target_levels = nothing
target_isordered = false
if L == EvoTrees.Logistic
@assert eltype(y_train) <: Real && minimum(y_train) >= 0 && maximum(y_train) <= 1
K = 1
Expand All @@ -22,6 +23,7 @@ function EvoTrees.init_core(params::EvoTrees.EvoTypes{L}, ::Type{<:EvoTrees.GPU}
elseif L == EvoTrees.MLogLoss
if eltype(y_train) <: EvoTrees.CategoricalValue
target_levels = EvoTrees.CategoricalArrays.levels(y_train)
target_isordered = isordered(y_train)
y = UInt32.(EvoTrees.CategoricalArrays.levelcode.(y_train))
elseif eltype(y_train) <: Integer || eltype(y_train) <: Bool || eltype(y_train) <: String || eltype(y_train) <: Char
target_levels = sort(unique(y_train))
Expand Down Expand Up @@ -89,6 +91,7 @@ function EvoTrees.init_core(params::EvoTrees.EvoTypes{L}, ::Type{<:EvoTrees.GPU}
info = Dict(
:fnames => fnames,
:target_levels => target_levels,
:target_isordered => target_isordered,
:edges => edges,
:featbins => featbins,
:feattypes => feattypes,
Expand Down
2 changes: 1 addition & 1 deletion src/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end

function predict(::EvoTreeClassifier, fitresult, A)
pred = predict(fitresult, A)
return MMI.UnivariateFinite(fitresult.info[:target_levels], pred, pool=missing)
return MMI.UnivariateFinite(fitresult.info[:target_levels], pred, pool=missing, ordered=fitresult.info[:target_isordered])
end

function predict(::EvoTreeCount, fitresult, A)
Expand Down
3 changes: 3 additions & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ function init_core(params::EvoTypes{L}, ::Type{CPU}, data, fnames, y_train, w, o
T = Float32

target_levels = nothing
target_isordered = false
if L == Logistic
@assert eltype(y_train) <: Real && minimum(y_train) >= 0 && maximum(y_train) <= 1
K = 1
Expand All @@ -22,6 +23,7 @@ function init_core(params::EvoTypes{L}, ::Type{CPU}, data, fnames, y_train, w, o
elseif L == MLogLoss
if eltype(y_train) <: CategoricalValue
target_levels = CategoricalArrays.levels(y_train)
target_isordered = isordered(y_train)
y = UInt32.(CategoricalArrays.levelcode.(y_train))
elseif eltype(y_train) <: Integer || eltype(y_train) <: Bool || eltype(y_train) <: String || eltype(y_train) <: Char
target_levels = sort(unique(y_train))
Expand Down Expand Up @@ -87,6 +89,7 @@ function init_core(params::EvoTypes{L}, ::Type{CPU}, data, fnames, y_train, w, o
info = Dict(
:fnames => fnames,
:target_levels => target_levels,
:target_isordered => target_isordered,
:edges => edges,
:featbins => featbins,
:feattypes => feattypes,
Expand Down
18 changes: 18 additions & 0 deletions test/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,21 @@ end
fit!(mach)
predict(mach, X)
end

##################################################
### issue #267: ordered target
##################################################
using CategoricalArrays
y = categorical(collect("cbbba"), levels=['b', 'a', 'c'], ordered=true)
lvls = levels(y)
eltype(y) <: CategoricalValue
isordered(y)

using MLJBase, EvoTrees
# using StatisticalMeasures
X = (; x=rand(10))
y = coerce(rand("ab", 10), OrderedFactor)
model = EvoTreeClassifier()
mach = machine(model, X, y) |> fit!
yhat = predict(mach, X)
@assert isordered(yhat)

0 comments on commit 6e92369

Please sign in to comment.