Skip to content

Commit

Permalink
up TMLE and get rid of windows
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Jan 27, 2024
1 parent f0da774 commit 289bfa2
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 17 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
os:
- ubuntu-latest
- macOS-latest
- windows-latest
arch:
- 'x64'
steps:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ MLJModels = "0.16"
MLJXGBoostInterface = "0.3.4"
MultipleTesting = "0.6.0"
Optim = "1.7"
TMLE = "0.13.1"
TMLE = "0.14.0"
Tables = "1.10.1"
YAML = "0.4.9"
julia = "1.7, 1"
4 changes: 2 additions & 2 deletions src/cli.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ function cli_settings()

"--estimands"
arg_type = String
help = "A string (`generateATEs`) or a serialized TMLE.Configuration (accepted formats: .json | .yaml | .jls)"
default = "generateATEs"
help = "A string (`factorialATE`) or a serialized TMLE.Configuration (accepted formats: .json | .yaml | .jls)"
default = "factorialATE"

"--estimators"
arg_type = String
Expand Down
8 changes: 4 additions & 4 deletions src/runner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mutable struct Runner
verbosity::Int
failed_nuisance::Set
function Runner(dataset;
estimands="generateATEs",
estimands="factorialATE",
estimators="glmnet",
verbosity=0,
outputs=Outputs(),
Expand Down Expand Up @@ -120,7 +120,7 @@ end

"""
tmle(dataset;
estimands="generateATEs",
estimands="factorialATE",
estimators="glmnet";
verbosity=0,
outputs=Outputs(),
Expand All @@ -138,7 +138,7 @@ TMLE CLI.
# Options
- `--estimands`: A string ("generateATEs") or a serialized TMLE.Configuration (accepted formats: .json | .yaml | .jls)
- `--estimands`: A string ("factorialATE") or a serialized TMLE.Configuration (accepted formats: .json | .yaml | .jls)
- `--estimators`: A julia file containing the estimators to use.
- `-v, --verbosity`: Verbosity level.
- `-o, --outputs`: Ouputs to be generated.
Expand All @@ -151,7 +151,7 @@ TMLE CLI.
- `-s, --sort_estimands`: Sort estimands to minimize cache usage (A brute force approach will be used, resulting in exponentially long sorting time).
"""
function tmle(dataset::String;
estimands::String="generateATEs",
estimands::String="factorialATE",
estimators::String="glmnet",
verbosity::Int=0,
outputs::Outputs=Outputs(),
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,19 @@ This explicitely requires that the following columns belong to the dataset:
All ATE parameters are generated.
"""
function TMLE.generateATEs(dataset)
function TMLE.factorialATE(dataset)
colnames = names(dataset)
"T" colnames || throw(ArgumentError("No column 'T' found in the dataset for the treatment variable."))
"Y" colnames || throw(ArgumentError("No column 'Y' found in the dataset for the outcome variable."))
confounding_variables = Tuple(name for name in colnames if occursin(r"^W", name))
length(confounding_variables) > 0 || throw(ArgumentError("Could not find any confounding variable (starting with 'W') in the dataset."))

return [generateATEs(dataset, (:T, ), :Y; confounders=confounding_variables)]
return [factorialATE(dataset, (:T, ), :Y; confounders=confounding_variables)]
end

function build_estimands_list(estimands_pattern, dataset)
estimands = if estimands_pattern == "generateATEs"
generateATEs(dataset)
estimands = if estimands_pattern == "factorialATE"
factorialATE(dataset)
else
proofread_estimands(estimands_pattern, dataset)
end
Expand Down
9 changes: 4 additions & 5 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,17 @@ end
rm(filename)
end

@testset "Test generateATEs" begin
@testset "Test factorialATE" begin
dataset = DataFrame(C=[1, 2, 3, 4],)
@test_throws ArgumentError TargetedEstimation.build_estimands_list("generateATEs", dataset)
@test_throws ArgumentError TargetedEstimation.build_estimands_list("factorialATE", dataset)
dataset.T = [0, 1, missing, 2]
@test_throws ArgumentError TargetedEstimation.build_estimands_list("generateATEs", dataset)
@test_throws ArgumentError TargetedEstimation.build_estimands_list("factorialATE", dataset)
dataset.Y = [0, 1, 2, 2]
dataset.W1 = [1, 1, 1, 1]
dataset.W_2 = [1, 1, 1, 1]
composedATE = TargetedEstimation.build_estimands_list("generateATEs", dataset)[1]
composedATE = TargetedEstimation.build_estimands_list("factorialATE", dataset)[1]
@test composedATE.args == (
TMLE.StatisticalATE(:Y, (T = (case = 1, control = 0),), (T = (:W1, :W_2),), ()),
TMLE.StatisticalATE(:Y, (T = (case = 2, control = 0),), (T = (:W1, :W_2),), ()),
TMLE.StatisticalATE(:Y, (T = (case = 2, control = 1),), (T = (:W1, :W_2),), ())
)
end
Expand Down

0 comments on commit 289bfa2

Please sign in to comment.