Skip to content

Commit

Permalink
Merge pull request #14 from biolab/kaplan-meier/plot_legend
Browse files Browse the repository at this point in the history
[ENH] owkaplanmeier: add more information to the plot legend
  • Loading branch information
JakaKokosar authored Mar 12, 2021
2 parents dbe833b + 31adfe2 commit 850a5c2
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 13 deletions.
79 changes: 70 additions & 9 deletions orangecontrib/survival_analysis/widgets/owkaplanmeier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from typing import Dict, List, Optional, NamedTuple
from itertools import zip_longest
from xml.sax.saxutils import escape

from AnyQt.QtGui import QBrush, QColor, QPainterPath
from AnyQt.QtCore import Qt, QSize
from AnyQt.QtCore import pyqtSignal as Signal
from pyqtgraph.functions import mkPen
from pyqtgraph.graphicsItems.ViewBox import ViewBox
from pyqtgraph.graphicsItems.LegendItem import ItemSample, LabelItem
from lifelines import KaplanMeierFitter
from lifelines.utils import median_survival_times

Expand Down Expand Up @@ -44,7 +44,7 @@ def generate_curve_coordinates(timeline, probabilities):
def __init__(self, time, events, label=None, color=None):
self._kmf = KaplanMeierFitter().fit(time.astype(np.float64), events.astype(np.float64))

self.label: str = label
self._label: str = label
self.color: List[int] = color

# refactor this
Expand All @@ -68,7 +68,9 @@ def __init__(self, time, events, label=None, color=None):
self.selection = pg.PlotDataItem(pen=mkPen(color=QColor(Qt.yellow), width=4))
self.selection.hide()

self.median_survival = median = median_survival_times(self._kmf.survival_function_.astype(np.float32))
self.median_survival = median = np.round(
median_survival_times(self._kmf.survival_function_.astype(np.float32)), 1
)
self.median_vertical = pg.PlotDataItem(x=(median, median), y=(0, 0.5), pen=MEDIAN_LINE_PEN)

censored_data = self.get_censored_data()
Expand All @@ -83,13 +85,20 @@ def __init__(self, time, events, label=None, color=None):
)
self.censored_data.setZValue(10)

self.num_of_samples = len(events)
self.num_of_censored_samples = len(censored_data)

@property
def label(self):
return self._label if self._label else 'All'

def get_censored_data(self):
time_events = np.column_stack((self._kmf.durations, self._kmf.event_observed))
censored_time = time_events[np.argwhere(time_events[:, 1] == 0), 0]
survival = self._kmf.survival_function_.values
return np.column_stack((censored_time, survival[np.where(censored_time == self._kmf.timeline)[1]]))

def get_color(self, alpha) -> QColor:
def get_color(self, alpha=255) -> QColor:
color = QColor(*self.color) if self.color else QColor(Qt.darkGray)
color.setAlpha(alpha)
return color
Expand Down Expand Up @@ -156,6 +165,58 @@ def mouseDragEvent(self, ev, axis=None):
ev.ignore()


class HLineItemSample(ItemSample):
def __init__(self):
super().__init__(None)
self.pen = pg.mkPen(color=QColor(Qt.darkGray), width=2)

def paint(self, p, *args):
p.setPen(self.pen)
p.drawLine(0, 0, int(self.width()), 0)


class CustomLegendItem(LegendItem):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.layout.setVerticalSpacing(3)
self._sample_column_label = 'n/N'
self._median_column_label = 'Median'

def set_header(self):
samples = LabelItem(self._sample_column_label)
median = LabelItem(self._median_column_label)
separator = HLineItemSample()

self.layout.addItem(samples, 1, 2)
self.layout.addItem(median, 1, 3)
self.items.append(samples)
self.items.append(median)

row = self.layout.rowCount()
self.layout.addItem(separator, row, 1, row, 3)
self.items.append(separator)

def set_curve(self, curve: EstimatedFunctionCurve):
curve_label = LabelItem(curve.label, color=curve.get_color())
samples = LabelItem(f'{curve.num_of_samples - curve.num_of_censored_samples}/{curve.num_of_samples}')
median = LabelItem(f'{curve.median_survival}')

row = self.layout.rowCount()
self.items.append(curve_label)
self.items.append(samples)
self.items.append(median)

self.layout.addItem(curve_label, row, 1)
self.layout.addItem(samples, row, 2)
self.layout.addItem(median, row, 3)

def clear(self):
for item in self.items:
self.layout.removeItem(item)
self.items = []
self.updateSize()


class KaplanMeierPlot(gui.OWComponent, pg.PlotWidget):
HIGHLIGHT_RADIUS = 20 # in pixels
selection_changed = Signal()
Expand All @@ -179,7 +240,7 @@ def __init__(self, parent=None):
)
self.view_box.selection_changed.connect(self.on_selection_changed)

self.legend = LegendItem()
self.legend = CustomLegendItem()
self.legend.setParentItem(self.getViewBox())
self.legend.restoreAnchor(((1, 0), (1, 0)))

Expand Down Expand Up @@ -321,10 +382,10 @@ def update_plot(self, confidence_interval=False, median=False, censored=False):
def update_legend(self):
self.legend.hide()

for curve in [curve for curve in self.curves.values() if curve.color and curve.label]:
c = QColor(*curve.color)
dot = pg.ScatterPlotItem(pen=c, brush=c, size=10, symbol='s')
self.legend.addItem(dot, escape(curve.label))
self.legend.set_header()
for curve in [curve for curve in self.curves.values()]:
self.legend.set_curve(curve)
self.legend.updateSize()

if bool(len(self.legend.items)):
self.legend.show()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pyqtgraph as pg
from pyqtgraph.tests import mouseMove, mousePress, mouseRelease, mouseClick
from pyqtgraph.graphicsItems.LegendItem import LabelItem
from pyqtgraph.Qt import QtTest

from Orange.data.table import Table, Domain, StringVariable, ContinuousVariable, DiscreteVariable
from orangewidget.tests.base import WidgetTest
from orangecontrib.survival_analysis.widgets.owkaplanmeier import OWKaplanMeier
from pyqtgraph.Qt import QtTest


from AnyQt.QtCore import Qt

Expand Down Expand Up @@ -81,11 +83,15 @@ def test_group_variable(self):
self.assertTrue(len(items) == 4)

def test_legend(self):
self.assertFalse(self.widget.graph.legend.items)
legend = tuple(label.text for label in self.widget.graph.legend.items if isinstance(label, LabelItem))
self.assertIn('All', legend)

self.widget.group_var = self.widget.data.domain['Group']
self.widget.on_controls_changed()
legend_text = tuple(label.text for _, label in self.widget.graph.legend.items)
self.assertEqual(self.widget.group_var.values, legend_text)

legend = tuple(label.text for label in self.widget.graph.legend.items if isinstance(label, LabelItem))
for group in self.widget.group_var.values:
self.assertIn(group, legend)

def test_curve_highlight(self):
self.widget.group_var = self.widget.data.domain['Group']
Expand All @@ -106,6 +112,7 @@ def test_selection(self):
selected_data = self.get_output(self.widget.Outputs.selected_data)
self.assertIsNone(selected_data)

self.widget.graph.legend.hide()
self.simulate_mouse_drag((0.1, 1), (6, 1))

# check if correct curve is selected
Expand All @@ -131,6 +138,7 @@ def test_selection(self):
self.widget.group_var = self.widget.data.domain['Group']
self.widget.on_controls_changed()

self.widget.graph.legend.hide()
self.simulate_mouse_drag((0.1, 1), (6, 1))

# check if correct curve is selected
Expand Down Expand Up @@ -163,6 +171,7 @@ def test_selection(self):
self.assertEqual(0, len(self.widget.graph.selection))

# test selection of a second group
self.widget.graph.legend.hide()
self.simulate_mouse_drag((0.4, 0.8), (6, 0.8))

# check if correct curve is selected
Expand Down

0 comments on commit 850a5c2

Please sign in to comment.