Skip to content

Commit

Permalink
add support for failedestimate
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Dec 14, 2023
1 parent 7ad919e commit 9c40a22
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 17 deletions.
1 change: 1 addition & 0 deletions src/NegativeControl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using JSON
using HTTP
using Statistics
using JLD2
using TargetedEstimation

include("utils.jl")
include("permutation_test.jl")
Expand Down
30 changes: 21 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,33 @@ Split string and remove principal components from the list.
"""
getconfounders(v) = Symbol.(filter(x -> !occursin(r"^PC[0-9]*$", x), split_string(v)))

default_statistical_test(Ψ̂::TMLE.Estimate; threshold=0.05) = pvalue(OneSampleTTest(Ψ̂)) < threshold
is_significant(Ψ̂::TMLE.Estimate; threshold=0.05) =
pvalue(OneSampleTTest(Ψ̂)) < threshold

default_statistical_test(Ψ̂::TMLE.ComposedEstimate; threshold=0.05) =
length(Ψ̂.estimate) > 1 ? pvalue(TMLE.OneSampleHotellingT2Test(Ψ̂)) < threshold : pvalue(TMLE.OneSampleTTest(Ψ̂)) < threshold
function is_significant(Ψ̂::TMLE.ComposedEstimate; threshold=0.05)
sig = if length(Ψ̂.estimate) > 1
pvalue(TMLE.OneSampleHotellingT2Test(Ψ̂)) < threshold
else
pvalue(TMLE.OneSampleTTest(Ψ̂)) < threshold
end
return sig
end

is_significant(Ψ̂::TMLE.Estimate; threshold=0.05) =
default_statistical_test(Ψ̂; threshold=threshold)
"""
For FailedEstimates
"""
is_significant(Ψ̂; threshold=0.05) = false

is_significant(nt; threshold=0.05, estimator_key=:TMLE) =
"""
For NamedTuples/Dicts stored in a results file
"""
is_significant(nt, estimator_key; threshold=0.05) =
is_significant(nt[estimator_key]; threshold=threshold)

function read_significant_from_hdf5(filename; threshold=0.05, estimator_key=:TMLE)
jldopen(filename) do io
return mapreduce(vcat, keys(io)) do key
[nt[estimator_key].estimand for nt io[key] if is_significant(nt; estimator_key=estimator_key, threshold=threshold)]
[nt[estimator_key].estimand for nt io[key] if is_significant(nt, estimator_key, threshold=threshold)]
end
end
end
Expand All @@ -50,7 +62,7 @@ function read_significant_from_jls(filename; threshold=0.05, estimator_key=:TMLE
open(filename) do io
while !eof(io)
nt = deserialize(io)
if is_significant(nt, threshold=threshold, estimator_key=estimator_key)
if is_significant(nt, estimator_key; threshold=threshold)
push!(results, nt[estimator_key].estimand)
end
end
Expand All @@ -59,7 +71,7 @@ function read_significant_from_jls(filename; threshold=0.05, estimator_key=:TMLE
end

read_significant_from_json(filename; threshold=0.05, estimator_key=:TMLE) =
[nt[estimator_key].estimand for nt TMLE.read_json(filename) if is_significant(nt; threshold=threshold, estimator_key=estimator_key)]
[nt[estimator_key].estimand for nt TMLE.read_json(filename) if is_significant(nt, estimator_key; threshold=threshold)]

function read_significant_results(filename; threshold=0.05, estimator_key=:TMLE)
results = if endswith(filename, "hdf5")
Expand Down
2 changes: 1 addition & 1 deletion test/permutation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ include(joinpath(TESTDIR, "testutils.jl"))
end

@testset "Test make_permutation_parameters" begin
estimands =.TMLE.estimand for Ψ make_estimates()]
estimands =.TMLE.estimand for Ψ make_estimates()[1:end-1]]
expected_permuted_variables = Set([
:rs117913124,
:rs10043934,
Expand Down
3 changes: 1 addition & 2 deletions test/random_variants_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ end
bgen_prefix = joinpath(TESTDIR, "data", "bgen", "ukb")
trans_actors = Set(["RSID_103", "RSID_104"])

estimands = [nt.TMLE.estimand for nt make_estimates()]
estimands = [nt.TMLE.estimand for nt make_estimates()[1:end-1]]
variant_map = NegativeControl.find_maf_matching_random_variants(
trans_actors,
bgen_prefix,
Expand Down Expand Up @@ -170,7 +170,6 @@ end
@testset "Test generate_random_variants_parameters_and_dataset" begin
estimates = make_estimates()
save(estimates)
estimands = [Ψ̂.TMLE.estimand for Ψ̂ in estimates]
parsed_args = Dict(
"p" => 5,
"results" => "tmle_output.hdf5",
Expand Down
18 changes: 17 additions & 1 deletion test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,23 @@ function make_estimates()
IC = []
)

return [(TMLE=IATE₁,), (TMLE=IATE₂,), (TMLE=jointIATE,), (TMLE=ATE₁,)]
failed_estimate = TargetedEstimation.FailedEstimate(
ATE(
outcome = "L50-L54 Urticaria and erythema",
treatment_values = (
rs117913124 = (case="GA", control="GG"),
RSID_104 = (case="GA", control="GG")
),
treatment_confounders = (
rs117913124 = (:PC1, :PC2, :PC3, :PC4, :PC5, :PC6),
RSID_104 = (:PC1, :PC2, :PC3, :PC4, :PC5, :PC6)
),
outcome_extra_covariates = ("Age-Assessment", "Genetic-Sex")
),
"Could not fluctuate"
)

return [(TMLE=IATE₁,), (TMLE=IATE₂,), (TMLE=jointIATE,), (TMLE=ATE₁,), (TMLE=failed_estimate,)]
end

function save(estimates; prefix="tmle_output")
Expand Down
8 changes: 4 additions & 4 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ end

threshold = 1.
for ext in ("jls", "json", "hdf5")
results = NegativeControl.read_significant_results(string(prefix, "." ,ext); threshold=threshold)
@test length(results) == length(estimates)
for (index, Ψ̋) enumerate(estimates)
results[index] == Ψ̋.TMLE.estimand
results = NegativeControl.read_significant_results(string(prefix, "." , ext); threshold=threshold)
length(results) == 4
for index 1:4
results[index] == estimates[index].TMLE.estimand
end
end

Expand Down

0 comments on commit 9c40a22

Please sign in to comment.