Skip to content

Commit

Permalink
fix counterfactual treatment with order added
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Apr 8, 2023
1 parent ca20ac8 commit e7ee755
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ function counterfactualTreatment(vals, T)
Tnames = Tables.columnnames(T)
n = nrows(T)
NamedTuple{Tnames}(
[categorical(repeat([vals[i]], n), levels=levels(Tables.getcolumn(T, name)))
[categorical(repeat([vals[i]], n), levels=levels(Tables.getcolumn(T, name)), ordered=isordered(Tables.getcolumn(T, name)))
for (i, name) in enumerate(Tnames)])
end

Expand Down
5 changes: 4 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,17 @@ end
@testset "Test counterfactualTreatment" begin
vals = (true, "a")
T = (
t₁ = categorical([true, false, false]),
t₁ = categorical([true, false, false], ordered=true),
t₂ = categorical(["a", "a", "c"])
)
cfT = TMLE.counterfactualTreatment(vals, T)
@test cfT == (
t₁ = categorical([true, true, true]),
t₂ = categorical(["a", "a", "a"])
)
@test isordered(cfT.t₁)
@test !isordered(cfT.t₂)

end

@testset "Test compute_covariate" begin
Expand Down

0 comments on commit e7ee755

Please sign in to comment.