Skip to content

Commit

Permalink
deal with namedtuple in output of TMLECLI
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Dec 14, 2023
1 parent 3adc78a commit 7ad919e
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 62 deletions.
66 changes: 30 additions & 36 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ version = "7.6.1"

[[deps.Arrow]]
deps = ["ArrowTypes", "BitIntegers", "CodecLz4", "CodecZstd", "ConcurrentUtilities", "DataAPI", "Dates", "EnumX", "LoggingExtras", "Mmap", "PooledArrays", "SentinelArrays", "Tables", "TimeZones", "TranscodingStreams", "UUIDs"]
git-tree-sha1 = "954666e252835c4cf8819ce4ffaf31073c1b7233"
git-tree-sha1 = "cd893c29839c524ca2c5944b8e05f26e299df105"
uuid = "69666777-d1a9-59fb-9406-91d4454c9d45"
version = "2.6.2"
version = "2.7.0"

[[deps.ArrowTypes]]
deps = ["Sockets", "UUIDs"]
git-tree-sha1 = "8c37bfdf1b689c6677bbfc8986968fe641f6a299"
git-tree-sha1 = "404265cd8128a2515a81d5eae16de90fdef05101"
uuid = "31f734f8-188a-4ce0-8406-c8a06bd891cd"
version = "2.2.2"
version = "2.3.0"

[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
Expand Down Expand Up @@ -174,9 +174,9 @@ uuid = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1"
version = "0.3.1"

[[deps.CEnum]]
git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90"
git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.2"
version = "0.5.0"

[[deps.CSV]]
deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"]
Expand Down Expand Up @@ -229,9 +229,9 @@ version = "0.1.13"

[[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"]
git-tree-sha1 = "006cc7170be3e0fa02ccac6d4164a1eee1fc8c27"
git-tree-sha1 = "0aa0a3dd7b9bacbbadf1932ccbdfa938985c5561"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.58.0"
version = "1.58.1"

[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra"]
Expand All @@ -256,10 +256,10 @@ uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.3"

[[deps.CodecZstd]]
deps = ["CEnum", "TranscodingStreams", "Zstd_jll"]
git-tree-sha1 = "849470b337d0fa8449c21061de922386f32949d9"
deps = ["TranscodingStreams", "Zstd_jll"]
git-tree-sha1 = "f69e46bf7b307d15a896b57d5b3321c01cd64923"
uuid = "6b39b394-51ab-5f42-8807-6242bab2b4c2"
version = "0.7.2"
version = "0.8.1"

[[deps.ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
Expand Down Expand Up @@ -538,9 +538,9 @@ uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"

[[deps.FillArrays]]
deps = ["LinearAlgebra", "Random"]
git-tree-sha1 = "25a10f2b86118664293062705fd9c7e2eda881a2"
git-tree-sha1 = "5b93957f6dcd33fc343044af3d48c215be2562f1"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.9.2"
version = "1.9.3"
weakdeps = ["PDMats", "SparseArrays", "Statistics"]

[deps.FillArrays.extensions]
Expand Down Expand Up @@ -592,9 +592,9 @@ version = "1.9.0"

[[deps.GLMNet]]
deps = ["DataFrames", "Distributed", "Distributions", "Printf", "Random", "SparseArrays", "StatsBase", "glmnet_jll"]
git-tree-sha1 = "7ea4e2bbb84183fe52a488d05e16c152b2387b95"
git-tree-sha1 = "49ef90cd140f8a99a81338f1e08e8ebc18837a63"
uuid = "8d5ece8b-de18-5317-b113-243142960cc6"
version = "0.7.2"
version = "0.7.0"

[[deps.GPUArrays]]
deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"]
Expand Down Expand Up @@ -924,16 +924,10 @@ uuid = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
version = "0.4.4"

[[deps.MLJ]]
deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBalancing", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "StatisticalMeasures", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "981196c41a23cbc1befbad190558b1f0ebb97910"
deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "StatisticalMeasures", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "58d17a367ee211ade6e53f83a9cc5adf9d26f833"
uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
version = "0.20.2"

[[deps.MLJBalancing]]
deps = ["MLJBase", "MLJModelInterface", "MLUtils", "OrderedCollections", "Random", "StatsBase"]
git-tree-sha1 = "e4be85602f010291f49b6a6464ccde1708ce5d62"
uuid = "45f359ea-796d-4f51-95a5-deb1a414c586"
version = "0.1.3"
version = "0.20.0"

[[deps.MLJBase]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
Expand All @@ -953,9 +947,9 @@ version = "0.4.0"

[[deps.MLJFlow]]
deps = ["MLFlowClient", "MLJBase", "MLJModelInterface"]
git-tree-sha1 = "89d0e7a7e08359476482f20b2d8ff12080d171ee"
git-tree-sha1 = "dc0de70a794c6d4c1aa4bde8196770c6b6e6b550"
uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f"
version = "0.3.0"
version = "0.2.0"

[[deps.MLJGLMInterface]]
deps = ["Distributions", "GLM", "MLJModelInterface", "StatsModels", "Tables"]
Expand Down Expand Up @@ -1473,9 +1467,9 @@ version = "0.1.15"

[[deps.StableRNGs]]
deps = ["Random", "Test"]
git-tree-sha1 = "3be7d49667040add7ee151fefaf1f8c04c8c8276"
git-tree-sha1 = "ddc1a7b85e760b5285b50b882fa91e40c603be47"
uuid = "860ef19b-820b-49d6-a774-d7a799459cd3"
version = "1.0.0"
version = "1.0.1"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
Expand Down Expand Up @@ -1531,9 +1525,9 @@ version = "1.7.0"

[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "1d77abd07f617c4868c33d4f5b9e1dbb2643c9cf"
git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.34.2"
version = "0.33.21"

[[deps.StatsFuns]]
deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
Expand Down Expand Up @@ -1590,9 +1584,9 @@ version = "5.10.1+6"

[[deps.TMLE]]
deps = ["AbstractDifferentiation", "CategoricalArrays", "Combinatorics", "Distributions", "GLM", "Graphs", "HypothesisTests", "LogExpFunctions", "MLJBase", "MLJGLMInterface", "MLJModels", "MetaGraphsNext", "Missings", "PrecompileTools", "PrettyTables", "Random", "Statistics", "TableOperations", "Tables", "Zygote"]
git-tree-sha1 = "e66fe5f9cfdfb69f8fb8f870a89b5ce151789ae9"
git-tree-sha1 = "a9cb213aede82499800d04c33837497357c6dfd6"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
version = "0.12.1"
version = "0.12.2"
weakdeps = ["JSON", "YAML"]

[deps.TMLE.extensions]
Expand Down Expand Up @@ -1635,7 +1629,7 @@ version = "1.10.0"

[[deps.TargetedEstimation]]
deps = ["Arrow", "CSV", "CategoricalArrays", "Combinatorics", "Comonicon", "Configurations", "DataFrames", "EvoTrees", "GLMNet", "HDF5", "JLD2", "JSON", "MKL", "MLJ", "MLJBase", "MLJLinearModels", "MLJModelInterface", "MLJModels", "MLJXGBoostInterface", "Mmap", "MultipleTesting", "Optim", "Random", "Serialization", "TMLE", "Tables", "YAML"]
git-tree-sha1 = "94334dd6c98f3a4ae10143d0d85e3f99383b69be"
git-tree-sha1 = "b81e841116848f61526b808384f35e7f06a13a52"
repo-rev = "cv_tmle"
repo-url = "https://github.com/TARGENE/TargetedEstimation.jl.git"
uuid = "2573d147-4098-46ba-9db2-8608d210ccac"
Expand Down Expand Up @@ -1785,10 +1779,10 @@ uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.4"

[[deps.glmnet_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "31adae3b983b579a1fbd7cfd43a4bc0d224c2f5a"
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "a88d1783391cea1503e092e8a346751ec5e3b5d1"
uuid = "78c6b45d-5eaf-5d68-bcfb-a5a2cb06c27f"
version = "2.0.13+0"
version = "5.0.0+0"

[[deps.libaec_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
Expand Down
4 changes: 4 additions & 0 deletions bin/generate_permutation_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ function parse_commandline()
help = "The p-value threshold for significant results calling"
default = 0.05
arg_type = Float64
"--estimator-key"
help = "Estimator to use to check significance."
default = "TMLE"
arg_type = String
"--limit"
help = "The max number of permutation parameters to be generated"
default = nothing
Expand Down
4 changes: 4 additions & 0 deletions bin/generate_random_variant_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ function parse_commandline()
help = "Number of random variants per trans-actor"
arg_type = Int
default = 10
"--estimator-key"
help = "Estimator to use to check significance."
default = "TMLE"
arg_type = String
"--pval-threshold"
help = "The p-value threshold for significant results calling"
default = 0.05
Expand Down
3 changes: 2 additions & 1 deletion src/permutation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ function generate_permutation_parameters_and_dataset(parsed_args)
limit = parsed_args["limit"]
rng = StableRNG(parsed_args["rng"])
chunksize = parsed_args["chunksize"]
estimator_key = Symbol(parsed_args["estimator-key"])

# Generating Permutation Parameters
verbosity > 0 && @info string("Retrieving significant parameters.")
results = read_significant_results(resultsfile, threshold=pval_threshold)
results = read_significant_results(resultsfile, threshold=pval_threshold, estimator_key=estimator_key)
verbosity > 0 && @info string(size(results, 1), " parameters satisfying the threshold.")
verbosity > 0 && @info "Generating permutation parameters."
parameters, permuted_variables = make_permutation_parameters(results; optimize=true, orders=orders)
Expand Down
3 changes: 2 additions & 1 deletion src/random_variants_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,10 @@ function generate_random_variants_parameters_and_dataset(parsed_args)
pval_threshold = parsed_args["pval-threshold"]
out = parsed_args["out"]
verbosity = parsed_args["verbosity"]
estimator_key = Symbol(parsed_args["estimator-key"])

verbosity > 0 && @info string("Retrieving significant estimands.")
significant_estimands = read_significant_results(resultsfile; threshold=pval_threshold)
significant_estimands = read_significant_results(resultsfile; threshold=pval_threshold, estimator_key=estimator_key)
all_rsids = unique_treatments(significant_estimands)

verbosity > 0 && @info string("Looking for random MAF matching variants for each Trans-Actor.")
Expand Down
31 changes: 17 additions & 14 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,48 @@ Split string and remove principal components from the list.
"""
getconfounders(v) = Symbol.(filter(x -> !occursin(r"^PC[0-9]*$", x), split_string(v)))

default_statistical_test(Ψ̂; threshold=0.05) = pvalue(OneSampleTTest(Ψ̂)) < threshold
default_statistical_test(Ψ̂::TMLE.Estimate; threshold=0.05) = pvalue(OneSampleTTest(Ψ̂)) < threshold

default_statistical_test(Ψ̂::TMLE.ComposedEstimate; threshold=0.05) =
length(Ψ̂.estimate) > 1 ? pvalue(TMLE.OneSampleHotellingT2Test(Ψ̂)) < threshold : pvalue(TMLE.OneSampleTTest(Ψ̂)) < threshold

is_significant(Ψ̂; threshold=0.05) =
is_significant(Ψ̂::TMLE.Estimate; threshold=0.05) =
default_statistical_test(Ψ̂; threshold=threshold)

function read_significant_from_hdf5(filename; threshold=0.05)
is_significant(nt; threshold=0.05, estimator_key=:TMLE) =
is_significant(nt[estimator_key]; threshold=threshold)

function read_significant_from_hdf5(filename; threshold=0.05, estimator_key=:TMLE)
jldopen(filename) do io
return mapreduce(vcat, keys(io)) do key
[Ψ̂.estimand for Ψ̂ io[key] if is_significant(Ψ̂; threshold=threshold)]
[nt[estimator_key].estimand for nt io[key] if is_significant(nt; estimator_key=estimator_key, threshold=threshold)]
end
end
end

function read_significant_from_jls(filename; threshold=0.05)
function read_significant_from_jls(filename; threshold=0.05, estimator_key=:TMLE)
results = []
open(filename) do io
while !eof(io)
Ψ̂ = deserialize(io)
if is_significant(Ψ̂, threshold=threshold)
push!(results, Ψ̂.estimand)
nt = deserialize(io)
if is_significant(nt, threshold=threshold, estimator_key=estimator_key)
push!(results, nt[estimator_key].estimand)
end
end
end
return results
end

read_significant_from_json(filename; threshold=0.05) =
[Ψ̂.estimand for Ψ̂ TMLE.read_json(filename) if is_significant(Ψ̂; threshold=threshold)]
read_significant_from_json(filename; threshold=0.05, estimator_key=:TMLE) =
[nt[estimator_key].estimand for nt TMLE.read_json(filename) if is_significant(nt; threshold=threshold, estimator_key=estimator_key)]

function read_significant_results(filename; threshold=0.05)
function read_significant_results(filename; threshold=0.05, estimator_key=:TMLE)
results = if endswith(filename, "hdf5")
read_significant_from_hdf5(filename; threshold=threshold)
read_significant_from_hdf5(filename; threshold=threshold, estimator_key=estimator_key)
elseif endswith(filename, "jls")
read_significant_from_jls(filename; threshold=threshold)
read_significant_from_jls(filename; threshold=threshold, estimator_key=estimator_key)
elseif endswith(filename, "json")
read_significant_from_json(filename; threshold=threshold)
read_significant_from_json(filename; threshold=threshold, estimator_key=estimator_key)
else
throw(ArgumentError("Unupported estimate file format: $filepath"))
end
Expand Down
8 changes: 4 additions & 4 deletions test/permutation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ include(joinpath(TESTDIR, "testutils.jl"))

@testset "Test permuted_estimand!" begin
estimates = make_estimates()
Ψ = estimates[1].estimand
Ψ = estimates[1].TMLE.estimand
@test Ψ isa TMLE.StatisticalIATE
# Treatment and Outcome
permutation_variables = Set([Ψ.outcome, :RSID_103])
Expand Down Expand Up @@ -43,7 +43,7 @@ include(joinpath(TESTDIR, "testutils.jl"))
@test Ψpermuted.treatment_confounders == Ψ.treatment_confounders
@test Ψpermuted.outcome_extra_covariates == Ψ.outcome_extra_covariates
# Composed Estimand
Ψ = estimates[3].estimand
Ψ = estimates[3].TMLE.estimand
@test Ψ isa ComposedEstimand
outcome = Symbol("High light scatter reticulocyte percentage")
permutation_variables = Set([:RSID_103, outcome])
Expand All @@ -58,9 +58,8 @@ include(joinpath(TESTDIR, "testutils.jl"))
@test arg₂.outcome_extra_covariates == arg₁.outcome_extra_covariates == Ψ.args[1].outcome_extra_covariates
end


@testset "Test make_permutation_parameters" begin
estimands =.estimand for Ψ make_estimates()]
estimands =.TMLE.estimand for Ψ make_estimates()]
expected_permuted_variables = Set([
:rs117913124,
:rs10043934,
Expand Down Expand Up @@ -102,6 +101,7 @@ end
"dataset" => joinpath(TESTDIR, "data", "final.data.csv"),
"results" => "tmle_output.hdf5",
"outdir" => ".",
"estimator-key" => "TMLE",
"pval-threshold" => 1e-10,
"verbosity" => 0,
"limit" => nothing,
Expand Down
5 changes: 3 additions & 2 deletions test/random_variants_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ end
bgen_prefix = joinpath(TESTDIR, "data", "bgen", "ukb")
trans_actors = Set(["RSID_103", "RSID_104"])

estimands = [Ψ̂.estimand for Ψ̂ make_estimates()]
estimands = [nt.TMLE.estimand for nt make_estimates()]
variant_map = NegativeControl.find_maf_matching_random_variants(
trans_actors,
bgen_prefix,
Expand Down Expand Up @@ -170,14 +170,15 @@ end
@testset "Test generate_random_variants_parameters_and_dataset" begin
estimates = make_estimates()
save(estimates)
estimands = [Ψ̂.estimand for Ψ̂ in estimates]
estimands = [Ψ̂.TMLE.estimand for Ψ̂ in estimates]
parsed_args = Dict(
"p" => 5,
"results" => "tmle_output.hdf5",
"trans-actors-prefix" => joinpath(TESTDIR, "data", "trans_act"),
"bgen-prefix" => joinpath(TESTDIR, "data", "bgen", "ukb"),
"out" => "random_variants_parameters.yaml",
"pval-threshold" => 0.05,
"estimator-key" => "TMLE",
"verbosity" => 0,
"reltol" => 0.05,
"rng" => 123,
Expand Down
2 changes: 1 addition & 1 deletion test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function make_estimates()
IC = []
)

return [IATE₁, IATE₂, jointIATE, ATE₁]
return [(TMLE=IATE₁,), (TMLE=IATE₂,), (TMLE=jointIATE,), (TMLE=ATE₁,)]
end

function save(estimates; prefix="tmle_output")
Expand Down
6 changes: 3 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ include(joinpath(TESTDIR, "testutils.jl"))
@testset "Test treatment_variables and outcome_variables" begin
estimates = make_estimates()
# Classic estimand
Ψ = estimates[1].estimand
Ψ = estimates[1].TMLE.estimand
@test Ψ isa TMLE.StatisticalIATE
@test NegativeControl.treatment_variables(Ψ) == [:RSID_103, :rs10043934]
@test NegativeControl.outcome_variables(Ψ) ==.outcome]
# Composed estimand
Ψ = estimates[3].estimand
Ψ = estimates[3].TMLE.estimand
@test Ψ isa ComposedEstimand
@test NegativeControl.treatment_variables(Ψ) == [:RSID_103, :rs10043934]
@test NegativeControl.outcome_variables(Ψ) == [Symbol("High light scatter reticulocyte percentage")]
Expand All @@ -31,7 +31,7 @@ end
results = NegativeControl.read_significant_results(string(prefix, "." ,ext); threshold=threshold)
@test length(results) == length(estimates)
for (index, Ψ̋) enumerate(estimates)
results[index] == Ψ̋.estimand
results[index] == Ψ̋.TMLE.estimand
end
end

Expand Down

0 comments on commit 7ad919e

Please sign in to comment.