Skip to content

Commit

Permalink
Merge pull request #892 from int-brain-lab/ready4recording
Browse files Browse the repository at this point in the history
Ready4recording
  • Loading branch information
mayofaulkner authored Dec 17, 2024
2 parents cd94130 + 3b051e4 commit 5cf0561
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 104 deletions.
226 changes: 173 additions & 53 deletions brainbox/behavior/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def get_subject_training_status(subj, date=None, details=True, one=None):
if not trials:
return
sess_dates = list(trials.keys())
status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay)
status, info, _ = get_training_status(trials, task_protocol, ephys_sess, n_delay)

if details:
if np.any(info.get('psych')):
Expand Down Expand Up @@ -265,13 +265,13 @@ def get_sessions(subj, date=None, one=None):
if not np.any(np.array(task_protocol) == 'training'):
ephys_sess = one.alyx.rest('sessions', 'list', subject=subj,
date_range=[sess_dates[-1], sess_dates[0]],
django='json__PYBPOD_BOARD__icontains,ephys')
django='location__name__icontains,ephys')
if len(ephys_sess) > 0:
ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess]

n_delay = len(one.alyx.rest('sessions', 'list', subject=subj,
date_range=[sess_dates[-1], sess_dates[0]],
django='json__SESSION_START_DELAY_SEC__gte,900'))
django='json__SESSION_DELAY_START__gte,900'))
else:
ephys_sess_dates = []
n_delay = 0
Expand Down Expand Up @@ -313,23 +313,32 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay):

info = Bunch()
trials_all = concatenate_trials(trials)
info.session_dates = list(trials.keys())
info.protocols = [p for p in task_protocol]

# Case when all sessions are trainingChoiceWorld
if np.all(np.array(task_protocol) == 'training'):
signed_contrast = get_signed_contrast(trials_all)
signed_contrast = np.unique(get_signed_contrast(trials_all))
(info.perf_easy, info.n_trials,
info.psych, info.rt) = compute_training_info(trials, trials_all)
if not np.any(signed_contrast == 0):
status = 'in training'

pass_criteria, criteria = criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt,
signed_contrast)
if pass_criteria:
failed_criteria = Bunch()
failed_criteria['NBiased'] = {'val': info.protocols, 'pass': False}
failed_criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False}
status = 'trained 1b'
else:
if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt):
status = 'trained 1b'
elif criterion_1a(info.psych, info.n_trials, info.perf_easy):
failed_criteria = criteria
pass_criteria, criteria = criterion_1a(info.psych, info.n_trials, info.perf_easy, signed_contrast)
if pass_criteria:
status = 'trained 1a'
else:
failed_criteria = criteria
status = 'in training'

return status, info
return status, info, failed_criteria

# Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion
if ~np.all(np.array(task_protocol) == 'training') and \
Expand All @@ -338,45 +347,52 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay):
(info.perf_easy, info.n_trials,
info.psych, info.rt) = compute_training_info(trials, trials_all)

return status, info
criteria = Bunch()
criteria['NBiased'] = {'val': info.protocols, 'pass': False}
criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False}

return status, info, criteria

# Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions
if not np.any(np.array(task_protocol) == 'training'):

(info.perf_easy, info.n_trials,
info.psych_20, info.psych_80,
info.rt) = compute_bias_info(trials, trials_all)
# We are still on training rig and so all sessions should be biased
if len(ephys_sess_dates) == 0:
assert np.all(np.array(task_protocol) == 'biased')
if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy,
info.rt):
status = 'ready4ephysrig'
else:
status = 'trained 1b'

elif len(ephys_sess_dates) < 3:
n_ephys = len(ephys_sess_dates)
info.n_ephys = n_ephys
info.n_delay = n_delay

# Criterion recording
pass_criteria, criteria = criteria_recording(n_ephys, n_delay, info.psych_20, info.psych_80, info.n_trials,
info.perf_easy, info.rt)
if pass_criteria:
# Here the criteria doesn't actually fail but we have no other criteria to meet so we return this
failed_criteria = criteria
status = 'ready4recording'
else:
failed_criteria = criteria
assert all(date in trials for date in ephys_sess_dates)
perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in
ephys_sess_dates])
n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates])

if criterion_delay(n_ephys_trials, perf_ephys_easy):
status = 'ready4delay'
else:
status = 'ready4ephysrig'

elif len(ephys_sess_dates) >= 3:
if n_delay > 0 and \
criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy,
info.rt):
status = 'ready4recording'
elif criterion_delay(info.n_trials, info.perf_easy):
pass_criteria, criteria = criterion_delay(n_ephys, n_ephys_trials, perf_ephys_easy)

if pass_criteria:
status = 'ready4delay'
else:
status = 'ready4ephysrig'
failed_criteria = criteria
pass_criteria, criteria = criterion_ephys(info.psych_20, info.psych_80, info.n_trials,
info.perf_easy, info.rt)
if pass_criteria:
status = 'ready4ephysrig'
else:
failed_criteria = criteria
status = 'trained 1b'

return status, info
return status, info, failed_criteria


def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None,
Expand Down Expand Up @@ -814,7 +830,7 @@ def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='re
return reaction_time, contrasts, n_contrasts,


def criterion_1a(psych, n_trials, perf_easy):
def criterion_1a(psych, n_trials, perf_easy, signed_contrast):
"""
Returns bool indicating whether criteria for status 'trained_1a' are met.
Expand All @@ -825,6 +841,7 @@ def criterion_1a(psych, n_trials, perf_easy):
- Lapse rate on both sides is less than 0.2
- The total number of trials is greater than 200 for each session
- Performance on easy contrasts > 80% for all sessions
- Zero contrast trials must be present
Parameters
----------
Expand All @@ -835,24 +852,39 @@ def criterion_1a(psych, n_trials, perf_easy):
The number for trials for each session.
perf_easy : numpy.array of float
The proportion of correct high contrast trials for each session.
signed_contrast: numpy.array
Unique list of contrasts displayed
Returns
-------
bool
True if the criteria are met for 'trained_1a'.
Bunch
Bunch containing breakdown of the passing/ failing critieria
Notes
-----
The parameter thresholds chosen here were originally determined by averaging the parameter fits
for a number of sessions determined to be of 'good' performance by an experimenter.
"""

criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and
np.all(n_trials > 200) and np.all(perf_easy > 0.8))
return criterion
criteria = Bunch()
criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)}
criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.2}
criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.2}
criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 16}
criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 19}
criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 200)}
criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.8)}

passing = np.all([v['pass'] for k, v in criteria.items()])

criteria['Criteria'] = {'val': 'trained_1a', 'pass': passing}

def criterion_1b(psych, n_trials, perf_easy, rt):
return passing, criteria


def criterion_1b(psych, n_trials, perf_easy, rt, signed_contrast):
"""
Returns bool indicating whether criteria for trained_1b are met.
Expand All @@ -864,6 +896,7 @@ def criterion_1b(psych, n_trials, perf_easy, rt):
- The total number of trials is greater than 400 for each session
- Performance on easy contrasts > 90% for all sessions
- The median response time across all zero contrast trials is less than 2 seconds
- Zero contrast trials must be present
Parameters
----------
Expand All @@ -876,11 +909,15 @@ def criterion_1b(psych, n_trials, perf_easy, rt):
The proportion of correct high contrast trials for each session.
rt : float
The median response time for zero contrast trials.
signed_contrast: numpy.array
Unique list of contrasts displayed
Returns
-------
bool
True if the criteria are met for 'trained_1b'.
Bunch
Bunch containing breakdown of the passing/ failing critieria
Notes
-----
Expand All @@ -890,17 +927,27 @@ def criterion_1b(psych, n_trials, perf_easy, rt):
regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the
slope of the psychometric curve may be slightly less steep than 1a.
"""
criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and
np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2)
return criterion

criteria = Bunch()
criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)}
criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.1}
criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.1}
criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 10}
criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 20}
criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)}
criteria['Perf_tasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)}
criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2}

passing = np.all([v['pass'] for k, v in criteria.items()])

criteria['Criteria'] = {'val': 'trained_1b', 'pass': passing}

return passing, criteria


def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt):
"""
Returns bool indicating whether criteria for ready4ephysrig or ready4recording are met.
NB: The difference between these two is whether the sessions were acquired ot a recording rig
with a delay before the first trial. Neither of these two things are tested here.
Returns bool indicating whether criteria for ready4ephysrig are met.
Criteria
--------
Expand Down Expand Up @@ -929,21 +976,34 @@ def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt):
Returns
-------
bool
True if subject passes the ready4ephysrig or ready4recording criteria.
True if subject passes the ready4ephysrig criteria.
Bunch
Bunch containing breakdown of the passing/ failing critieria
"""
criteria = Bunch()
criteria['LapseLow_80'] = {'val': psych_80[2], 'pass': psych_80[2] < 0.1}
criteria['LapseHigh_80'] = {'val': psych_80[3], 'pass': psych_80[3] < 0.1}
criteria['LapseLow_20'] = {'val': psych_20[2], 'pass': psych_20[2] < 0.1}
criteria['LapseHigh_20'] = {'val': psych_20[3], 'pass': psych_20[3] < 0.1}
criteria['Bias_shift'] = {'val': psych_80[0] - psych_20[0], 'pass': psych_80[0] - psych_20[0] > 5}
criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)}
criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)}
criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2}

criterion = (np.all(np.r_[psych_20[2:4], psych_80[2:4]] < 0.1) and # lapse
psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and # bias shift and n trials
np.all(perf_easy > 0.9) and rt < 2) # overall performance and response times
return criterion
passing = np.all([v['pass'] for k, v in criteria.items()])

criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': passing}

def criterion_delay(n_trials, perf_easy):
return passing, criteria


def criterion_delay(n_ephys, n_trials, perf_easy):
"""
Returns bool indicating whether criteria for 'ready4delay' is met.
Criteria
--------
- At least one session on an ephys rig
- Total number of trials for any of the sessions is greater than 400
- Performance on easy contrasts is greater than 90% for any of the sessions
Expand All @@ -959,9 +1019,69 @@ def criterion_delay(n_trials, perf_easy):
-------
bool
True if subject passes the 'ready4delay' criteria.
Bunch
Bunch containing breakdown of the passing/ failing critieria
"""

criteria = Bunch()
criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys > 0}
criteria['N_trials'] = {'val': n_trials, 'pass': np.any(n_trials > 400)}
criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.any(perf_easy > 0.9)}

passing = np.all([v['pass'] for k, v in criteria.items()])

criteria['Criteria'] = {'val': 'ready4delay', 'pass': passing}

return passing, criteria


def criteria_recording(n_ephys, delay, psych_20, psych_80, n_trials, perf_easy, rt):
"""
criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9)
return criterion
Returns bool indicating whether criteria for ready4recording are met.
Criteria
--------
- At least 3 ephys sessions
- Delay on any session > 0
- Lapse on both sides < 0.1 for both bias blocks
- Bias shift between blocks > 5
- Total number of trials > 400 for all sessions
- Performance on easy contrasts > 90% for all sessions
- Median response time for zero contrast stimuli < 2 seconds
Parameters
----------
psych_20 : numpy.array
The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2.
Parameters are bias, threshold, lapse high, lapse low.
psych_80 : numpy.array
The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8.
Parameters are bias, threshold, lapse high, lapse low.
n_trials : numpy.array
The number of trials for each session (typically three consecutive sessions).
perf_easy : numpy.array
The proportion of correct high contrast trials for each session (typically three
consecutive sessions).
rt : float
The median response time for zero contrast trials.
Returns
-------
bool
True if subject passes the ready4recording criteria.
Bunch
Bunch containing breakdown of the passing/ failing critieria
"""

_, criteria = criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt)
criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys >= 3}
criteria['N_delay'] = {'val': delay, 'pass': delay > 0}

passing = np.all([v['pass'] for k, v in criteria.items()])

criteria['Criteria'] = {'val': 'ready4recording', 'pass': passing}

return passing, criteria


def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs):
Expand Down
Loading

0 comments on commit 5cf0561

Please sign in to comment.