Skip to content

Commit

Permalink
Merge pull request #23 from TARGENE/runner_constructor
Browse files Browse the repository at this point in the history
Runner constructor
  • Loading branch information
olivierlabayle authored Feb 21, 2024
2 parents 19462d3 + 6fba0df commit 6696024
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 74 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TargetedEstimation"
uuid = "2573d147-4098-46ba-9db2-8608d210ccac"
authors = ["Olivier Labayle"]
version = "0.8.0"
version = "0.9.0"

[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
Expand Down
26 changes: 19 additions & 7 deletions src/runner.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
instantiate_estimators(file::AbstractString) = load_tmle_spec(;file=file)
instantiate_estimators(estimators) = estimators

function load_tmle_spec(;file="glmnet")
file = endswith(file, ".jl") ? file : joinpath(
pkgdir(TargetedEstimation),
"estimators-configs",
string(file, ".jl"))
include(abspath(file))
return ESTIMATORS
end

mutable struct Runner
estimators::NamedTuple
estimands::Vector{TMLE.Estimand}
Expand All @@ -8,8 +20,8 @@ mutable struct Runner
verbosity::Int
failed_nuisance::Set
function Runner(dataset;
estimands="factorialATE",
estimators="glmnet",
estimands_config="factorialATE",
estimators_spec="glmnet",
verbosity=0,
outputs=Outputs(),
chunksize=100,
Expand All @@ -18,11 +30,11 @@ mutable struct Runner
sort_estimands=false
)
# Retrieve TMLE specifications
estimators = TargetedEstimation.load_tmle_spec(file=estimators)
estimators = instantiate_estimators(estimators_spec)
# Load dataset
dataset = TargetedEstimation.instantiate_dataset(dataset)
dataset = instantiate_dataset(dataset)
# Read parameter files
estimands = TargetedEstimation.build_estimands_list(estimands, dataset)
estimands = instantiate_estimands(estimands_config, dataset)
if sort_estimands
estimands = groups_ordering(estimands;
brute_force=true,
Expand Down Expand Up @@ -161,8 +173,8 @@ function tmle(dataset::String;
sort_estimands::Bool=false
)
runner = Runner(dataset;
estimands=estimands,
estimators=estimators,
estimands_config=estimands,
estimators_spec=estimators,
verbosity=verbosity,
outputs=outputs,
chunksize=chunksize,
Expand Down
34 changes: 16 additions & 18 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ function fix_treatment_values!(treatment_types::AbstractDict, Ψ::ComposedEstima
return ComposedEstimand.f, new_args)
end

wrapped_type(x) = x
wrapped_type(x::Type{<:CategoricalValue{T,}}) where T = T

"""
Uses the values found in the dataset to create a new estimand with adjusted values.
"""
function fix_treatment_values!(treatment_types::AbstractDict, Ψ, dataset)
treatment_names = keys.treatment_values)
for tn in treatment_names
haskey(treatment_types, tn) ? nothing : treatment_types[tn] = eltype(dataset[!, tn])
haskey(treatment_types, tn) ? nothing : treatment_types[tn] = wrapped_type(eltype(dataset[!, tn]))
end
new_treatment = NamedTuple{treatment_names}(
convert_treatment_values.treatment_values, treatment_types)
Expand All @@ -55,13 +58,11 @@ function fix_treatment_values!(treatment_types::AbstractDict, Ψ, dataset)
end

"""
proofread_estimands(param_file, dataset)
proofread_estimands(config, dataset)
Reads estimands from file and ensures that the treatment values in the config file
respects the treatment types in the dataset.
Ensures that the treatment values in the config respect the treatment types in the dataset.
"""
function proofread_estimands(filename, dataset)
config = read_estimands_config(filename)
function proofread_estimands(config, dataset)
adjustment_method = get_identification_method(config.adjustment)
estimands = Vector{TMLE.Estimand}(undef, length(config.estimands))
treatment_types = Dict()
Expand Down Expand Up @@ -91,11 +92,15 @@ function TMLE.factorialATE(dataset)
return [factorialATE(dataset, (:T, ), :Y; confounders=confounding_variables)]
end

function build_estimands_list(estimands_pattern, dataset)
instantiate_config(file::AbstractString) = read_estimands_config(file)
instantiate_config(config) = config

function instantiate_estimands(estimands_pattern, dataset)
estimands = if estimands_pattern == "factorialATE"
factorialATE(dataset)
else
proofread_estimands(estimands_pattern, dataset)
config = instantiate_config(estimands_pattern)
proofread_estimands(config, dataset)
end
return estimands
end
Expand All @@ -112,9 +117,11 @@ TMLE.emptyIC(nt::NamedTuple{names}, pval_threshold) where names =
Returns a DataFrame wrapper around a dataset, either in CSV format.
"""
instantiate_dataset(path::String) =
instantiate_dataset(path::AbstractString) =
endswith(path, ".csv") ? CSV.read(path, DataFrame, ntasks=1) : DataFrame(Arrow.Table(path))

instantiate_dataset(dataset) = dataset

isbinary(col, dataset) = Set(unique(skipmissing(dataset[!, col]))) == Set([0, 1])

make_categorical(x::CategoricalVector, ordered) = x
Expand Down Expand Up @@ -176,14 +183,5 @@ variables(Ψ::TMLE.Estimand) = Set([
Iterators.flatten(values.treatment_confounders))...
])

function load_tmle_spec(;file="glmnet")
file = endswith(file, ".jl") ? file : joinpath(
pkgdir(TargetedEstimation),
"estimators-configs",
string(file, ".jl"))
include(abspath(file))
return ESTIMATORS
end

TMLE.to_dict(nt::NamedTuple{names, <:Tuple{Vararg{Union{TMLE.EICEstimate, FailedEstimate, TMLE.ComposedEstimate}}}}) where names =
Dict(key => TMLE.to_dict(val) for (key, val) zip(keys(nt), nt))
49 changes: 34 additions & 15 deletions test/runner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,47 @@ using CSV
using Serialization
using YAML
using JSON
using MLJBase

TESTDIR = joinpath(pkgdir(TargetedEstimation), "test")

PKGDIR = pkgdir(TargetedEstimation)
TESTDIR = joinpath(PKGDIR, "test")
CONFIGDIR = joinpath(TESTDIR, "config")

include(joinpath(TESTDIR, "testutils.jl"))

@testset "Test instantiate_estimators" begin
# From template name
for file in readdir(joinpath(PKGDIR, "estimators-configs"))
configname = replace(file, ".jl" => "")
estimators = TargetedEstimation.instantiate_estimators(configname)
@test estimators.TMLE isa TMLEE
end
# From explicit file
estimators = TargetedEstimation.instantiate_estimators(joinpath(TESTDIR, "config", "tmle_ose_config.jl"))
@test estimators.TMLE isa TMLE.TMLEE
@test estimators.OSE isa TMLE.OSE
@test estimators.TMLE.weighted === true
@test estimators.TMLE.models.G_default === estimators.OSE.models.G_default
@test estimators.TMLE.models.G_default isa MLJBase.ProbabilisticStack
# From already constructed estimators
estimators_new = TargetedEstimation.instantiate_estimators(estimators)
@test estimators_new === estimators
end

@testset "Integration Test" begin
build_dataset(;n=1000, format="csv")
dataset = build_dataset(;n=1000)
tmpdir = mktempdir(cleanup=true)
estimands_filename = joinpath(tmpdir, "configuration.yaml")
TMLE.write_json(estimands_filename, statistical_estimands_only_config())
config = statistical_estimands_only_config()
outputs = TargetedEstimation.Outputs(
json=TargetedEstimation.JSONOutput(filename="output.json"),
hdf5=TargetedEstimation.HDF5Output(filename="output.hdf5", pval_threshold=1., sample_ids=true),
jls=TargetedEstimation.JLSOutput(filename="output.jls", pval_threshold=1e-5),
)
estimators = TargetedEstimation.instantiate_estimators(joinpath(CONFIGDIR, "tmle_ose_config.jl"))
runner = Runner(
"data.csv";
estimands=estimands_filename,
estimators=joinpath(CONFIGDIR, "tmle_ose_config.jl"),
dataset;
estimands_config=config,
estimators_spec=estimators,
outputs=outputs,
cache_strategy="release-unusable",
)
Expand Down Expand Up @@ -100,7 +120,6 @@ include(joinpath(TESTDIR, "testutils.jl"))
close(hdf5file)

# Clean
rm("data.csv")
rm(outputs.jls.filename)
rm(outputs.json.filename)
rm(outputs.hdf5.filename)
Expand All @@ -119,7 +138,7 @@ end
# Run tests over CSV and Arrow data formats
for format in ("csv", "arrow")
datafile = string("data.", format)
build_dataset(;n=1000, format=format)
write_dataset(;n=1000, format=format)
for chunksize in (4, 10)
tmle(datafile;
estimands=estimands_filename,
Expand Down Expand Up @@ -154,7 +173,7 @@ end
end

@testset "Test tmle: lower p-value threshold only JSON output" begin
build_dataset(;n=1000, format="csv")
write_dataset(;n=1000, format="csv")
tmpdir = mktempdir(cleanup=true)
estimandsfile = joinpath(tmpdir, "configuration.json")
configuration = statistical_estimands_only_config()
Expand Down Expand Up @@ -186,7 +205,7 @@ end
end

@testset "Test tmle: Failing estimands" begin
build_dataset(;n=1000, format="csv")
write_dataset(;n=1000, format="csv")
outputs = TargetedEstimation.Outputs(
json=TargetedEstimation.JSONOutput(filename="output.json"),
hdf5=TargetedEstimation.HDF5Output(filename="output.hdf5")
Expand All @@ -199,8 +218,8 @@ end
datafile = "data.csv"

runner = Runner(datafile;
estimands=estimandsfile,
estimators=estimatorfile,
estimands_config=estimandsfile,
estimators_spec=estimatorfile,
outputs=outputs
);
runner()
Expand Down Expand Up @@ -241,7 +260,7 @@ end
end

@testset "Test tmle: Causal and Composed Estimands" begin
build_dataset(;n=1000, format="csv")
write_dataset(;n=1000, format="csv")
tmpdir = mktempdir(cleanup=true)
estimandsfile = joinpath(tmpdir, "configuration.jls")

Expand Down
4 changes: 2 additions & 2 deletions test/sieve_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ TESTDIR = joinpath(pkgdir(TargetedEstimation), "test")

include(joinpath(TESTDIR, "testutils.jl"))

function build_dataset(sample_ids)
function write_sieve_dataset(sample_ids)
rng = StableRNG(123)
n = size(sample_ids, 1)
# Confounders
Expand Down Expand Up @@ -52,7 +52,7 @@ function build_tmle_output_file(sample_ids, estimandfile, outprefix;
pval=1.,
estimatorfile=joinpath(TESTDIR, "config", "tmle_ose_config.jl")
)
build_dataset(sample_ids)
write_sieve_dataset(sample_ids)
outputs = TargetedEstimation.Outputs(
hdf5=TargetedEstimation.HDF5Output(filename=string(outprefix, ".hdf5"), pval_threshold=pval, sample_ids=true),
)
Expand Down
2 changes: 1 addition & 1 deletion test/summary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ CONFIGDIR = joinpath(TESTDIR, "config")
include(joinpath(TESTDIR, "testutils.jl"))

@testset "Test make_summary" begin
build_dataset()
write_dataset()
datafile = "data.csv"
estimatorfile = joinpath(CONFIGDIR, "ose_config.jl")
tmpdir = mktempdir(cleanup=true)
Expand Down
5 changes: 5 additions & 0 deletions test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,10 @@ function build_dataset(;n=1000, format="csv")
dataset[!, "BINARY/OUTCOME"] = categorical(y₂)
dataset[!, "EXTREME_BINARY"] = categorical(vcat(0, ones(n-1)))

return dataset
end

function write_dataset(;n=1000, format="csv")
dataset = build_dataset(;n=1000)
format == "csv" ? CSV.write("data.csv", dataset) : Arrow.write("data.arrow", dataset)
end
37 changes: 7 additions & 30 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using TargetedEstimation
using TMLE
using DataFrames
using CSV
using MLJBase
using MLJLinearModels
using CategoricalArrays

Expand All @@ -14,33 +13,10 @@ check_type(treatment_value, ::Type{T}) where T = @test treatment_value isa T
check_type(treatment_values::NamedTuple, ::Type{T}) where T =
@test treatment_values.case isa T && treatment_values.control isa T

PKGDIR = pkgdir(TargetedEstimation)
TESTDIR = joinpath(PKGDIR, "test")
TESTDIR = joinpath(pkgdir(TargetedEstimation), "test")

include(joinpath(TESTDIR, "testutils.jl"))

@testset "Test load_tmle_spec" begin
# Default
noarg_estimators = TargetedEstimation.load_tmle_spec()
default_models = noarg_estimators.TMLE.models
@test noarg_estimators.TMLE isa TMLEE
@test default_models.Q_binary_default.glm_net_classifier isa GLMNetClassifier
@test default_models.Q_continuous_default.glm_net_regressor isa GLMNetRegressor
@test default_models.G_default isa GLMNetClassifier
# From template name
for file in readdir(joinpath(PKGDIR, "estimators-configs"))
configname = replace(file, ".jl" => "")
estimators = TargetedEstimation.load_tmle_spec(;file=configname)
@test estimators.TMLE isa TMLEE
end
# From explicit file
estimators = TargetedEstimation.load_tmle_spec(file=joinpath(TESTDIR, "config", "tmle_ose_config.jl"))
@test estimators.TMLE isa TMLE.TMLEE
@test estimators.OSE isa TMLE.OSE
@test estimators.TMLE.weighted === true
@test estimators.TMLE.models.G_default === estimators.OSE.models.G_default
@test estimators.TMLE.models.G_default isa MLJBase.ProbabilisticStack
end
@testset "Test convert_treatment_values" begin
treatment_types = Dict(:T₁=> Union{Missing, Bool}, :T₂=> Int)
newT = TargetedEstimation.convert_treatment_values((T₁=1,), treatment_types)
Expand All @@ -56,13 +32,14 @@ end
@test newT == [(case = true, control = false), (case = 1, control = 0)]
end

@testset "Test proofread_estimands" for extension in ("yaml", "json")
@testset "Test instantiate_config" for extension in ("yaml", "json")
# Write estimands file
filename = "statistical_estimands.$extension"
eval(Meta.parse("TMLE.write_$extension"))(filename, statistical_estimands_only_config())

dataset = DataFrame(T1 = [1., 0.], T2=[true, false])
estimands = TargetedEstimation.proofread_estimands(filename, dataset)
config = TargetedEstimation.instantiate_config(filename)
estimands = TargetedEstimation.proofread_estimands(config, dataset)
for estimand in estimands
if haskey(estimand.treatment_values, :T1)
check_type(estimand.treatment_values.T1, Float64)
Expand All @@ -77,13 +54,13 @@ end

@testset "Test factorialATE" begin
dataset = DataFrame(C=[1, 2, 3, 4],)
@test_throws ArgumentError TargetedEstimation.build_estimands_list("factorialATE", dataset)
@test_throws ArgumentError TargetedEstimation.instantiate_estimands("factorialATE", dataset)
dataset.T = [0, 1, missing, 2]
@test_throws ArgumentError TargetedEstimation.build_estimands_list("factorialATE", dataset)
@test_throws ArgumentError TargetedEstimation.instantiate_estimands("factorialATE", dataset)
dataset.Y = [0, 1, 2, 2]
dataset.W1 = [1, 1, 1, 1]
dataset.W_2 = [1, 1, 1, 1]
composedATE = TargetedEstimation.build_estimands_list("factorialATE", dataset)[1]
composedATE = TargetedEstimation.instantiate_estimands("factorialATE", dataset)[1]
@test composedATE.args == (
TMLE.StatisticalATE(:Y, (T = (case = 1, control = 0),), (T = (:W1, :W_2),), ()),
TMLE.StatisticalATE(:Y, (T = (case = 2, control = 1),), (T = (:W1, :W_2),), ())
Expand Down

0 comments on commit 6696024

Please sign in to comment.