Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
grg2rsr committed Dec 4, 2024
1 parent 264a5ac commit d32a5e5
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 29 deletions.
15 changes: 11 additions & 4 deletions src/iblphotometry/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def from_dataframe(
if data_columns is None:
# this hacky parser currently deals with the inconsistency between carolinas and alejandros extraction
# https://github.com/int-brain-lab/ibl-photometry/issues/35
data_columns = [col for col in raw_df.columns if col.startswith('Region') or col.startswith('G')]
data_columns = [
col
for col in raw_df.columns
if col.startswith('Region') or col.startswith('G')
]

# infer name of time column if not provided
if time_column is None:
Expand Down Expand Up @@ -75,10 +79,11 @@ def from_dataframe(

return raw_dfs


def from_dataframes(raw_df: pd.DataFrame, locations_df: pd.DataFrame):
data_columns = (list(locations_df.index),)
rename = locations_df['brain_region'].to_dict()

read_config = dict(
data_columns=data_columns,
time_column='times',
Expand Down Expand Up @@ -125,7 +130,9 @@ def from_pqt(
return from_dataframe(raw_df, **read_config)


def from_raw_neurophotometrics_df(raw_df: pd.DataFrame, rois=None, drop_first=True) -> pd.DataFrame:
def from_raw_neurophotometrics_df(
raw_df: pd.DataFrame, rois=None, drop_first=True
) -> pd.DataFrame:
"""reads in parses the output of the neurophotometrics FP3002
Args:
Expand Down Expand Up @@ -251,6 +258,7 @@ def _validate_dataframe(

return schema_raw_data.validate(df)


def _validate_neurophotometrics_digital_inputs(df: pd.DataFrame) -> pd.DataFrame:
schema_digital_inputs = pandera.DataFrameSchema(
columns=dict(
Expand All @@ -262,4 +270,3 @@ def _validate_neurophotometrics_digital_inputs(df: pd.DataFrame) -> pd.DataFrame
)
)
return schema_digital_inputs.validate(df)

1 change: 0 additions & 1 deletion src/iblphotometry/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def _load_data_from_eid(self, eid: str, rename=True):

return raw_dfs


def _eid2pnames(self, eid: str):
session_path = self.one.eid2path(eid)
pnames = [reg.name for reg in session_path.joinpath('alf').glob('Region*')]
Expand Down
11 changes: 9 additions & 2 deletions src/iblphotometry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
import pandas as pd
from scipy import stats

from iblphotometry.processing import z, Regression, ExponDecay, detect_spikes, detect_outliers
from iblphotometry.processing import (
z,
Regression,
ExponDecay,
detect_spikes,
detect_outliers,
)
from iblphotometry.behavior import psth


def percentile_dist(A: pd.Series | np.ndarray, pc: tuple = (50, 95), axis=-1) -> float:
"""the distance between two percentiles in units of z. Captures the magnitude of transients.
Expand Down Expand Up @@ -65,7 +72,7 @@ def n_spikes(A: pd.Series | np.ndarray, sd: int = 5):
"""count the number of spike artifacts in the recording."""
a = A.values if isinstance(A, pd.Series) else A
return detect_spikes(a, sd=sd).shape[0]


def n_outliers(
A: pd.Series | np.ndarray, w_size: int = 1000, alpha: float = 0.0005
Expand Down
63 changes: 58 additions & 5 deletions src/iblphotometry/neurophotometrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
The light source map refers to the available LEDs on the system.
The flags refers to the byte encoding of led states in the system.
"""

LIGHT_SOURCE_MAP = {
'color': ['None', 'Violet', 'Blue', 'Green'],
'wavelength': [0, 415, 470, 560],
Expand All @@ -24,8 +25,60 @@
10: 'Input 0 signal HIGH + Stimulation',
11: 'Output 0 HIGH + Input 0 HIGH + Stimulation',
},
'No LED ON': {0: 0, 1: 8, 2: 16, 3: 32, 4: 64, 5: 128, 6: 256, 7: 512, 8: 48, 9: 528, 10: 544, 11: 560},
'L415': {0: 1, 1: 9, 2: 17, 3: 33, 4: 65, 5: 129, 6: 257, 7: 513, 8: 49, 9: 529, 10: 545, 11: 561},
'L470': {0: 2, 1: 10, 2: 18, 3: 34, 4: 66, 5: 130, 6: 258, 7: 514, 8: 50, 9: 530, 10: 546, 11: 562},
'L560': {0: 4, 1: 12, 2: 20, 3: 36, 4: 68, 5: 132, 6: 260, 7: 516, 8: 52, 9: 532, 10: 548, 11: 564}
}
'No LED ON': {
0: 0,
1: 8,
2: 16,
3: 32,
4: 64,
5: 128,
6: 256,
7: 512,
8: 48,
9: 528,
10: 544,
11: 560,
},
'L415': {
0: 1,
1: 9,
2: 17,
3: 33,
4: 65,
5: 129,
6: 257,
7: 513,
8: 49,
9: 529,
10: 545,
11: 561,
},
'L470': {
0: 2,
1: 10,
2: 18,
3: 34,
4: 66,
5: 130,
6: 258,
7: 514,
8: 50,
9: 530,
10: 546,
11: 562,
},
'L560': {
0: 4,
1: 12,
2: 20,
3: 36,
4: 68,
5: 132,
6: 260,
7: 516,
8: 52,
9: 532,
10: 548,
11: 564,
},
}
9 changes: 8 additions & 1 deletion src/iblphotometry/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import numpy as np
import pandas as pd
from iblphotometry.processing import remove_spikes, lowpass_bleachcorrect, isosbestic_correct, sliding_mad, zscore
from iblphotometry.processing import (
remove_spikes,
lowpass_bleachcorrect,
isosbestic_correct,
sliding_mad,
zscore,
)

import logging

logger = logging.getLogger()


def run_pipeline(
pipeline,
F_signal: pd.DataFrame,
Expand Down
44 changes: 31 additions & 13 deletions src/iblphotometry/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from copy import copy



# machine resolution
eps = np.finfo(np.float64).eps

Expand All @@ -34,6 +33,7 @@
## ####### ## ## ###### ## #### ####### ## ## ######
"""


def z(A: np.ndarray, mode='classic'):
"""classic z-score. Deviation from sample mean in units of sd
Expand All @@ -49,38 +49,41 @@ def z(A: np.ndarray, mode='classic'):
if mode == 'median':
return (A - np.median(A)) / np.std(A)


def mad(A: np.ndarray):
""" the MAD is defined as the median of the absolute deviations from the data's median
"""the MAD is defined as the median of the absolute deviations from the data's median
see https://en.wikipedia.org/wiki/Median_absolute_deviation
:param A: _description_
:type A: np.ndarray
:return: _description_
:rtype: _type_
"""
return np.median(np.absolute(A - np.median(A)), axis=-1)


def madscore(F: pd.Series):
# TODO overloading of mad?
y, t = F.values, F.index.values
return pd.Series(mad(y), index=t)


def zscore(F: pd.Series, mode='classic'):
y, t = F.values, F.index.values
# mu, sig = np.average(y), np.std(y)
return pd.Series(z(y, mode=mode), index=t)


def filt(F: pd.Series, N: int, Wn: float, fs: float | None = None, btype='low'):
""" a wrapper for scipy.signal.butter and sosfiltfilt
"""
"""a wrapper for scipy.signal.butter and sosfiltfilt"""
y, t = F.values, F.index.values
if fs is None:
fs = 1 / np.median(np.diff(t))
sos = signal.butter(N, Wn, btype, fs=fs, output='sos')
y_filt = signal.sosfiltfilt(sos, y)
return pd.Series(y_filt, index=t)


def sliding_rcoeff(signal_a, signal_b, nswin, overlap=0):
"""
Computes the local correlation coefficient between two signals in sliding windows
Expand Down Expand Up @@ -110,6 +113,7 @@ def sliding_rcoeff(signal_a, signal_b, nswin, overlap=0):
######## ####### ###### ###### ## ####### ## ## ###### ## #### ####### ## ## ######
"""


def mse_loss(p, x, y, fun):
# mean squared error
y_hat = fun(x, *p)
Expand Down Expand Up @@ -138,6 +142,7 @@ def irls_loss(p, x, y, fun, d=1e-7):
w = 1 / f
return np.sum(w * np.abs(a) ** 2) / y.shape[0]


"""
######## ######## ###### ######## ######## ###### ###### #### ####### ## ##
## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ### ##
Expand All @@ -151,6 +156,7 @@ def irls_loss(p, x, y, fun, d=1e-7):
# wrapper class for regressions with different loss functions
# following a sklearn style of .fit() and .predict()


class Regression:
def __init__(self, model=None, method: str = 'mse', method_params=None):
self.model = model
Expand Down Expand Up @@ -220,6 +226,7 @@ def predict(self, x: np.ndarray, return_type='numpy'):
if return_type == 'pandas':
return pd.Series(y_hat, index=x)


"""
######## ## ######## ### ###### ## ## ###### ####### ######## ######## ######## ###### ######## #### ####### ## ##
## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ### ##
Expand All @@ -230,10 +237,11 @@ def predict(self, x: np.ndarray, return_type='numpy'):
######## ######## ######## ## ## ###### ## ## ###### ####### ## ## ## ## ######## ###### ## #### ####### ## ##
"""


class BleachCorrection:
def __init__(
self,
model = None, # TODO bring back type checking
model=None, # TODO bring back type checking
regression_method: str = 'mse',
regression_params: dict = None,
correction_method: str = 'subtract',
Expand Down Expand Up @@ -262,7 +270,7 @@ def __init__(
def correct(self, F: pd.Series):
F_filt = filt(F, **self.filter_params)
return correct(F, F_filt, mode=self.correction_method)


class IsosbesticCorrection:
def __init__(
Expand Down Expand Up @@ -293,7 +301,10 @@ def correct(

return correct(F_ca, F_iso_fit, mode=self.correction_method)

def correct(signal: pd.Series, reference: pd.Series, mode: str = 'subtract') -> pd.Series:

def correct(
signal: pd.Series, reference: pd.Series, mode: str = 'subtract'
) -> pd.Series:
"""the main function that applies the correction of a signal with a reference. Correcions can be applied in 3 principle ways:
- The reference can be subtracted from the signal
- the signal can be divided by the reference
Expand Down Expand Up @@ -329,7 +340,6 @@ def correct(signal: pd.Series, reference: pd.Series, mode: str = 'subtract') ->


class AbstractModel(ABC):

@abstractmethod
def eq():
# the model equation
Expand Down Expand Up @@ -374,8 +384,10 @@ def calc_model_stats(self, y, y_hat, n_samples: int = -1, use_kde: bool = False)
aic = self._calc_aic(ll, k)
return dict(r_sq=r_sq, ll=ll, aic=aic)


# the actual models


class LinearModel(AbstractModel):
def eq(self, x, m, b):
return x * m + b
Expand Down Expand Up @@ -449,6 +461,7 @@ def est_p0(self, t: np.ndarray, y: np.ndarray):
b_est,
)


"""
###### ####### ######## ######## ######## ###### ######## #### ####### ## ## ######## ## ## ## ## ###### ######## #### ####### ## ## ######
## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ### ## ## ## ## ### ## ## ## ## ## ## ## ### ## ## ##
Expand All @@ -461,15 +474,18 @@ def est_p0(self, t: np.ndarray, y: np.ndarray):

# these are the convenience functions that are called in pipelines


def lowpass_bleachcorrect(F: pd.Series, **kwargs):
bc = LowpassBleachCorrection(**kwargs)
return bc.correct(F)


def exponential_bleachcorrect(F: pd.Series, **kwargs):
model = DoubleExponDecay()
ec = BleachCorrection(model, **kwargs)
return ec.correct(F)


def isosbestic_correct(F_sig: pd.DataFrame, F_ref: pd.DataFrame, **kwargs):
ic = IsosbesticCorrection(**kwargs)
return ic.correct(F_sig, F_ref)
Expand All @@ -485,7 +501,8 @@ def isosbestic_correct(F_sig: pd.DataFrame, F_ref: pd.DataFrame, **kwargs):
####### ####### ## ######## #### ######## ## ## ######## ######## ## ######## ###### ## #### ####### ## ##
"""

def _grubbs_single(y: np.ndarray, alpha: float =0.005, mode: str='median') -> bool:

def _grubbs_single(y: np.ndarray, alpha: float = 0.005, mode: str = 'median') -> bool:
# to apply a single pass of grubbs outlier detection
# see https://en.wikipedia.org/wiki/Grubbs%27s_test

Expand All @@ -505,7 +522,7 @@ def _grubbs_single(y: np.ndarray, alpha: float =0.005, mode: str='median') -> bo
return False


def grubbs_test(y: np.ndarray, alpha: float =0.005, mode:str='median'):
def grubbs_test(y: np.ndarray, alpha: float = 0.005, mode: str = 'median'):
# apply grubbs test iteratively until no more outliers are found
outliers = []
while _grubbs_single(y, alpha=alpha):
Expand Down Expand Up @@ -565,7 +582,9 @@ def fillnan_kde(y: np.ndarray, w: int = 25):
return y


def remove_outliers(F: pd.Series, w_size: int = 1000, alpha: float = 0.005, w: int = 25):
def remove_outliers(
F: pd.Series, w_size: int = 1000, alpha: float = 0.005, w: int = 25
):
y, t = F.values, F.index.values
y = copy(y)
outliers = detect_outliers(y, w_size=w_size, alpha=alpha)
Expand Down Expand Up @@ -620,7 +639,6 @@ def make_sliding_window(
method='stride_tricks',
warning=None,
):

"""use np.stride_tricks to make a sliding window view of a 1-d np.ndarray A
full overlap, step size 1
assumes 8 byte numbers (to be exposed? but needs to be tested first)
Expand Down
1 change: 1 addition & 0 deletions src/iblphotometry/qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

logger = logging.getLogger()


# %% # those could be in metrics
def sliding_metric(
F: pd.Series,
Expand Down
Loading

0 comments on commit d32a5e5

Please sign in to comment.