Skip to content

Commit

Permalink
Merge pull request #289 from RichieHakim/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
RichieHakim authored Aug 10, 2024
2 parents 115a3d3 + 799c9e6 commit 183f471
Show file tree
Hide file tree
Showing 11 changed files with 344 additions and 119 deletions.
174 changes: 104 additions & 70 deletions notebooks/jupyter/tracking/tracking_interactive_notebook.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
hdbscan==0.8.37
hdbscan==0.8.38.post1
holoviews[recommended]==1.19.1
jupyter==1.0.0
kymatio==0.3.0
matplotlib==3.9.1
matplotlib==3.9.1.post1
natsort==8.4.0
numpy==1.26.4
opencv_contrib_python_headless<=4.10.0.84
Expand All @@ -20,8 +20,8 @@ bokeh==3.4.1
psutil==6.0.0
py_cpuinfo==9.0.0
GPUtil==1.4.0
PyYAML==6.0.1
mat73==0.63
PyYAML==6.0.2
mat73==0.65
torch==2.4.0
torchvision==0.19.0
torchaudio==2.4.0
Expand All @@ -30,4 +30,4 @@ skl2onnx==1.17.0
onnx==1.16.2
onnxruntime==1.18.1
jupyter_bokeh==4.0.5
onnx2torch==1.5.14
onnx2torch==1.5.15
2 changes: 1 addition & 1 deletion roicat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
for pkg in __all__:
exec('from . import ' + pkg)

__version__ = '1.1.39'
__version__ = '1.2.0'
13 changes: 12 additions & 1 deletion roicat/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,10 @@ class Convergence_checker_optuna:
max_duration (float):
Maximum number of seconds to run before stopping.
(Default is *600*)
value_stop (Optional[float]):
Value at which to stop the optimization. If the best value is equal
to or less than this value, the optimization will stop.
(Default is *None*)
verbose (bool):
If ``True``, print messages.
(Default is ``True``)
Expand Down Expand Up @@ -552,6 +556,7 @@ def __init__(
tol_frac: float = 0.05,
max_trials: int = 350,
max_duration: float = 60*10,
value_stop: Optional[float] = None,
verbose: bool = True,
):
"""
Expand All @@ -563,6 +568,7 @@ def __init__(
self.tol_frac = tol_frac
self.max_trials = max_trials
self.max_duration = max_duration
self.value_stop = value_stop
self.num_trial = 0
self.verbose = verbose

Expand Down Expand Up @@ -601,6 +607,11 @@ def check(
if duration > self.max_duration:
print(f'Stopping. Duration limit reached. study.duration={duration}, max_duration={self.max_duration}.') if self.verbose else None
study.stop()

if self.value_stop is not None:
if self.best <= self.value_stop:
print(f'Stopping. Best value ({self.best}) is less than or equal to value_stop ({self.value_stop}).') if self.verbose else None
study.stop()

if self.verbose:
print(f'Trial num: {self.num_trial}. Duration: {duration:.3f}s. Best value: {self.best:3e}. Current value:{trial.value:3e}') if self.verbose else None
Expand Down Expand Up @@ -1642,7 +1653,7 @@ def yaml_save(
"""
path = prepare_filepath_for_saving(filepath, mkdir=mkdir, allow_overwrite=allow_overwrite)
with open(path, mode) as f:
yaml.dump(obj, f, indent=indent)
yaml.dump(obj, f, indent=indent, sort_keys=False)

def yaml_load(
filepath: str,
Expand Down
119 changes: 104 additions & 15 deletions roicat/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import copy
import tempfile
from IPython.display import display
import time

# import matplotlib.pyplot as plt
import numpy as np

## Import roicat submodules
from . import data_importing, ROInet, helpers, util, tracking, classification
from . import data_importing, ROInet, helpers, util, visualization, tracking, classification

def pipeline_tracking(params: dict):
"""
Expand All @@ -32,6 +33,8 @@ def pipeline_tracking(params: dict):
Parameters used in the pipeline. See
``roicat.helpers.prepare_params()`` for details.
"""
## Start timer
tic_start = time.time()

## Prepare params
defaults = util.get_default_parameters(pipeline='tracking')
Expand All @@ -56,15 +59,13 @@ def pipeline_tracking(params: dict):
find_folders=False,
natsorted=True,
)[:]
paths_allOps = np.array([Path(path).resolve().parent / 'ops.npy' for path in paths_allStat])[:]
paths_allOps = [str(Path(path).resolve().parent / 'ops.npy') for path in paths_allStat][:]

print(f"Found the following stat.npy files:")
[print(f" {path}") for path in paths_allStat]
print(f"Found the following corresponding ops.npy files:")
[print(f" {path}") for path in paths_allOps]

params['data_loading']['paths_allStat'] = paths_allStat
params['data_loading']['paths_allOps'] = paths_allOps

## Import data
data = data_importing.Data_suite2p(
Expand Down Expand Up @@ -190,13 +191,19 @@ def pipeline_tracking(params: dict):
s_sesh=sim.s_sesh,
verbose=VERBOSE,
)
kwargs_makeConjunctiveDistanceMatrix_best = clusterer.find_optimal_parameters_for_pruning(
seed=SEED,
**params['clustering']['automatic_mixing'],
)
kwargs_mcdm_tmp = kwargs_makeConjunctiveDistanceMatrix_best ## Use the optimized parameters
if params['clustering']['mixing_method'] == 'automatic':
kwargs_makeConjunctiveDistanceMatrix_best = clusterer.find_optimal_parameters_for_pruning(
seed=SEED,
**params['clustering']['parameters_automatic_mixing'],
)
elif params['clustering']['mixing_method'] == 'manual':
kwargs_makeConjunctiveDistanceMatrix_best = params['clustering']['parameters_manual_mixing']
else:
## Not implemented
raise NotImplementedError(f"Mixing method '{params['clustering']['mixing_method']}' is not implemented. Select from: ['automatic', 'manual']")

clusterer.make_pruned_similarity_graphs(
kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp,
kwargs_makeConjunctiveDistanceMatrix=kwargs_makeConjunctiveDistanceMatrix_best,
**params['clustering']['pruning'],
)

Expand Down Expand Up @@ -267,7 +274,7 @@ def choose_clustering_method(method='automatic', n_sessions_switch=8, n_sessions
})


## Visualize results
## Print some results
print(f'Number of clusters: {len(np.unique(results["clusters"]["labels"]))}')
print(f'Number of discarded ROIs: {(results["clusters"]["labels"]==-1).sum()}')

Expand All @@ -276,14 +283,13 @@ def choose_clustering_method(method='automatic', n_sessions_switch=8, n_sessions
if params['results_saving']['dir_save'] is not None:

dir_save = Path(params['results_saving']['dir_save']).resolve()
name_save = params['results_saving']['prefix_name_save']
name_save = str(params['results_saving']['prefix_name_save'])

path_save = dir_save / (name_save + '.ROICaT.tracking.results' + '.pkl')
print(f'path_save: {path_save}')
print(f'dir_save: {dir_save}')

helpers.pickle_save(
obj=results,
filepath=path_save,
filepath=str(dir_save / (name_save + '.ROICaT.tracking.results' + '.pkl')),
mkdir=True,
)

Expand All @@ -292,6 +298,89 @@ def choose_clustering_method(method='automatic', n_sessions_switch=8, n_sessions
filepath=str(dir_save / (name_save + '.ROICaT.tracking.rundata' + '.pkl')),
mkdir=True,
)

helpers.yaml_save(
obj=params,
filepath=str(dir_save / (name_save + '.ROICaT.tracking.params' + '.yaml')),
mkdir=True,
)


## Visualize results
### Save some figures

#### Save FOV_images as .png files
def save_image(array, path, normalize=True):
## Use PIL to save the image
from PIL import Image
Path(path).parent.mkdir(parents=True, exist_ok=True)
Image.fromarray((np.array(array / array.max() if normalize else array) * 255).astype(np.uint8)).save(path)
[save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'FOV_images' / f'FOV_images_{ii}.png') ) for ii, array in enumerate(data.FOV_images)]
[save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'FOV_images_aligned_geometric' / f'FOV_images_aligned_geometric_{ii}.png') ) for ii, array in enumerate(aligner.ims_registered_geo)]
[save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'FOV_images_aligned_nonrigid' / f'FOV_images_aligned_nonrigid_{ii}.png') ) for ii, array in enumerate(aligner.ims_registered_nonrigid)]
[save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'ROIs' / f'ROIs_{ii}.png') ) for ii, array in enumerate(data.get_maxIntensityProjection_spatialFootprints())]
[save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'ROIs_aligned' / f'ROIs_aligned_{ii}.png') ) for ii, array in enumerate(aligner.get_ROIsAligned_maxIntensityProjection(normalize=True))]
[save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'ROIs_aligned_blurred' / f'ROIs_aligned_blurred_{ii}.png') ) for ii, array in enumerate(blurrer.get_ROIsBlurred_maxIntensityProjection())]

#### Save some sample ROI images
[save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'ROIs_sample' / f'ROIs_sample_{ii}.png') ) for ii, array in enumerate(roinet.ROI_images_rs[:100])]

#### Save the similarity graphy blocks
fig = sim.visualize_blocks()
(Path(dir_save).resolve() / 'visualization' / 'similarity_graph').mkdir(parents=True, exist_ok=True)
fig.savefig(str(Path(dir_save).resolve() / 'visualization' / 'similarity_graph' / 'blocks.png'))

#### Save the similarity / distance plots for the given conjunctive distance matrix kwargs
fig = clusterer.plot_distSame(kwargs_makeConjunctiveDistanceMatrix=kwargs_makeConjunctiveDistanceMatrix_best)
(Path(dir_save).resolve() / 'visualization' / 'clustering').mkdir(parents=True, exist_ok=True)
fig.savefig(str(Path(dir_save).resolve() / 'visualization' / 'clustering' / 'dist.png'))
fig, axs = clusterer.plot_similarity_relationships(
plots_to_show=[1,2,3],
max_samples=100000, ## Make smaller if it is running too slow
kwargs_scatter={'s':1, 'alpha':0.2},
kwargs_makeConjunctiveDistanceMatrix=kwargs_makeConjunctiveDistanceMatrix_best
)
(Path(dir_save).resolve() / 'visualization' / 'clustering').mkdir(parents=True, exist_ok=True)
fig.savefig(str(Path(dir_save).resolve() / 'visualization' / 'clustering' / 'similarity_relationships.png'))

#### Save the clustering results
fig, axs = tracking.clustering.plot_quality_metrics(
quality_metrics=quality_metrics,
labels=labels_squeezed,
n_sessions=data.n_sessions,
)
(Path(dir_save).resolve() / 'visualization' / 'clustering').mkdir(parents=True, exist_ok=True)
fig.savefig(str(Path(dir_save).resolve() / 'visualization' / 'clustering' / 'quality_metrics.png'))

### Save a gif of the ROIs
FOV_clusters = visualization.compute_colored_FOV(
spatialFootprints=[r.power(1.0) for r in results['ROIs']['ROIs_aligned']], ## Spatial footprint sparse arrays
FOV_height=results['ROIs']['frame_height'],
FOV_width=results['ROIs']['frame_width'],
labels=results["clusters"]["labels_bySession"], ## cluster labels
# labels=(np.array(results["clusters"]["labels"])!=-1).astype(np.int64), ## cluster labels
# alphas_labels=confidence*1.5, ## Set brightness of each cluster based on some 1-D array
# alphas_labels=(clusterer.quality_metrics['cluster_silhouette'] > 0) * (clusterer.quality_metrics['cluster_intra_means'] > 0.4),
# alphas_sf=clusterer.quality_metrics['sample_silhouette'], ## Set brightness of each ROI based on some 1-D array
)
helpers.save_gif(
array=helpers.add_text_to_images(
images=[(f * 255).astype(np.uint8) for f in FOV_clusters],
text=[[f"{ii}",] for ii in range(len(FOV_clusters))],
font_size=3,
line_width=10,
position=(30, 90),
),
path=str(Path(dir_save).resolve() / 'visualization' / 'FOV_clusters.gif'),
frameRate=10.0,
loop=0,
)



## End timer
tic_end = time.time()
print(f"Elapsed time: {tic_end - tic_start:.2f} seconds")

return results, run_data, params

Expand Down
6 changes: 5 additions & 1 deletion roicat/tracking/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def get_ROIsAligned_maxIntensityProjection(
self,
H: Optional[int] = None,
W: Optional[int] = None,
normalize: bool = True,
) -> List[np.ndarray]:
"""
Returns the max intensity projection of the ROIs aligned to the template
Expand All @@ -613,6 +614,9 @@ def get_ROIsAligned_maxIntensityProjection(
W (Optional[int]):
The width of the output projection. If not provided and if not
already set, an error will be thrown. (Default is ``None``)
normalize (bool):
If ``True``, the ROIs are normalized by the maximum value.
(Default is ``True``)
Returns:
(List[np.ndarray]):
Expand All @@ -622,7 +626,7 @@ def get_ROIsAligned_maxIntensityProjection(
if H is None:
assert self._HW is not None, 'H and W must be provided if not already set.'
H, W = self._HW
return [rois.max(0).toarray().reshape(H, W) for rois in self.ROIs_aligned]
return [(rois.multiply(rois.max(1).power(-1)) if normalize else rois).max(0).toarray().reshape(H, W) for rois in self.ROIs_aligned]

def get_flowFields(
self,
Expand Down
57 changes: 44 additions & 13 deletions roicat/tracking/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def find_optimal_parameters_for_pruning(
'tol_frac': 0.05,
'max_trials': 350,
'max_duration': 60*10,
'value_stop': 0.0,
},
bounds_findParameters: Dict[str, Tuple[float, float]] = {
'power_SF': (0.3, 2),
Expand Down Expand Up @@ -781,7 +782,7 @@ def make_conjunctive_distance_matrix(
def _activation_function(
cls,
s: Optional[torch.Tensor] = None,
sig_kwargs: Dict[str, float] = {'mu':0.0, 'b':1.0},
sig_kwargs: Optional[Dict[str, float]] = {'mu':0.0, 'b':1.0},
power: Optional[float] = 1
) -> Optional[torch.Tensor]:
"""
Expand All @@ -806,15 +807,13 @@ def _activation_function(
"""
if s is None:
return None
if (sig_kwargs is not None) and (power is not None):
return helpers.generalised_logistic_function(torch.as_tensor(s, dtype=torch.float32), **sig_kwargs)**power
elif (sig_kwargs is None) and (power is not None):
return torch.maximum(torch.as_tensor(s, dtype=torch.float32), torch.as_tensor([0], dtype=torch.float32))**power
# return torch.as_tensor(s, dtype=torch.float32)**power
elif (sig_kwargs is not None) and (power is None):
return helpers.generalised_logistic_function(torch.as_tensor(s, dtype=torch.float32), **sig_kwargs)
else:
return torch.as_tensor(s, dtype=torch.float32)

s = torch.as_tensor(s, dtype=torch.float32)
## make functions such that if the param is None, then no operation is applied
fn_sigmoid = lambda x, params: helpers.generalised_logistic_function(x, **params) if params is not None else x
fn_power = lambda x, p: x ** p if p is not None else x

return fn_power(torch.clamp(fn_sigmoid(s, sig_kwargs), min=0), power)

@classmethod
def _pNorm(
Expand Down Expand Up @@ -948,7 +947,7 @@ def plot_distSame(self, kwargs_makeConjunctiveDistanceMatrix: Optional[dict] = N
print('No crossover found, not plotting')
return None

plt.figure()
fig = plt.figure()
plt.stairs(dens_same, edges, linewidth=5)
plt.stairs(dens_same_crop, edges, linewidth=3)
plt.stairs(dens_diff, edges)
Expand All @@ -962,6 +961,7 @@ def plot_distSame(self, kwargs_makeConjunctiveDistanceMatrix: Optional[dict] = N
plt.xlabel('distance or prob(different)')
plt.ylabel('counts or density')
plt.legend(['same', 'same (cropped)', 'diff', 'all', 'diff - same', 'all - diff', '(diff * same) * 1000', 'crossover'])
return fig

def _separate_diffSame_distributions(
self,
Expand Down Expand Up @@ -1166,7 +1166,12 @@ def compute_quality_metrics(
d_dense.fill_value = (dist_mat.data.max() - dist_mat.data.min()).astype(np.float16) * 10
d_dense = d_dense.todense()
np.fill_diagonal(d_dense, 0)
rs_sil = sklearn.metrics.silhouette_samples(X=d_dense, labels=labels, metric='precomputed')
## Number of labels must be at least 2
if len(np.unique(labels)) < 2:
warnings.warn(f"Silhouette samples calculation requires at least 2 labels. Returning None. Found {len(np.unique(labels))} labels.")
rs_sil = None
else:
rs_sil = sklearn.metrics.silhouette_samples(X=d_dense, labels=labels, metric='precomputed')

self.quality_metrics = {
'cluster_labels_unique': labels_unique,
Expand Down Expand Up @@ -1449,4 +1454,30 @@ def make_label_variants(
assert np.allclose(labels_bool.nonzero()[1] - 1, labels_squeezed)
assert np.all([np.allclose(np.where(labels_squeezed==u)[0], ldu) for u, ldu in labels_dict.items()])

return labels_squeezed, labels_bySession, labels_bool, labels_bool_bySession, labels_dict
return labels_squeezed, labels_bySession, labels_bool, labels_bool_bySession, labels_dict


def plot_quality_metrics(quality_metrics: dict, labels: Union[np.ndarray, list], n_sessions: int) -> None:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,7))

axs[0,0].hist(quality_metrics['cluster_silhouette'], 50);
axs[0,0].set_xlabel('cluster_silhouette');
axs[0,0].set_ylabel('cluster counts');

axs[0,1].hist(quality_metrics['cluster_intra_means'], 50);
axs[0,1].set_xlabel('cluster_intra_means');
axs[0,1].set_ylabel('cluster counts');

axs[1,0].hist(quality_metrics['sample_silhouette'], 50);
axs[1,0].set_xlabel('sample_silhouette score');
axs[1,0].set_ylabel('roi sample counts');

_, counts = np.unique(labels[labels!=-1], return_counts=True)

axs[1,1].hist(counts, n_sessions*2 + 1, range=(0, n_sessions+1));
axs[1,1].set_xlabel('n_sessions')
axs[1,1].set_ylabel('cluster counts');

# Make the title include the number of excluded (label==-1) ROIs
fig.suptitle(f'Quality metrics n_excluded: {np.sum(labels==-1)}, n_included: {np.sum(labels!=-1)}, n_total: {len(labels)}, n_clusters: {len(np.unique(labels[labels!=-1]))}, n_sessions: {n_sessions}')
return fig, axs
Loading

0 comments on commit 183f471

Please sign in to comment.