Skip to content

Commit

Permalink
update TMLE
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Apr 1, 2024
1 parent fa165e0 commit bf65808
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 41 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
PackageCompiler = "2.1.16"
ArgParse = "1.1.4"
Arrow = "2.5.2"
CSV = "0.10"
Expand All @@ -55,7 +54,8 @@ MLJModels = "0.16"
MLJXGBoostInterface = "0.3.4"
MultipleTesting = "0.6.0"
Optim = "1.7"
TMLE = "0.15.0"
PackageCompiler = "2.1.16"
TMLE = "0.16"
Tables = "1.10.1"
YAML = "0.4.9"
julia = "1.7, 1"
25 changes: 6 additions & 19 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,27 +152,14 @@ function make_float!(dataset, colnames)
end
end

function coerce_types!(dataset, Ψ::ComposedEstimand)
for arg in Ψ.args
coerce_types!(dataset, arg)
end
function coerce_types!(dataset, colnames)
infered_types = autotype(dataset[!, colnames])
coerce!(dataset, infered_types)
end

function coerce_types!(dataset, Ψ)
# Make Treatments categorical but preserve order
categorical_variables = Set(keys.treatment_values))
make_categorical!(dataset, categorical_variables, infer_ordered=true)
# Make Confounders and extra covariates continuous
continuous_variables = Set(Iterators.flatten(values.treatment_confounders)))
union!(continuous_variables, Ψ.outcome_extra_covariates)
make_float!(dataset, continuous_variables)
# Make outcome categorical if binary but do not infer order
if TMLE.is_binary(dataset, Ψ.outcome)
make_categorical!(dataset, Ψ.outcome, infer_ordered=false)
else
make_float!(dataset, Ψ.outcome)
end
end
coerce_types!(dataset, Ψ::TMLE.Estimand) =
coerce_types!(dataset, collect(variables(Ψ)))


variables::TMLE.ComposedEstimand) = union((variables(arg) for arg in Ψ.args)...)

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand Down
4 changes: 3 additions & 1 deletion test/runner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Serialization
using YAML
using JSON
using MLJBase
using MLJModels

PKGDIR = pkgdir(TargetedEstimation)
TESTDIR = joinpath(PKGDIR, "test")
Expand All @@ -29,7 +30,8 @@ include(joinpath(TESTDIR, "testutils.jl"))
@test estimators.OSE isa TMLE.OSE
@test estimators.TMLE.weighted === true
@test estimators.TMLE.models.G_default === estimators.OSE.models.G_default
@test estimators.TMLE.models.G_default isa MLJBase.ProbabilisticStack
@test estimators.TMLE.models.G_default.continuous_encoder isa MLJModels.ContinuousEncoder
@test estimators.TMLE.models.G_default.probabilistic_stack isa MLJBase.ProbabilisticStack
# From already constructed estimators
estimators_new = TargetedEstimation.instantiate_estimators(estimators)
@test estimators_new === estimators
Expand Down
10 changes: 5 additions & 5 deletions test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,17 @@ function build_dataset(;n=1000, format="csv")

dataset = DataFrame(
SAMPLE_ID = 1:n,
T1 = categorical(T₁),
T2 = categorical(T₂),
T1 = T₁,
T2 = T₂,
W1 = W₁,
W2 = W₂,
C1 = C₁,
)
# Comma in name
dataset[!, "CONTINUOUS, OUTCOME"] = categorical(y₁)
dataset[!, "CONTINUOUS, OUTCOME"] = y₁
# Slash in name
dataset[!, "BINARY/OUTCOME"] = categorical(y₂)
dataset[!, "EXTREME_BINARY"] = categorical(vcat(0, ones(n-1)))
dataset[!, "BINARY/OUTCOME"] = y₂
dataset[!, "EXTREME_BINARY"] = vcat(0, ones(n-1))

return dataset
end
Expand Down
29 changes: 15 additions & 14 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using DataFrames
using CSV
using MLJLinearModels
using CategoricalArrays
using MLJBase

check_type(treatment_value, ::Type{T}) where T = @test treatment_value isa T

Expand Down Expand Up @@ -74,21 +75,21 @@ end
)

dataset = DataFrame(
Ycont = [1.1, 2.2, missing],
Ycat = [1., 0., missing],
T₁ = [1, 0, missing],
T₂ = [missing, "AC", "CC"],
W₁ = [1., 0., 0.],
W₂ = [missing, 0., 0.],
C = [1, 2, 3]
Ycont = [1.1, 2.2, missing, 3.5, 6.6, 0., 4.],
Ycat = [1., 0., missing, 1., 0, 0, 0],
T₁ = [1, 0, missing, 0, 0, 0, missing],
T₂ = [missing, "AC", "CC", "CC", missing, "AA", "AA"],
W₁ = [1., 0., 0., 1., 0., 1, 1],
W₂ = [missing, 0., 0., 0., 0., 0., 0.],
C = [1, 2, 3, 4, 5, 6, 6]
)
TargetedEstimation.coerce_types!(dataset, Ψ)

@test dataset.T₁ isa CategoricalArray
@test dataset.T₂ isa CategoricalArray
for var in [:W₁, :W₂, :Ycont]
@test eltype(dataset[!, var]) <: Union{Missing, Float64}
end
@test scitype(dataset.T₁) == AbstractVector{Union{Missing, OrderedFactor{2}}}
@test scitype(dataset.T₂) == AbstractVector{Union{Missing, Multiclass{3}}}
@test scitype(dataset.Ycont) == AbstractVector{Union{Missing, MLJBase.Continuous}}
@test scitype(dataset.W₁) == AbstractVector{OrderedFactor{2}}
@test scitype(dataset.W₂) == AbstractVector{Union{Missing, OrderedFactor{1}}}

Ψ = IATE(
outcome=:Ycat,
Expand All @@ -98,8 +99,8 @@ end
)
TargetedEstimation.coerce_types!(dataset, Ψ)

@test dataset.Ycat isa CategoricalArray
@test eltype(dataset.C) <: Union{Missing, Float64}
@test scitype(dataset.Ycat) == AbstractVector{Union{Missing, OrderedFactor{2}}}
@test scitype(dataset.C) == AbstractVector{Count}
end

@testset "Test misc" begin
Expand Down

0 comments on commit bf65808

Please sign in to comment.