Skip to content

Commit

Permalink
add ref_cov parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengp0 committed Jun 26, 2024
1 parent 8a718b8 commit 295ea52
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 13 deletions.
37 changes: 30 additions & 7 deletions src/mrtool/core/cov_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Covariates model for `mrtool`.
"""

import warnings

import numpy as np
import pandas as pd
import xspline
Expand Down Expand Up @@ -972,31 +974,52 @@ def __init__(
raise ValueError("alt_cov should be a single column.")
if len(self.ref_cov) > 1:
raise ValueError("ref_cov should be nothing or a single column.")
if len(self.ref_cov) == 1 and self.ref_cat is None:
warnings.warn(
"ref_cat is not provided for a comparison covmodel, it will be "
"inferenced as the most common categories when attaching data."
)
if len(self.ref_cov) == 0 and self.ref_cat is not None:
raise ValueError(
"Cannot set ref_cat when this is not a comparison model."
)

self.cats: pd.Series

def attach_data(self, data: MRData) -> None:
"""Attach data and parse the categories. Number of variables will be
determined here and priors will be processed here as well.
determined here and priors will be processed and if ref_cov is not set
before, and this is a comparison model, ref_cov will be inferred as the
most common category.
"""
alt_cov = data.get_covs(self.alt_cov)
ref_cov = data.get_covs(self.ref_cov)
self.cats = pd.Series(
np.unique(np.hstack([alt_cov, ref_cov])),
name="cats",
unique_cats, counts = np.unique(
np.hstack([alt_cov, ref_cov]), return_counts=True
)
self.cats = pd.Series(unique_cats, name="cats")
self._process_priors()

if len(self.ref_cov) == 1:
if self.ref_cat is None:
self.ref_cat = unique_cats[counts.argmax()]
if self.ref_cat not in unique_cats:
raise ValueError(
f"ref_cat {self.ref_cat} is not in the categories."
)

def has_data(self) -> bool:
"""Return if the data has been attached and categories has been parsed."""
return hasattr(self, "cats")

def encode(self, x: NDArray) -> NDArray:
"""Encode the provided categories into dummy variables."""
col = pd.merge(pd.Series(x, name="cats"), self.cats.reset_index())[
"index"
]
col = pd.merge(
pd.Series(x, name="cats"), self.cats.reset_index(), how="left"
)["index"]
if np.isnan(col).any():
raise ValueError("Categories not found")
mat = np.zeros((len(x), self.num_x_vars))
mat[range(len(x)), col] = 1.0
return mat
Expand Down
45 changes: 39 additions & 6 deletions tests/test_cat_covmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def data():


def test_init():
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
assert covmodel.alt_cov == ["alt_cat"]
assert covmodel.ref_cov == ["ref_cat"]

Expand All @@ -41,26 +41,51 @@ def test_init():
CatCovModel(alt_cov=["a", "b"])

with pytest.raises(ValueError):
CatCovModel(alt_cov="a", ref_cov=["a", "b"])
CatCovModel(alt_cov="a", ref_cov=["a", "b"], ref_cat="A")


def test_attach_data(data):
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
assert not hasattr(covmodel, "cats")
covmodel.attach_data(data)
assert covmodel.cats.to_list() == ["A", "B", "C", "D"]


def test_ref_cov(data):
with pytest.raises(ValueError):
covmodel = CatCovModel(
alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="E"
)
covmodel.attach_data(data)

with pytest.raises(ValueError):
covmodel = CatCovModel(alt_cov="alt_cat", ref_cat="A")

covmodel = CatCovModel(alt_cov="alt_cat")
covmodel.attach_data(data)
assert covmodel.ref_cat is None

with pytest.warns():
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
assert covmodel.ref_cat is None
covmodel.attach_data(data)
assert covmodel.ref_cat == "A"

covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="B")
covmodel.attach_data(data)
assert covmodel.ref_cat == "B"


def test_has_data(data):
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
assert not covmodel.has_data()

covmodel.attach_data(data)
assert covmodel.has_data()


def test_encode(data):
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
covmodel.attach_data(data)

mat = covmodel.encode(["A", "B", "C", "C"])
Expand All @@ -77,8 +102,16 @@ def test_encode(data):
assert np.allclose(mat, true_mat)


def test_encode_fail(data):
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
covmodel.attach_data(data)

with pytest.raises(ValueError):
covmodel.encode(["A", "B", "C", "E"])


def test_create_design_mat(data):
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
covmodel.attach_data(data)

alt_mat, ref_mat = covmodel.create_design_mat(data)
Expand Down

0 comments on commit 295ea52

Please sign in to comment.