Skip to content

Commit

Permalink
fix bug in composed estimand generation
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Jan 15, 2024
1 parent 8d6143b commit 135f819
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
20 changes: 7 additions & 13 deletions src/random_variants_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ NotEnoughMatchingVariantsError(rsid, p, reltol) =
))


read_snps_from_csv(path::String) = unique(CSV.read(path, DataFrame; select=[:ID, :CHR]), :ID)

get_variants_to_randomize(filepath::AbstractString) = Set(open(readlines, filepath))


Expand Down Expand Up @@ -169,17 +167,13 @@ function find_maf_matching_random_variants(
return variant_map
end

"""
Each sub estimand can potentially lead to p new sub estimands. A D-dimensional composite estimand
would lead to p^D new composed estimands. Typically p = 10, D = 9 for interactions leading to too many
new estimands. Instead we randomly select p for each subestimands, leading to p new composed estimands.
"""
function make_random_variants_estimands::ComposedEstimand, variant_map; p=10, rng=MersenneTwister())
newargs = Tuple(rand(rng, make_random_variants_estimands(arg, variant_map), p) for arg Ψ.args)

function make_random_variants_estimands::ComposedEstimand, variant_map)
newargs = Tuple(make_random_variants_estimands(arg, variant_map) for arg Ψ.args)
return [ComposedEstimand.f, Tuple(args)) for args zip(newargs...)]
end

function make_random_variants_estimands::T, variant_map; kwargs...) where T <: TMLE.Estimand
function make_random_variants_estimands::T, variant_map) where T <: TMLE.Estimand
transactors = keys(variant_map)
origin_treatment_variables = keys.treatment_values)
# At least one trans-actor in the parameter treatments to be processed
Expand Down Expand Up @@ -217,8 +211,8 @@ function make_random_variants_estimands(Ψ::T, variant_map; kwargs...) where T <
return new_estimands
end

make_random_variants_estimands(estimands, variant_map; p=10, rng=MersenneTwister(123)) =
vcat((make_random_variants_estimands(Ψ, variant_map; p=p, rng=rng) for Ψ in estimands)...)
make_random_variants_estimands(estimands, variant_map) =
vcat((make_random_variants_estimands(Ψ, variant_map) for Ψ in estimands)...)

function generate_random_variants_parameters_and_dataset(parsed_args)
resultsfile = parsed_args["results"]
Expand All @@ -243,7 +237,7 @@ function generate_random_variants_parameters_and_dataset(parsed_args)
p=p, rng=rng, reltol=reltol, verbosity=verbosity
)
verbosity > 0 && @info string("Building new estimands from matched random variants.")
new_estimands = make_random_variants_estimands(significant_estimands, variant_map; rng=rng, p=p)
new_estimands = make_random_variants_estimands(significant_estimands, variant_map)
new_estimands = groups_ordering(new_estimands, brute_force=false, do_shuffle=false)
serialize(out, Configuration(estimands = new_estimands))
verbosity > 0 && @info string("Done.")
Expand Down
8 changes: 7 additions & 1 deletion test/random_variants_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ end
trans_actors;
p=p, rng=rng, reltol=reltol, verbosity=0
)
new_estimands = NegativeControl.make_random_variants_estimands(estimands, variant_map, p=p)
new_estimands = NegativeControl.make_random_variants_estimands(estimands, variant_map)
# 5 new estimands for each of the 4 input estimands
@test length(new_estimands) == 20
# Check 5 new estimands generated from estimand 1
Expand All @@ -156,6 +156,12 @@ end
@test expected_new_rsids == new_rsids
# Check 5 new estimands generated from estimand 3 (composed estimand)
for Ψ new_estimands[11:15]
@test all(arg.outcome == first.args).outcome for arg in Ψ.args)
treatment_settings = [arg.treatment_values for arg in Ψ.args]
treatment_variables = [keys(ts) for ts in treatment_settings]
@test length(unique(treatment_variables)) == 1
treatment_values = [values(ts) for ts in treatment_settings]
@test length(unique(treatment_values)) == length(treatment_variables)
@test Ψ isa ComposedEstimand
@test Ψ != estimands[3]
end
Expand Down

0 comments on commit 135f819

Please sign in to comment.