Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
svittoz committed Jun 14, 2024
1 parent b8d0716 commit bb06576
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 37 deletions.
42 changes: 5 additions & 37 deletions eds_scikit/period/stays.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from eds_scikit.utils.checks import MissingConceptError, algo_checker, concept_checker
from eds_scikit.utils.datetime_helpers import substract_datetime
from eds_scikit.utils.framework import get_framework
from eds_scikit.utils.sort_values_first import sort_values_first
from eds_scikit.utils.typing import DataFrame


Expand Down Expand Up @@ -292,44 +293,11 @@ def get_first(
how="inner",
)

# Getting the corresponding first visit
# Replacement for :
# first_visit = merged.sort_values(by=[flag_name, "visit_start_datetime_1"],
# ascending=[False, False])
# .groupby(visit_occurrence_id_2).first()["visit_occurrence_id_1"]
# which is not deterministic in Koalas

flagged = (
merged[merged[flag_name]]
.groupby("visit_occurrence_id_2", as_index=False)[
["visit_start_datetime_1"]
]
.max()
first_visit = sort_values_first(
merged,
by_cols=["visit_occurrence_id_2"],
cols=[flag_name, "visit_start_datetime_1", "visit_occurrence_id_1"],
)
flagged = merged[merged[flag_name]].merge(
flagged, on=["visit_occurrence_id_2", "visit_start_datetime_1"], how="right"
)
flagged["flagged"] = True
unflagged = (
merged[~merged[flag_name]]
.groupby("visit_occurrence_id_2", as_index=False)[
["visit_start_datetime_1"]
]
.max()
)
unflagged = merged[~merged[flag_name]].merge(
unflagged,
on=["visit_occurrence_id_2", "visit_start_datetime_1"],
how="right",
)
unflagged = unflagged.merge(
flagged[["visit_occurrence_id_2", "flagged"]],
on="visit_occurrence_id_2",
how="left",
)
unflagged = unflagged[unflagged.flagged.isna()]
first_visit = fw.concat((flagged, unflagged), axis=0)

first_visit = first_visit.rename(
columns={
"visit_occurrence_id_1": f"{concept_prefix}STAY_ID",
Expand Down
8 changes: 8 additions & 0 deletions eds_scikit/utils/sort_values_first.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from eds_scikit.utils.typing import DataFrame
from typing import List

def sort_values_first(df : DataFrame,
by_cols : List[str],
cols : List[str],
ascending : bool = False):
return df.groupby(by_cols).apply(lambda group: group.sort_values(by=cols, ascending=[ascending for i in cols]).head(1)).reset_index(drop=True)
40 changes: 40 additions & 0 deletions tests/test_sort_values_first.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pandas as pd
import pytest
from eds_scikit.utils import framework
from eds_scikit.utils.sort_values_first import sort_values_first
from eds_scikit.utils.test_utils import assert_equal_no_order

from databricks import koalas as ks
import numpy as np

# Create a DataFrame
np.random.seed(0)
size=10000
data = {
'A': np.random.choice(['X', 'Y', 'Z'], size),
'B': np.random.randint(1, 5, size),
'C': np.random.randint(1, 5, size),
'D': np.random.randint(1, 5, size),
'E': np.random.randint(1, 5, size)
}

inputs = pd.DataFrame(data)
inputs.loc[0, 'B'] = 0
inputs.loc[0, 'C'] = 4

@pytest.mark.parametrize(
"module",
["pandas", "koalas"],
)
def test_sort_values_first(module):

inputs_fr = framework.to(module, inputs)

computed = framework.pandas(sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=True))
expected = inputs.sort_values(["B", "C"], ascending=True).groupby("A", as_index=False).first()
assert_equal_no_order(computed, expected)

computed = framework.pandas(sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=False))
expected = inputs.sort_values(["B", "C"], ascending=False).groupby("A", as_index=False).first()
assert_equal_no_order(computed, expected)

0 comments on commit bb06576

Please sign in to comment.