From d954f00528e06a4063a4fd1c7c58b8453317b4fd Mon Sep 17 00:00:00 2001 From: Jared Galloway Date: Sun, 13 Aug 2023 11:31:25 -0700 Subject: [PATCH] reset index, and add_pheno_to_df index check (#117) Closes #116 --- CHANGELOG.rst | 6 ++++++ multidms/data.py | 6 +++--- multidms/model.py | 15 +++++++++------ 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 302f378..9e44da5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,12 @@ All notable changes to this project will be documented in this file. The format is based on `Keep a Changelog `_. +main/HEAD +--------- +- Fixed a `bug `_ + caused by non-unique indicies in input variant functional score dataframes. + + 0.2.1 ----- - Made lineplot_and_heatmap() more private to remove from docs. diff --git a/multidms/data.py b/multidms/data.py index 8a21935..b284696 100644 --- a/multidms/data.py +++ b/multidms/data.py @@ -205,7 +205,7 @@ def __init__( self._conditions = tuple(variants_df["condition"].astype(str).unique()) if str(reference) not in self._conditions: - if type(reference) != str: + if isinstance(reference, str): raise ValueError( "reference must be a string, note that if your " "condition names are numeric, they are being" @@ -223,7 +223,7 @@ def __init__( raise ValueError("not enough `condition_colors`") else: self.condition_colors = dict(zip(self._conditions, condition_colors)) - if not onp.all([type(c) == str for c in self.condition_colors.values()]): + if not onp.all([isinstance(c, str) for c in self.condition_colors.values()]): raise ValueError("condition_color values must be hexidecimal") # Check and initialize alphabet & mut parser attributes @@ -259,7 +259,7 @@ def __init__( ) else: - df = variants_df.copy() + df = variants_df[cols].reset_index(drop=True) self._split_subs = partial(split_subs, parser=self._mutparser.parse_mut) df["wts"], df["sites"], df["muts"] = zip( diff --git a/multidms/model.py b/multidms/model.py index ff032c1..d94b430 100644 --- a/multidms/model.py +++ b/multidms/model.py @@ -280,12 +280,12 @@ def __init__( raise ValueError( "softplus activation requires a lower bound be specified" ) - elif type(lower_bound) != float: + if not isinstance(lower_bound, float): raise ValueError("lower_bound must be a float") - else: - output_activation = partial( - multidms.biophysical.softplus_activation, lower_bound=lower_bound - ) + + output_activation = partial( + multidms.biophysical.softplus_activation, lower_bound=lower_bound + ) for condition in data.conditions: self._params[f"gamma_{condition}"] = jnp.zeros(shape=(1,)) @@ -493,7 +493,8 @@ def add_phenotypes_to_df( ---------- df : pandas.DataFrame Data frame containing variants. Requirements are the same as - those used to initialize the `multidms.Data` object + those used to initialize the `multidms.Data` object - except + the indices must be unique. substitutions_col : str Column in `df` giving variants as substitution strings with respect to a given variants condition. @@ -537,6 +538,8 @@ def add_phenotypes_to_df( raise ValueError("`df` lacks `substitutions_col` " f"{substitutions_col}") if condition_col not in df.columns: raise ValueError("`df` lacks `condition_col` " f"{condition_col}") + if not df.index.is_unique: + raise ValueError("`df` must have unique indices") # return copy ret = df.copy()