Skip to content

Commit

Permalink
Merge branch 'main' of github.com:TARGENE/TMLE.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Dec 19, 2023
2 parents b236548 + 624f76b commit 44e0827
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 38 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TMLE"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
authors = ["Olivier Labayle"]
version = "0.12.2"
version = "0.13.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Expand All @@ -21,6 +21,7 @@ Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Expand Down Expand Up @@ -55,6 +56,7 @@ TableOperations = "1.2"
Tables = "1.6"
YAML = "0.4.9"
Zygote = "0.6"
SplitApplyCombine = "1.2.2"
julia = "1.6, 1.7, 1"

[extras]
Expand Down
3 changes: 2 additions & 1 deletion src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import AbstractDifferentiation as AD
using Graphs
using MetaGraphsNext
using Combinatorics
using SplitApplyCombine

# #############################################################################
# EXPORTS
Expand All @@ -28,7 +29,7 @@ using Combinatorics
export SCM, StaticSCM, add_equations!, add_equation!, parents, vertices
export CM, ATE, IATE
export AVAILABLE_ESTIMANDS
export generateATEs
export generateATEs, generateIATEs
export TMLEE, OSE, NAIVE
export ComposedEstimand
export var, estimate, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test,pvalue, confint, emptyIC
Expand Down
220 changes: 195 additions & 25 deletions src/counterfactual_mean_based/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,37 +231,207 @@ unique_non_missing(dataset, colname) = unique(skipmissing(Tables.getcolumn(datas

unique_treatment_values(dataset, colnames) =(;(colname => unique_non_missing(dataset, colname) for colname in colnames)...)

get_treatments_contrasts(treatments_unique_values) = [collect(Combinatorics.combinations(treatments_unique_values[T], 2)) for T in keys(treatments_unique_values)]

function generateComposedEstimandFromContrasts(
constructor,
treatments_levels::NamedTuple{names},
outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
treatments_contrasts = get_treatments_contrasts(treatments_levels)
components = []
for combo Iterators.product(treatments_contrasts...)
treatments_contrast = [NamedTuple{(:control, :case)}(treatment_control_case) for treatment_control_case combo]
Ψ = constructor(
outcome=outcome,
treatment_values=NamedTuple{names}(treatments_contrast),
treatment_confounders = confounders,
outcome_extra_covariates=outcome_extra_covariates
)
if satisfies_positivity(Ψ, freq_table; positivity_constraint=positivity_constraint)
push!(components, Ψ)
end
end
return ComposedEstimand(joint_estimand, Tuple(components))
end

GENERATE_DOCSTRING = """
The components of this estimand are generated from the treatment variables contrasts.
For example, consider two treatment variables T₁ and T₂ each taking three possible values (0, 1, 2).
For each treatment variable, the marginal contrasts are defined by (0 → 1, 1 → 2, 0 → 2), there are thus
3 x 3 = 9 joint contrasts to be generated:
- (T₁: 0 → 1, T₂: 0 → 1)
- (T₁: 0 → 1, T₂: 1 → 2)
- (T₁: 0 → 1, T₂: 0 → 2)
- (T₁: 1 → 2, T₂: 0 → 1)
- (T₁: 1 → 2, T₂: 1 → 2)
- (T₁: 1 → 2, T₂: 0 → 2)
- (T₁: 0 → 2, T₂: 0 → 1)
- (T₁: 0 → 2, T₂: 1 → 2)
- (T₁: 0 → 2, T₂: 0 → 2)
# Return
A `ComposedEstimand` with causal or statistical components.
# Args
- `treatments_levels`: A NamedTuple providing the unique levels each treatment variable can take.
- `outcome`: The outcome variable.
- `confounders=nothing`: The generated components will inherit these confounding variables.
If `nothing`, causal estimands are generated.
- `outcome_extra_covariates=()`: The generated components will inherit these `outcome_extra_covariates`.
- `positivity_constraint=nothing`: Only components that pass the positivity constraint are added to the `ComposedEstimand`
"""

"""
generateATEs(
treatments_levels::NamedTuple{names}, outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
Generate a `ComposedEstimand` of ATEs from the `treatments_levels`. $GENERATE_DOCSTRING
# Example:
To generate a causal composed estimand with 3 components:
```@example
generateATEs((T₁ = (0, 1), T₂=(0, 1, 2)), :Y₁)
```
To generate a statistical composed estimand with 9 components:
```@example
generateATEs((T₁ = (0, 1, 2), T₂=(0, 1, 2)), :Y₁, confounders=[:W₁, :W₂])
```
"""
function generateATEs(
treatments_levels::NamedTuple{names}, outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
return generateComposedEstimandFromContrasts(
ATE,
treatments_levels,
outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
freq_table=freq_table,
positivity_constraint=positivity_constraint
)
end

"""
generateATEs(dataset, treatments, outcome; confounders=nothing, outcome_extra_covariates=())
generateATEs(dataset, treatments, outcome;
confounders=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing
)
Find all unique values for each treatment variable in the dataset and generate all possible ATEs from these values.
"""
function generateATEs(dataset, treatments, outcome; confounders=nothing, outcome_extra_covariates=())
treatments_unique_values = unique_treatment_values(dataset, treatments)
return generateATEs(treatments_unique_values, outcome; confounders=confounders, outcome_extra_covariates=outcome_extra_covariates)
function generateATEs(dataset, treatments, outcome;
confounders=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing
)
treatments_levels = unique_treatment_values(dataset, treatments)
freq_table = positivity_constraint !== nothing ? frequency_table(dataset, keys(treatments_levels)) : nothing
return generateATEs(
treatments_levels,
outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
freq_table=freq_table,
positivity_constraint=positivity_constraint
)
end

"""
generateATEs(treatments_unique_values, outcome; confounders=nothing, outcome_extra_covariates=())
generateIATEs(
treatments_levels::NamedTuple{names}, outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
Generates a `ComposedEstimand` of Average Interation Effects from `treatments_levels`. $GENERATE_DOCSTRING
# Example:
To generate a causal composed estimand with 3 components:
Generate all possible ATEs from the `treatments_unique_values`.
```@example
generateIATEs((T₁ = (0, 1), T₂=(0, 1, 2)), :Y₁)
```
To generate a statistical composed estimand with 9 components:
```@example
generateIATEs((T₁ = (0, 1, 2), T₂=(0, 1, 2)), :Y₁, confounders=[:W₁, :W₂])
```
"""
function generateATEs(treatments_unique_values, outcome; confounders=nothing, outcome_extra_covariates=())
treatments = Tuple(Symbol.(keys(treatments_unique_values)))
treatments_control_case = [collect(Combinatorics.combinations(treatments_unique_values[T], 2)) for T in treatments]

ATEs = []
for combo Iterators.product(treatments_control_case...)
treatments_control_case = [NamedTuple{(:control, :case)}(treatment_control_case) for treatment_control_case combo]
push!(
ATEs,
ATE(
outcome=outcome,
treatment_values=NamedTuple{treatments}(treatments_control_case),
treatment_confounders = confounders,
outcome_extra_covariates=outcome_extra_covariates
)
)
end
return ATEs
end
function generateIATEs(
treatments_levels::NamedTuple{names}, outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
return generateComposedEstimandFromContrasts(
IATE,
treatments_levels,
outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
freq_table=freq_table,
positivity_constraint=positivity_constraint
)
end

"""
generateIATEs(dataset, treatments, outcome;
confounders=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing
)
Finds treatments levels from the dataset and generates a `ComposedEstimand` of Average Interation Effects from them
(see [`generateIATEs(treatments_levels, outcome; confounders=nothing, outcome_extra_covariates=())`](@ref)).
"""
function generateIATEs(dataset, treatments, outcome;
confounders=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing
)
treatments_levels = unique_treatment_values(dataset, treatments)
freq_table = positivity_constraint !== nothing ? frequency_table(dataset, keys(treatments_levels)) : nothing
return generateIATEs(
treatments_levels,
outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
freq_table=freq_table,
positivity_constraint=positivity_constraint
)
end

joint_levels::StatisticalIATE) = Iterators.product(values.treatment_values)...)

joint_levels::StatisticalATE) =
(Tuple.treatment_values[T][c] for T keys.treatment_values)) for c in (:case, :control))

joint_levels::StatisticalCM) = (values.treatment_values),)
22 changes: 21 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,24 @@ default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(
G_default = G
)

is_binary(dataset, columnname) = Set(skipmissing(Tables.getcolumn(dataset, columnname))) == Set([0, 1])
is_binary(dataset, columnname) = Set(skipmissing(Tables.getcolumn(dataset, columnname))) == Set([0, 1])

function satisfies_positivity(Ψ, freq_table; positivity_constraint=0.01)
for jointlevel in joint_levels(Ψ)
if !haskey(freq_table, jointlevel) || freq_table[jointlevel] < positivity_constraint
return false
end
end
return true
end

satisfies_positivity(Ψ, freq_table::Nothing; positivity_constraint=nothing) = true

function frequency_table(dataset, colnames)
iterator = zip((Tables.getcolumn(dataset, colname) for colname in sort(collect(colnames)))...)
counts = groupcount(x -> x, iterator)
for key in keys(counts)
counts[key] /= nrows(dataset)
end
return counts
end
Loading

0 comments on commit 44e0827

Please sign in to comment.