Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Attempt to clean up and package code to improve reproduciblity #2

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# cache & build
__pycache__/
*.egg-info
.mypy_cache/
build/

*.pth
*.swp
*.log
*.png
wandb/
.ipynb_checkpoints
checkpoints/
*.pt
*.pt
39 changes: 39 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# First install `pre-commit` with `pip install pre-commit`.
# Next install these hooks with `pre-commit install`.
# A good overview of many of the pre-commit tools added below is
# available here: https://scikit-hep.org/developer/style

ci:
autoupdate_schedule: quarterly

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.9
hooks:
- id: ruff
args: [--fix]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-case-conflict
- id: check-symlinks
- id: check-yaml
args: [--unsafe]
- id: destroyed-symlinks
- id: end-of-file-fixer
- id: mixed-line-ending
- id: trailing-whitespace

- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
- id: codespell
exclude_types: [json]
args: [--ignore-words-list, 'datas,calender', --builtin, 'clear']

# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v0.942
# hooks:
# - id: mypy
# exclude: (tests|examples)
20 changes: 6 additions & 14 deletions dirac_phi.py → dcsurvival/dirac_phi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Contains the generator for ACNet (in the class DiracPhi).
"""Contains the generator for ACNet (in the class DiracPhi).
This is named as such since the mixing variable is a convex combination of dirac delta functions.
"""

Expand All @@ -8,12 +7,10 @@


class DiracPhi(nn.Module):
'''
TODO: streamline 3 cases in forward pass.
'''
"""TODO: streamline 3 cases in forward pass."""

def __init__(self, depth, widths, lc_w_range, shift_w_range, device, tol):
super(DiracPhi, self).__init__()
def __init__(self, depth, widths, lc_w_range, shift_w_range, device, tol) -> None:
super().__init__()

# Depth is the number of hidden layers.
self.depth = depth
Expand Down Expand Up @@ -52,19 +49,14 @@ def get_sizes_w_(self):
lc_sizes, shift_sizes = [], []

# Shift weights
prev_width = 1
for pos in range(depth):
width = widths[pos]
shift_sizes.append((width,))
prev_width = width

# Linear combination weights
for pos in range(depth):
width = widths[pos]
if pos < depth-1:
next_width = widths[pos+1]
else:
next_width = 1
next_width = widths[pos + 1] if pos < depth - 1 else 1
lc_sizes.append((next_width, width))

return shift_sizes + lc_sizes
Expand Down Expand Up @@ -122,6 +114,6 @@ def pf(x): return torch.exp(x)

output = states[-1]
assert (output >= 0.).all() and \
(output <= 1.+ self.tol).all(), "t %s, output %s, tol %s, max %s, min %s" % (t, output, self.tol, torch.max(output), torch.min(output))
(output <= 1.+ self.tol).all(), f"t {t}, output {output}, tol {self.tol}, max {torch.max(output)}, min {torch.min(output)}"

return output.reshape(t_raw.size())
83 changes: 41 additions & 42 deletions metrics/metric.py → dcsurvival/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import torch
import copy

import numpy as np
from tqdm import tqdm
import torch
import copy


def Survival(truth_model, estimate, x, time_steps):
"""
model: the learned survival model
truth: the true survival model
"""model: the learned survival model
truth: the true survival model.
"""
device = torch.device("cpu")
estimate = copy.deepcopy(estimate).to(device)
Expand All @@ -23,7 +22,7 @@ def Survival(truth_model, estimate, x, time_steps):


def surv_diff(truth_model, estimate, x, steps):
device = torch.device("cpu")
torch.device("cpu")
surv1, surv2, time_steps, t_m = Survival(truth_model, estimate, x, steps)
# integ = torch.abs(surv1-surv2).sum()
integ = torch.sum( torch.diff(torch.cat([torch.zeros(1), time_steps])) * torch.abs(surv1-surv2) )
Expand Down Expand Up @@ -65,7 +64,7 @@ def BS_censored(t, event_t, x, model,e, km_h, km_p):
t_ind2 = (t < event_t).type(torch.float32)
G_t = KM_evaluater(t, km_h, km_p).clamp(0.01,1)
G_event = KM_evaluater(event_t, km_h, km_p).clamp(0.01,100)

#print(torch.sum(t_ind1/(G_event+1e-9*(G_event==0))), torch.sum(t_ind2/(G_t+1e-9*(G_t==0))))
return torch.mean((tmp *(t_ind1/(G_event+1e-9*(G_event==0)))) + (tmp * (t_ind2/(G_t+1e-9*(G_t==0)))))

Expand All @@ -83,49 +82,49 @@ def IBS_plain(event_t, x, model, t_max, n_bins=100):
ibs = 0
for t_ in torch.linspace(0, t_max, n_bins):
tmp = BS(torch.ones_like(event_t, device=event_t.device)*t_, event_t, x, model)

ibs += tmp * len_bin
return ibs/t_max



def evaluate_c_index(dep_model, indep_model, dgp, test_dict, E_reverse = False):
E = test_dict['E']
E = test_dict["E"]
if E_reverse:
E = 1-test_dict['E']
dgp_obs = C_index(test_dict['T'], test_dict['X'], E, dgp).cpu().detach().numpy().item()
dep_obs = C_index(test_dict['T'], test_dict['X'], E, dep_model).cpu().detach().numpy().item()
indep_obs = C_index(test_dict['T'], test_dict['X'], E, indep_model).cpu().detach().numpy().item()
E = 1-test_dict["E"]
dgp_obs = C_index(test_dict["T"], test_dict["X"], E, dgp).cpu().detach().numpy().item()
dep_obs = C_index(test_dict["T"], test_dict["X"], E, dep_model).cpu().detach().numpy().item()
indep_obs = C_index(test_dict["T"], test_dict["X"], E, indep_model).cpu().detach().numpy().item()
aux_e = torch.ones_like(E, device = E.device)
t = test_dict['t1']
t = test_dict["t1"]
if E_reverse:
t = test_dict['t2']
dgp_tot = C_index(t, test_dict['X'], aux_e, dgp).cpu().numpy().item()
dep_tot = C_index(t, test_dict['X'], aux_e, dep_model).cpu().numpy().item()
indep_tot = C_index(t, test_dict['X'], aux_e, indep_model).cpu().numpy().item()
t = test_dict["t2"]
dgp_tot = C_index(t, test_dict["X"], aux_e, dgp).cpu().numpy().item()
dep_tot = C_index(t, test_dict["X"], aux_e, dep_model).cpu().numpy().item()
indep_tot = C_index(t, test_dict["X"], aux_e, indep_model).cpu().numpy().item()
return [[dgp_obs, dep_obs, indep_obs], [dgp_tot, dep_tot, indep_tot]]



def evaluate_IBS(dep_model, indep_model, dgp, test_dict,km_h, km_p, E_reverse):
t = test_dict['t1']
t = test_dict["t1"]
if E_reverse:
t = test_dict['t2']
dgp_tot = IBS_plain(t, test_dict['X'], dgp, torch.max(t), n_bins=100).cpu().numpy().item()
dep_tot = IBS_plain(t, test_dict['X'], dep_model, torch.max(t), n_bins=100).cpu().numpy().item()
indep_tot = IBS_plain(t, test_dict['X'], indep_model, torch.max(t), n_bins=100).cpu().numpy().item()
E = test_dict['E']
t = test_dict["t2"]
dgp_tot = IBS_plain(t, test_dict["X"], dgp, torch.max(t), n_bins=100).cpu().numpy().item()
dep_tot = IBS_plain(t, test_dict["X"], dep_model, torch.max(t), n_bins=100).cpu().numpy().item()
indep_tot = IBS_plain(t, test_dict["X"], indep_model, torch.max(t), n_bins=100).cpu().numpy().item()
E = test_dict["E"]
if E_reverse:
E = 1-test_dict['E']
dgp_obs = IBS(test_dict['T'], test_dict['X'], dgp, torch.max(test_dict['T']), E, km_h, km_p).cpu().numpy().item()
dep_obs = IBS(test_dict['T'], test_dict['X'], dep_model, torch.max(test_dict['T']), E, km_h, km_p).cpu().numpy().item()
indep_obs = IBS(test_dict['T'], test_dict['X'], indep_model, torch.max(test_dict['T']), E, km_h, km_p).cpu().numpy().item()
E = 1-test_dict["E"]
dgp_obs = IBS(test_dict["T"], test_dict["X"], dgp, torch.max(test_dict["T"]), E, km_h, km_p).cpu().numpy().item()
dep_obs = IBS(test_dict["T"], test_dict["X"], dep_model, torch.max(test_dict["T"]), E, km_h, km_p).cpu().numpy().item()
indep_obs = IBS(test_dict["T"], test_dict["X"], indep_model, torch.max(test_dict["T"]), E, km_h, km_p).cpu().numpy().item()
return [[dgp_obs, dep_obs, indep_obs], [dgp_tot, dep_tot, indep_tot]]

def KM(t, e):
device= t.device
t = t.cpu().numpy().reshape(-1,)
e = e.cpu().numpy().reshape(-1,)
t = t.cpu().numpy().reshape(-1)
e = e.cpu().numpy().reshape(-1)
indices = np.argsort(t)
t_sorted = t[indices]
e_sorted = e[indices]
Expand All @@ -142,7 +141,7 @@ def KM(t, e):
event_times_end = np.zeros(event_times.shape[0]+1)
event_times_end[1:] = event_times
return torch.from_numpy(event_times_end).to(device), torch.from_numpy(prob_end).to(device)

def KM_evaluater(t, h, p):
device = t.device
h = h.cpu().numpy()
Expand All @@ -151,10 +150,10 @@ def KM_evaluater(t, h, p):
if len(t.shape) == 1:
idx = np.digitize(t, h, False)
return torch.from_numpy(p[idx-1]).to(device)
else:
idx = np.digitize(t, h, False)
prob = np.zeros_like(t)
for i in range(idx.shape[0]):
for j in range(idx.shape[1]):
prob[i, j] = p[idx[i,j]-1]
return torch.from_numpy(prob).to(device)

idx = np.digitize(t, h, False)
prob = np.zeros_like(t)
for i in range(idx.shape[0]):
for j in range(idx.shape[1]):
prob[i, j] = p[idx[i,j]-1]
return torch.from_numpy(prob).to(device)
79 changes: 38 additions & 41 deletions metrics/metric_pycox.py → dcsurvival/metrics/metric_pycox.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import torch
import numpy as np
from tqdm import tqdm
import torch


def Survival(truth_model, estimate, x, time_steps):
"""
model: the learned survival model
truth: the true survival model
"""model: the learned survival model
truth: the true survival model.
"""
device = torch.device("cpu")
# surv1 = torch.zeros((x.shape[0], time_steps.shape[0]),device=device)
Expand All @@ -23,7 +20,7 @@ def Survival(truth_model, estimate, x, time_steps):


def surv_diff(truth_model, estimate, x, steps):
device = torch.device("cpu")
torch.device("cpu")
surv1, surv2, time_steps, t_m = Survival(truth_model, estimate, x, steps)
# integ = torch.abs(surv1-surv2).sum()
integ = torch.sum( torch.diff(torch.cat([torch.zeros(1), time_steps])) * torch.abs(surv1-surv2) )
Expand Down Expand Up @@ -64,7 +61,7 @@ def BS_censored(t, event_t, x, model,e, km_h, km_p):
t_ind2 = (t < event_t).type(torch.float32)
G_t = KM_evaluater(t, km_h, km_p).clamp(0.01,1)
G_event = KM_evaluater(event_t, km_h, km_p).clamp(0.01,100)

#print(torch.sum(t_ind1/(G_event+1e-9*(G_event==0))), torch.sum(t_ind2/(G_t+1e-9*(G_t==0))))
return torch.mean((tmp *(t_ind1/(G_event+1e-9*(G_event==0)))) + (tmp * (t_ind2/(G_t+1e-9*(G_t==0)))))

Expand All @@ -82,49 +79,49 @@ def IBS_plain(event_t, x, model, t_max, n_bins=100):
ibs = 0
for t_ in torch.linspace(0, t_max, n_bins):
tmp = BS(torch.ones_like(event_t, device=event_t.device)*t_, event_t, x, model)

ibs += tmp * len_bin
return ibs/t_max



def evaluate_c_index(dep_model, indep_model, dgp, test_dict, E_reverse = False):
E = test_dict['E']
E = test_dict["E"]
if E_reverse:
E = 1-test_dict['E']
dgp_obs = C_index(test_dict['T'], test_dict['X'], E, dgp).cpu().detach().numpy().item()
dep_obs = C_index(test_dict['T'], test_dict['X'], E, dep_model).cpu().detach().numpy().item()
indep_obs = C_index(test_dict['T'], test_dict['X'], E, indep_model).cpu().detach().numpy().item()
E = 1-test_dict["E"]
dgp_obs = C_index(test_dict["T"], test_dict["X"], E, dgp).cpu().detach().numpy().item()
dep_obs = C_index(test_dict["T"], test_dict["X"], E, dep_model).cpu().detach().numpy().item()
indep_obs = C_index(test_dict["T"], test_dict["X"], E, indep_model).cpu().detach().numpy().item()
aux_e = torch.ones_like(E, device = E.device)
t = test_dict['t1']
t = test_dict["t1"]
if E_reverse:
t = test_dict['t2']
dgp_tot = C_index(t, test_dict['X'], aux_e, dgp).cpu().numpy().item()
dep_tot = C_index(t, test_dict['X'], aux_e, dep_model).cpu().numpy().item()
indep_tot = C_index(t, test_dict['X'], aux_e, indep_model).cpu().numpy().item()
t = test_dict["t2"]
dgp_tot = C_index(t, test_dict["X"], aux_e, dgp).cpu().numpy().item()
dep_tot = C_index(t, test_dict["X"], aux_e, dep_model).cpu().numpy().item()
indep_tot = C_index(t, test_dict["X"], aux_e, indep_model).cpu().numpy().item()
return [[dgp_obs, dep_obs, indep_obs], [dgp_tot, dep_tot, indep_tot]]



def evaluate_IBS(dep_model, indep_model, dgp, test_dict,km_h, km_p, E_reverse):
t = test_dict['t1']
t = test_dict["t1"]
if E_reverse:
t = test_dict['t2']
dgp_tot = IBS_plain(t, test_dict['X'], dgp, torch.max(t), n_bins=100).cpu().numpy().item()
dep_tot = IBS_plain(t, test_dict['X'], dep_model, torch.max(t), n_bins=100).cpu().numpy().item()
indep_tot = IBS_plain(t, test_dict['X'], indep_model, torch.max(t), n_bins=100).cpu().numpy().item()
E = test_dict['E']
t = test_dict["t2"]
dgp_tot = IBS_plain(t, test_dict["X"], dgp, torch.max(t), n_bins=100).cpu().numpy().item()
dep_tot = IBS_plain(t, test_dict["X"], dep_model, torch.max(t), n_bins=100).cpu().numpy().item()
indep_tot = IBS_plain(t, test_dict["X"], indep_model, torch.max(t), n_bins=100).cpu().numpy().item()
E = test_dict["E"]
if E_reverse:
E = 1-test_dict['E']
dgp_obs = IBS(test_dict['T'], test_dict['X'], dgp, torch.max(test_dict['T']), E, km_h, km_p).cpu().numpy().item()
dep_obs = IBS(test_dict['T'], test_dict['X'], dep_model, torch.max(test_dict['T']), E, km_h, km_p).cpu().numpy().item()
indep_obs = IBS(test_dict['T'], test_dict['X'], indep_model, torch.max(test_dict['T']), E, km_h, km_p).cpu().numpy().item()
E = 1-test_dict["E"]
dgp_obs = IBS(test_dict["T"], test_dict["X"], dgp, torch.max(test_dict["T"]), E, km_h, km_p).cpu().numpy().item()
dep_obs = IBS(test_dict["T"], test_dict["X"], dep_model, torch.max(test_dict["T"]), E, km_h, km_p).cpu().numpy().item()
indep_obs = IBS(test_dict["T"], test_dict["X"], indep_model, torch.max(test_dict["T"]), E, km_h, km_p).cpu().numpy().item()
return [[dgp_obs, dep_obs, indep_obs], [dgp_tot, dep_tot, indep_tot]]

def KM(t, e):
device= t.device
t = t.cpu().numpy().reshape(-1,)
e = e.cpu().numpy().reshape(-1,)
t = t.cpu().numpy().reshape(-1)
e = e.cpu().numpy().reshape(-1)
indices = np.argsort(t)
t_sorted = t[indices]
e_sorted = e[indices]
Expand All @@ -141,7 +138,7 @@ def KM(t, e):
event_times_end = np.zeros(event_times.shape[0]+1)
event_times_end[1:] = event_times
return torch.from_numpy(event_times_end).to(device), torch.from_numpy(prob_end).to(device)

def KM_evaluater(t, h, p):
device = t.device
h = h.cpu().numpy()
Expand All @@ -150,10 +147,10 @@ def KM_evaluater(t, h, p):
if len(t.shape) == 1:
idx = np.digitize(t, h, False)
return torch.from_numpy(p[idx-1]).to(device)
else:
idx = np.digitize(t, h, False)
prob = np.zeros_like(t)
for i in range(idx.shape[0]):
for j in range(idx.shape[1]):
prob[i, j] = p[idx[i,j]-1]
return torch.from_numpy(prob).to(device)

idx = np.digitize(t, h, False)
prob = np.zeros_like(t)
for i in range(idx.shape[0]):
for j in range(idx.shape[1]):
prob[i, j] = p[idx[i,j]-1]
return torch.from_numpy(prob).to(device)
Loading