From dfbee003776764f03c59daa353c6f148f3245db7 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 21 Feb 2024 11:59:29 +0000 Subject: [PATCH 1/3] add runner build estimator from estimator --- src/runner.jl | 14 +++++++++++++- src/utils.jl | 9 --------- test/runner.jl | 24 ++++++++++++++++++++++-- test/utils.jl | 26 +------------------------- 4 files changed, 36 insertions(+), 37 deletions(-) diff --git a/src/runner.jl b/src/runner.jl index e91c09d..af148c3 100644 --- a/src/runner.jl +++ b/src/runner.jl @@ -1,3 +1,15 @@ +instantiate_estimators(file::AbstractString) = load_tmle_spec(;file=file) +instantiate_estimators(estimators) = estimators + +function load_tmle_spec(;file="glmnet") + file = endswith(file, ".jl") ? file : joinpath( + pkgdir(TargetedEstimation), + "estimators-configs", + string(file, ".jl")) + include(abspath(file)) + return ESTIMATORS +end + mutable struct Runner estimators::NamedTuple estimands::Vector{TMLE.Estimand} @@ -18,7 +30,7 @@ mutable struct Runner sort_estimands=false ) # Retrieve TMLE specifications - estimators = TargetedEstimation.load_tmle_spec(file=estimators) + estimators = instantiate_estimators(estimators) # Load dataset dataset = TargetedEstimation.instantiate_dataset(dataset) # Read parameter files diff --git a/src/utils.jl b/src/utils.jl index 8fa48ca..7b11833 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -176,14 +176,5 @@ variables(Ψ::TMLE.Estimand) = Set([ Iterators.flatten(values(Ψ.treatment_confounders))... ]) -function load_tmle_spec(;file="glmnet") - file = endswith(file, ".jl") ? file : joinpath( - pkgdir(TargetedEstimation), - "estimators-configs", - string(file, ".jl")) - include(abspath(file)) - return ESTIMATORS -end - TMLE.to_dict(nt::NamedTuple{names, <:Tuple{Vararg{Union{TMLE.EICEstimate, FailedEstimate, TMLE.ComposedEstimate}}}}) where names = Dict(key => TMLE.to_dict(val) for (key, val) ∈ zip(keys(nt), nt)) \ No newline at end of file diff --git a/test/runner.jl b/test/runner.jl index 01bd6da..a4a28d3 100644 --- a/test/runner.jl +++ b/test/runner.jl @@ -8,13 +8,33 @@ using CSV using Serialization using YAML using JSON +using MLJBase -TESTDIR = joinpath(pkgdir(TargetedEstimation), "test") - +PKGDIR = pkgdir(TargetedEstimation) +TESTDIR = joinpath(PKGDIR, "test") CONFIGDIR = joinpath(TESTDIR, "config") include(joinpath(TESTDIR, "testutils.jl")) +@testset "Test instantiate_estimators" begin + # From template name + for file in readdir(joinpath(PKGDIR, "estimators-configs")) + configname = replace(file, ".jl" => "") + estimators = TargetedEstimation.instantiate_estimators(configname) + @test estimators.TMLE isa TMLEE + end + # From explicit file + estimators = TargetedEstimation.instantiate_estimators(joinpath(TESTDIR, "config", "tmle_ose_config.jl")) + @test estimators.TMLE isa TMLE.TMLEE + @test estimators.OSE isa TMLE.OSE + @test estimators.TMLE.weighted === true + @test estimators.TMLE.models.G_default === estimators.OSE.models.G_default + @test estimators.TMLE.models.G_default isa MLJBase.ProbabilisticStack + # From already constructed estimators + estimators_new = TargetedEstimation.instantiate_estimators(estimators) + @test estimators_new === estimators +end + @testset "Integration Test" begin build_dataset(;n=1000, format="csv") tmpdir = mktempdir(cleanup=true) diff --git a/test/utils.jl b/test/utils.jl index 9279170..0113eeb 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -5,7 +5,6 @@ using TargetedEstimation using TMLE using DataFrames using CSV -using MLJBase using MLJLinearModels using CategoricalArrays @@ -14,33 +13,10 @@ check_type(treatment_value, ::Type{T}) where T = @test treatment_value isa T check_type(treatment_values::NamedTuple, ::Type{T}) where T = @test treatment_values.case isa T && treatment_values.control isa T -PKGDIR = pkgdir(TargetedEstimation) -TESTDIR = joinpath(PKGDIR, "test") +TESTDIR = joinpath(pkgdir(TargetedEstimation), "test") include(joinpath(TESTDIR, "testutils.jl")) -@testset "Test load_tmle_spec" begin - # Default - noarg_estimators = TargetedEstimation.load_tmle_spec() - default_models = noarg_estimators.TMLE.models - @test noarg_estimators.TMLE isa TMLEE - @test default_models.Q_binary_default.glm_net_classifier isa GLMNetClassifier - @test default_models.Q_continuous_default.glm_net_regressor isa GLMNetRegressor - @test default_models.G_default isa GLMNetClassifier - # From template name - for file in readdir(joinpath(PKGDIR, "estimators-configs")) - configname = replace(file, ".jl" => "") - estimators = TargetedEstimation.load_tmle_spec(;file=configname) - @test estimators.TMLE isa TMLEE - end - # From explicit file - estimators = TargetedEstimation.load_tmle_spec(file=joinpath(TESTDIR, "config", "tmle_ose_config.jl")) - @test estimators.TMLE isa TMLE.TMLEE - @test estimators.OSE isa TMLE.OSE - @test estimators.TMLE.weighted === true - @test estimators.TMLE.models.G_default === estimators.OSE.models.G_default - @test estimators.TMLE.models.G_default isa MLJBase.ProbabilisticStack -end @testset "Test convert_treatment_values" begin treatment_types = Dict(:T₁=> Union{Missing, Bool}, :T₂=> Int) newT = TargetedEstimation.convert_treatment_values((T₁=1,), treatment_types) From 79970945f32359599b3ec9a0a2e8d74040006e58 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 21 Feb 2024 13:40:29 +0000 Subject: [PATCH 2/3] add runner constructor with already built Julia structs --- src/runner.jl | 4 ++-- src/utils.jl | 25 ++++++++++++++++--------- test/runner.jl | 21 ++++++++++----------- test/sieve_variance.jl | 4 ++-- test/summary.jl | 2 +- test/testutils.jl | 5 +++++ test/utils.jl | 11 ++++++----- 7 files changed, 42 insertions(+), 30 deletions(-) diff --git a/src/runner.jl b/src/runner.jl index af148c3..57f5da4 100644 --- a/src/runner.jl +++ b/src/runner.jl @@ -32,9 +32,9 @@ mutable struct Runner # Retrieve TMLE specifications estimators = instantiate_estimators(estimators) # Load dataset - dataset = TargetedEstimation.instantiate_dataset(dataset) + dataset = instantiate_dataset(dataset) # Read parameter files - estimands = TargetedEstimation.build_estimands_list(estimands, dataset) + estimands = instantiate_estimands(estimands, dataset) if sort_estimands estimands = groups_ordering(estimands; brute_force=true, diff --git a/src/utils.jl b/src/utils.jl index 7b11833..0f46498 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -35,13 +35,16 @@ function fix_treatment_values!(treatment_types::AbstractDict, Ψ::ComposedEstima return ComposedEstimand(Ψ.f, new_args) end +wrapped_type(x) = x +wrapped_type(x::Type{<:CategoricalValue{T,}}) where T = T + """ Uses the values found in the dataset to create a new estimand with adjusted values. """ function fix_treatment_values!(treatment_types::AbstractDict, Ψ, dataset) treatment_names = keys(Ψ.treatment_values) for tn in treatment_names - haskey(treatment_types, tn) ? nothing : treatment_types[tn] = eltype(dataset[!, tn]) + haskey(treatment_types, tn) ? nothing : treatment_types[tn] = wrapped_type(eltype(dataset[!, tn])) end new_treatment = NamedTuple{treatment_names}( convert_treatment_values(Ψ.treatment_values, treatment_types) @@ -55,13 +58,11 @@ function fix_treatment_values!(treatment_types::AbstractDict, Ψ, dataset) end """ - proofread_estimands(param_file, dataset) + proofread_estimands(config, dataset) -Reads estimands from file and ensures that the treatment values in the config file -respects the treatment types in the dataset. +Ensures that the treatment values in the config respect the treatment types in the dataset. """ -function proofread_estimands(filename, dataset) - config = read_estimands_config(filename) +function proofread_estimands(config, dataset) adjustment_method = get_identification_method(config.adjustment) estimands = Vector{TMLE.Estimand}(undef, length(config.estimands)) treatment_types = Dict() @@ -91,11 +92,15 @@ function TMLE.factorialATE(dataset) return [factorialATE(dataset, (:T, ), :Y; confounders=confounding_variables)] end -function build_estimands_list(estimands_pattern, dataset) +instantiate_config(file::AbstractString) = read_estimands_config(file) +instantiate_config(config) = config + +function instantiate_estimands(estimands_pattern, dataset) estimands = if estimands_pattern == "factorialATE" factorialATE(dataset) else - proofread_estimands(estimands_pattern, dataset) + config = instantiate_config(estimands_pattern) + proofread_estimands(config, dataset) end return estimands end @@ -112,9 +117,11 @@ TMLE.emptyIC(nt::NamedTuple{names}, pval_threshold) where names = Returns a DataFrame wrapper around a dataset, either in CSV format. """ -instantiate_dataset(path::String) = +instantiate_dataset(path::AbstractString) = endswith(path, ".csv") ? CSV.read(path, DataFrame, ntasks=1) : DataFrame(Arrow.Table(path)) +instantiate_dataset(dataset) = dataset + isbinary(col, dataset) = Set(unique(skipmissing(dataset[!, col]))) == Set([0, 1]) make_categorical(x::CategoricalVector, ordered) = x diff --git a/test/runner.jl b/test/runner.jl index a4a28d3..7e9ac78 100644 --- a/test/runner.jl +++ b/test/runner.jl @@ -36,19 +36,19 @@ include(joinpath(TESTDIR, "testutils.jl")) end @testset "Integration Test" begin - build_dataset(;n=1000, format="csv") + dataset = build_dataset(;n=1000) tmpdir = mktempdir(cleanup=true) - estimands_filename = joinpath(tmpdir, "configuration.yaml") - TMLE.write_json(estimands_filename, statistical_estimands_only_config()) + config = statistical_estimands_only_config() outputs = TargetedEstimation.Outputs( json=TargetedEstimation.JSONOutput(filename="output.json"), hdf5=TargetedEstimation.HDF5Output(filename="output.hdf5", pval_threshold=1., sample_ids=true), jls=TargetedEstimation.JLSOutput(filename="output.jls", pval_threshold=1e-5), ) + estimators = TargetedEstimation.instantiate_estimators(joinpath(CONFIGDIR, "tmle_ose_config.jl")) runner = Runner( - "data.csv"; - estimands=estimands_filename, - estimators=joinpath(CONFIGDIR, "tmle_ose_config.jl"), + dataset; + estimands=config, + estimators=estimators, outputs=outputs, cache_strategy="release-unusable", ) @@ -120,7 +120,6 @@ end close(hdf5file) # Clean - rm("data.csv") rm(outputs.jls.filename) rm(outputs.json.filename) rm(outputs.hdf5.filename) @@ -139,7 +138,7 @@ end # Run tests over CSV and Arrow data formats for format in ("csv", "arrow") datafile = string("data.", format) - build_dataset(;n=1000, format=format) + write_dataset(;n=1000, format=format) for chunksize in (4, 10) tmle(datafile; estimands=estimands_filename, @@ -174,7 +173,7 @@ end end @testset "Test tmle: lower p-value threshold only JSON output" begin - build_dataset(;n=1000, format="csv") + write_dataset(;n=1000, format="csv") tmpdir = mktempdir(cleanup=true) estimandsfile = joinpath(tmpdir, "configuration.json") configuration = statistical_estimands_only_config() @@ -206,7 +205,7 @@ end end @testset "Test tmle: Failing estimands" begin - build_dataset(;n=1000, format="csv") + write_dataset(;n=1000, format="csv") outputs = TargetedEstimation.Outputs( json=TargetedEstimation.JSONOutput(filename="output.json"), hdf5=TargetedEstimation.HDF5Output(filename="output.hdf5") @@ -261,7 +260,7 @@ end end @testset "Test tmle: Causal and Composed Estimands" begin - build_dataset(;n=1000, format="csv") + write_dataset(;n=1000, format="csv") tmpdir = mktempdir(cleanup=true) estimandsfile = joinpath(tmpdir, "configuration.jls") diff --git a/test/sieve_variance.jl b/test/sieve_variance.jl index 0095067..5bf29a2 100644 --- a/test/sieve_variance.jl +++ b/test/sieve_variance.jl @@ -15,7 +15,7 @@ TESTDIR = joinpath(pkgdir(TargetedEstimation), "test") include(joinpath(TESTDIR, "testutils.jl")) -function build_dataset(sample_ids) +function write_sieve_dataset(sample_ids) rng = StableRNG(123) n = size(sample_ids, 1) # Confounders @@ -52,7 +52,7 @@ function build_tmle_output_file(sample_ids, estimandfile, outprefix; pval=1., estimatorfile=joinpath(TESTDIR, "config", "tmle_ose_config.jl") ) - build_dataset(sample_ids) + write_sieve_dataset(sample_ids) outputs = TargetedEstimation.Outputs( hdf5=TargetedEstimation.HDF5Output(filename=string(outprefix, ".hdf5"), pval_threshold=pval, sample_ids=true), ) diff --git a/test/summary.jl b/test/summary.jl index 40f3a73..093eec1 100644 --- a/test/summary.jl +++ b/test/summary.jl @@ -14,7 +14,7 @@ CONFIGDIR = joinpath(TESTDIR, "config") include(joinpath(TESTDIR, "testutils.jl")) @testset "Test make_summary" begin - build_dataset() + write_dataset() datafile = "data.csv" estimatorfile = joinpath(CONFIGDIR, "ose_config.jl") tmpdir = mktempdir(cleanup=true) diff --git a/test/testutils.jl b/test/testutils.jl index ef5b992..99adfd6 100644 --- a/test/testutils.jl +++ b/test/testutils.jl @@ -126,5 +126,10 @@ function build_dataset(;n=1000, format="csv") dataset[!, "BINARY/OUTCOME"] = categorical(y₂) dataset[!, "EXTREME_BINARY"] = categorical(vcat(0, ones(n-1))) + return dataset +end + +function write_dataset(;n=1000, format="csv") + dataset = build_dataset(;n=1000) format == "csv" ? CSV.write("data.csv", dataset) : Arrow.write("data.arrow", dataset) end \ No newline at end of file diff --git a/test/utils.jl b/test/utils.jl index 0113eeb..2843b01 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -32,13 +32,14 @@ include(joinpath(TESTDIR, "testutils.jl")) @test newT == [(case = true, control = false), (case = 1, control = 0)] end -@testset "Test proofread_estimands" for extension in ("yaml", "json") +@testset "Test instantiate_config" for extension in ("yaml", "json") # Write estimands file filename = "statistical_estimands.$extension" eval(Meta.parse("TMLE.write_$extension"))(filename, statistical_estimands_only_config()) dataset = DataFrame(T1 = [1., 0.], T2=[true, false]) - estimands = TargetedEstimation.proofread_estimands(filename, dataset) + config = TargetedEstimation.instantiate_config(filename) + estimands = TargetedEstimation.proofread_estimands(config, dataset) for estimand in estimands if haskey(estimand.treatment_values, :T1) check_type(estimand.treatment_values.T1, Float64) @@ -53,13 +54,13 @@ end @testset "Test factorialATE" begin dataset = DataFrame(C=[1, 2, 3, 4],) - @test_throws ArgumentError TargetedEstimation.build_estimands_list("factorialATE", dataset) + @test_throws ArgumentError TargetedEstimation.instantiate_estimands("factorialATE", dataset) dataset.T = [0, 1, missing, 2] - @test_throws ArgumentError TargetedEstimation.build_estimands_list("factorialATE", dataset) + @test_throws ArgumentError TargetedEstimation.instantiate_estimands("factorialATE", dataset) dataset.Y = [0, 1, 2, 2] dataset.W1 = [1, 1, 1, 1] dataset.W_2 = [1, 1, 1, 1] - composedATE = TargetedEstimation.build_estimands_list("factorialATE", dataset)[1] + composedATE = TargetedEstimation.instantiate_estimands("factorialATE", dataset)[1] @test composedATE.args == ( TMLE.StatisticalATE(:Y, (T = (case = 1, control = 0),), (T = (:W1, :W_2),), ()), TMLE.StatisticalATE(:Y, (T = (case = 2, control = 1),), (T = (:W1, :W_2),), ()) From 6fba0df8ce3a99d2e726c0edcd57fb76c01e07e8 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 21 Feb 2024 13:46:52 +0000 Subject: [PATCH 3/3] rename some of Runner args [BREAKING] --- Project.toml | 2 +- src/runner.jl | 12 ++++++------ test/runner.jl | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index e968a6d..9aa819a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TargetedEstimation" uuid = "2573d147-4098-46ba-9db2-8608d210ccac" authors = ["Olivier Labayle"] -version = "0.8.0" +version = "0.9.0" [deps] ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" diff --git a/src/runner.jl b/src/runner.jl index 57f5da4..841d633 100644 --- a/src/runner.jl +++ b/src/runner.jl @@ -20,8 +20,8 @@ mutable struct Runner verbosity::Int failed_nuisance::Set function Runner(dataset; - estimands="factorialATE", - estimators="glmnet", + estimands_config="factorialATE", + estimators_spec="glmnet", verbosity=0, outputs=Outputs(), chunksize=100, @@ -30,11 +30,11 @@ mutable struct Runner sort_estimands=false ) # Retrieve TMLE specifications - estimators = instantiate_estimators(estimators) + estimators = instantiate_estimators(estimators_spec) # Load dataset dataset = instantiate_dataset(dataset) # Read parameter files - estimands = instantiate_estimands(estimands, dataset) + estimands = instantiate_estimands(estimands_config, dataset) if sort_estimands estimands = groups_ordering(estimands; brute_force=true, @@ -173,8 +173,8 @@ function tmle(dataset::String; sort_estimands::Bool=false ) runner = Runner(dataset; - estimands=estimands, - estimators=estimators, + estimands_config=estimands, + estimators_spec=estimators, verbosity=verbosity, outputs=outputs, chunksize=chunksize, diff --git a/test/runner.jl b/test/runner.jl index 7e9ac78..daece69 100644 --- a/test/runner.jl +++ b/test/runner.jl @@ -47,8 +47,8 @@ end estimators = TargetedEstimation.instantiate_estimators(joinpath(CONFIGDIR, "tmle_ose_config.jl")) runner = Runner( dataset; - estimands=config, - estimators=estimators, + estimands_config=config, + estimators_spec=estimators, outputs=outputs, cache_strategy="release-unusable", ) @@ -218,8 +218,8 @@ end datafile = "data.csv" runner = Runner(datafile; - estimands=estimandsfile, - estimators=estimatorfile, + estimands_config=estimandsfile, + estimators_spec=estimatorfile, outputs=outputs ); runner()