diff --git a/orangecontrib/survival_analysis/widgets/owkaplanmeier.py b/orangecontrib/survival_analysis/widgets/owkaplanmeier.py index 9651ec6..8d4bec8 100644 --- a/orangecontrib/survival_analysis/widgets/owkaplanmeier.py +++ b/orangecontrib/survival_analysis/widgets/owkaplanmeier.py @@ -3,13 +3,13 @@ from typing import Dict, List, Optional, NamedTuple from itertools import zip_longest -from xml.sax.saxutils import escape from AnyQt.QtGui import QBrush, QColor, QPainterPath from AnyQt.QtCore import Qt, QSize from AnyQt.QtCore import pyqtSignal as Signal from pyqtgraph.functions import mkPen from pyqtgraph.graphicsItems.ViewBox import ViewBox +from pyqtgraph.graphicsItems.LegendItem import ItemSample, LabelItem from lifelines import KaplanMeierFitter from lifelines.utils import median_survival_times @@ -44,7 +44,7 @@ def generate_curve_coordinates(timeline, probabilities): def __init__(self, time, events, label=None, color=None): self._kmf = KaplanMeierFitter().fit(time.astype(np.float64), events.astype(np.float64)) - self.label: str = label + self._label: str = label self.color: List[int] = color # refactor this @@ -68,7 +68,9 @@ def __init__(self, time, events, label=None, color=None): self.selection = pg.PlotDataItem(pen=mkPen(color=QColor(Qt.yellow), width=4)) self.selection.hide() - self.median_survival = median = median_survival_times(self._kmf.survival_function_.astype(np.float32)) + self.median_survival = median = np.round( + median_survival_times(self._kmf.survival_function_.astype(np.float32)), 1 + ) self.median_vertical = pg.PlotDataItem(x=(median, median), y=(0, 0.5), pen=MEDIAN_LINE_PEN) censored_data = self.get_censored_data() @@ -83,13 +85,20 @@ def __init__(self, time, events, label=None, color=None): ) self.censored_data.setZValue(10) + self.num_of_samples = len(events) + self.num_of_censored_samples = len(censored_data) + + @property + def label(self): + return self._label if self._label else 'All' + def get_censored_data(self): time_events = np.column_stack((self._kmf.durations, self._kmf.event_observed)) censored_time = time_events[np.argwhere(time_events[:, 1] == 0), 0] survival = self._kmf.survival_function_.values return np.column_stack((censored_time, survival[np.where(censored_time == self._kmf.timeline)[1]])) - def get_color(self, alpha) -> QColor: + def get_color(self, alpha=255) -> QColor: color = QColor(*self.color) if self.color else QColor(Qt.darkGray) color.setAlpha(alpha) return color @@ -156,6 +165,58 @@ def mouseDragEvent(self, ev, axis=None): ev.ignore() +class HLineItemSample(ItemSample): + def __init__(self): + super().__init__(None) + self.pen = pg.mkPen(color=QColor(Qt.darkGray), width=2) + + def paint(self, p, *args): + p.setPen(self.pen) + p.drawLine(0, 0, int(self.width()), 0) + + +class CustomLegendItem(LegendItem): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.layout.setVerticalSpacing(3) + self._sample_column_label = 'n/N' + self._median_column_label = 'Median' + + def set_header(self): + samples = LabelItem(self._sample_column_label) + median = LabelItem(self._median_column_label) + separator = HLineItemSample() + + self.layout.addItem(samples, 1, 2) + self.layout.addItem(median, 1, 3) + self.items.append(samples) + self.items.append(median) + + row = self.layout.rowCount() + self.layout.addItem(separator, row, 1, row, 3) + self.items.append(separator) + + def set_curve(self, curve: EstimatedFunctionCurve): + curve_label = LabelItem(curve.label, color=curve.get_color()) + samples = LabelItem(f'{curve.num_of_samples - curve.num_of_censored_samples}/{curve.num_of_samples}') + median = LabelItem(f'{curve.median_survival}') + + row = self.layout.rowCount() + self.items.append(curve_label) + self.items.append(samples) + self.items.append(median) + + self.layout.addItem(curve_label, row, 1) + self.layout.addItem(samples, row, 2) + self.layout.addItem(median, row, 3) + + def clear(self): + for item in self.items: + self.layout.removeItem(item) + self.items = [] + self.updateSize() + + class KaplanMeierPlot(gui.OWComponent, pg.PlotWidget): HIGHLIGHT_RADIUS = 20 # in pixels selection_changed = Signal() @@ -179,7 +240,7 @@ def __init__(self, parent=None): ) self.view_box.selection_changed.connect(self.on_selection_changed) - self.legend = LegendItem() + self.legend = CustomLegendItem() self.legend.setParentItem(self.getViewBox()) self.legend.restoreAnchor(((1, 0), (1, 0))) @@ -321,10 +382,10 @@ def update_plot(self, confidence_interval=False, median=False, censored=False): def update_legend(self): self.legend.hide() - for curve in [curve for curve in self.curves.values() if curve.color and curve.label]: - c = QColor(*curve.color) - dot = pg.ScatterPlotItem(pen=c, brush=c, size=10, symbol='s') - self.legend.addItem(dot, escape(curve.label)) + self.legend.set_header() + for curve in [curve for curve in self.curves.values()]: + self.legend.set_curve(curve) + self.legend.updateSize() if bool(len(self.legend.items)): self.legend.show() diff --git a/orangecontrib/survival_analysis/widgets/tests/test_owkaplanmeier.py b/orangecontrib/survival_analysis/widgets/tests/test_owkaplanmeier.py index e174640..8ab1222 100644 --- a/orangecontrib/survival_analysis/widgets/tests/test_owkaplanmeier.py +++ b/orangecontrib/survival_analysis/widgets/tests/test_owkaplanmeier.py @@ -1,10 +1,12 @@ import pyqtgraph as pg from pyqtgraph.tests import mouseMove, mousePress, mouseRelease, mouseClick +from pyqtgraph.graphicsItems.LegendItem import LabelItem +from pyqtgraph.Qt import QtTest from Orange.data.table import Table, Domain, StringVariable, ContinuousVariable, DiscreteVariable from orangewidget.tests.base import WidgetTest from orangecontrib.survival_analysis.widgets.owkaplanmeier import OWKaplanMeier -from pyqtgraph.Qt import QtTest + from AnyQt.QtCore import Qt @@ -81,11 +83,15 @@ def test_group_variable(self): self.assertTrue(len(items) == 4) def test_legend(self): - self.assertFalse(self.widget.graph.legend.items) + legend = tuple(label.text for label in self.widget.graph.legend.items if isinstance(label, LabelItem)) + self.assertIn('All', legend) + self.widget.group_var = self.widget.data.domain['Group'] self.widget.on_controls_changed() - legend_text = tuple(label.text for _, label in self.widget.graph.legend.items) - self.assertEqual(self.widget.group_var.values, legend_text) + + legend = tuple(label.text for label in self.widget.graph.legend.items if isinstance(label, LabelItem)) + for group in self.widget.group_var.values: + self.assertIn(group, legend) def test_curve_highlight(self): self.widget.group_var = self.widget.data.domain['Group'] @@ -106,6 +112,7 @@ def test_selection(self): selected_data = self.get_output(self.widget.Outputs.selected_data) self.assertIsNone(selected_data) + self.widget.graph.legend.hide() self.simulate_mouse_drag((0.1, 1), (6, 1)) # check if correct curve is selected @@ -131,6 +138,7 @@ def test_selection(self): self.widget.group_var = self.widget.data.domain['Group'] self.widget.on_controls_changed() + self.widget.graph.legend.hide() self.simulate_mouse_drag((0.1, 1), (6, 1)) # check if correct curve is selected @@ -163,6 +171,7 @@ def test_selection(self): self.assertEqual(0, len(self.widget.graph.selection)) # test selection of a second group + self.widget.graph.legend.hide() self.simulate_mouse_drag((0.4, 0.8), (6, 0.8)) # check if correct curve is selected