Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* fixed typos, added warning for ties

* removed lines, added plots

* cleanup

* Address reviews on the Statement of Need Section

---------

Co-authored-by: melodiemonod <monod.melodie@gmail.com>
  • Loading branch information
tcoroller and melodiemonod authored Dec 11, 2024
1 parent ffd6920 commit a505bfc
Show file tree
Hide file tree
Showing 6 changed files with 454 additions and 79 deletions.
1 change: 1 addition & 0 deletions docs/notebooks/helpers_momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def train_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=True,
shuffle=True,
)

def val_dataloader(self):
Expand Down
353 changes: 308 additions & 45 deletions docs/notebooks/introduction.ipynb

Large diffs are not rendered by default.

163 changes: 136 additions & 27 deletions docs/notebooks/momentum.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions paper/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ bibliography: paper.bib

# Statement of need

Survival analysis plays a crucial role in various domains, such as medicine, economics or engineering. Thus, sophisticated survival models sugin deep learning opens new opportunities to leverage complex dataset and relationships. However, no existing library provides the flexibility to define the survival model's parameters using a custom `PyTorch`-based neural network.
Survival analysis plays a crucial role in various domains, such as medicine, economics or engineering. Sophisticated survival analysis using deep learning, often referred to as "deep survival analysis," unlocks new opportunities to leverage new data types and uncover intricate relationships.
However, performing comprehensive deep survival analysis remain challenging. Key issues include the lack of flexibility in existing tools to define survival model parameters with custom architectures and limitations in handling complex, high-dimensional datasets. Indeed, existing frameworks often lack the computational efficiency necessary to process large datasets efficiently, making them less suitable for real-world applications where time and resource constraints are paramount.

\autoref{tab:bibliography} compares the functionalities of `TorchSurv` with those of
To address these gaps, we propose a flexible, `PyTorch`-based library that allows users to define survival model parameters using custom neural network architectures. By combining computational efficiency with ease of use, this toolbox opens new opportunities to advance survival analysis research and application, making it more accessible and interpretable for practitioners across disciplines. \autoref{tab:bibliography} compares the functionalities of `TorchSurv` with those of
`auton-survival` [@nagpal2022auton],
`pycox` [@Kvamme2019pycox],
`torchlife` [@torchlifeAbeywardana],
Expand Down
8 changes: 6 additions & 2 deletions src/torchsurv/loss/cox.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def neg_partial_log_likelihood(
>>> time = torch.tensor([1., 2., 3., 4., 5.])
>>> neg_partial_log_likelihood(log_hz, event, time) # default, mean of log likelihoods across patients
tensor(1.0071)
>>> neg_partial_log_likelihood(log_hz, event, time, reduction = 'sum') # sun of log likelihoods across patients
>>> neg_partial_log_likelihood(log_hz, event, time, reduction = 'sum') # sum of log likelihoods across patients
tensor(3.0214)
>>> time = torch.tensor([1., 2., 2., 4., 5.]) # Dealing with ties (default: Efron)
>>> neg_partial_log_likelihood(log_hz, event, time, ties_method = "efron")
tensor(1.0873)
>>> neg_partial_log_likelihood(log_hz, event, time, ties_method = "breslow") # Dealing with ties (Bfron)
>>> neg_partial_log_likelihood(log_hz, event, time, ties_method = "breslow") # Dealing with ties (Breslow)
tensor(1.0873)
References:
Expand Down Expand Up @@ -134,6 +134,10 @@ def neg_partial_log_likelihood(
# if not ties, use traditional cox partial likelihood
pll = _partial_likelihood_cox(log_hz_sorted, event_sorted)
else:
# add warning about ties
warnings.warn(
f"Ties in event time detected; using {ties_method}'s method to handle ties."
)
# if ties, use either efron or breslow approximation of partial likelihood
if ties_method == "efron":
pll = _partial_likelihood_efron(
Expand Down
3 changes: 0 additions & 3 deletions src/torchsurv/stats/kaplan_meier.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,6 @@ def print_survival_table(self):
for t, y in zip(self.time, self.km_est):
print(f"{t:.2f}\t{y:.4f}")

x = torch.randn(1, 50, 50, 50)
print(x.shape) # shows the shape of the tensor

def _compute_counts(
self,
) -> Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
Expand Down

0 comments on commit a505bfc

Please sign in to comment.