Skip to content

Commit

Permalink
Merge pull request #108 from RichieHakim/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
RichieHakim authored Oct 17, 2023
2 parents f8a66d5 + ddce267 commit 2e6912f
Show file tree
Hide file tree
Showing 8 changed files with 2,065 additions and 649 deletions.
1 change: 1 addition & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ updates:
directory: "/" # This denotes the directory where Dependabot will look for dependency files.
schedule:
interval: "daily" # This will tell Dependabot to check for updates once a day.
open-pull-requests-limit: 99999 # This will tell Dependabot to keep up to n PRs open at any time.
target-branch: "dev" # This will ensure Dependabot PRs target the "dev" branch.
commit-message:
prefix: "chore" # This is the prefix for commit messages.
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ With ROICaT, you can:

**What are the minimum computing needs?**
- We recommend the following as a starting point:
- 4 GB of RAM (more for large data sets e.g., ~32 Gb for 100K neurons)
- GPU not required but will increase run speed.
- 4 GB of RAM (more for large data sets e.g., ~32 GB for 100K neurons)
- GPU not required but will increase run speeds ~5-50x


<br>
Expand Down Expand Up @@ -68,8 +68,8 @@ Listed below, we have a suite of easy to run notebooks for running the ROICaT pi
# General workflow:
- **Pass ROIs through ROInet:** Images of the ROIs are passed through a neural network which outputs a feature vector for each image describing what the ROI looks like.
- **Classification:** The feature vectors can then be used to classify ROIs:
- A simple classifier can be trained using user supplied labeled data (e.g. an array of images of ROIs and a corresponding array of labels for each ROI).
- Alternatively, classification can be done by projecting the feature vectors into a lower-dimensional space using UMAP and then simply circling the region of space to classify the ROIs.
- A simple regression-like classifier can be trained using user-supplied labeled data (e.g. an array of images of ROIs and a corresponding array of labels for each ROI).
- Alternatively, classification can be done by projecting the feature vectors into a lower-dimensional space using UMAP and then simply circling the region of space to classify the ROIs.
- **Tracking**: The feature vectors can be combined with information about the position of the ROIs to track the ROIs across imaging sessions/planes.


Expand Down

Large diffs are not rendered by default.

2,367 changes: 1,734 additions & 633 deletions notebooks/jupyter/tracking/tracking_visualize_results.ipynb

Large diffs are not rendered by default.

19 changes: 10 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
hdbscan==0.8.33
holoviews[recommended]==1.16.2
holoviews[recommended]==1.17.1
jupyter==1.0.0
kymatio==0.3.0
matplotlib==3.7.1
matplotlib==3.8.0
natsort==8.4.0
numpy==1.24.3
opencv_contrib_python==4.8.0.76
opencv_contrib_python==4.8.1.78
optuna==3.3.0
Pillow==10.0.0
Pillow==10.1.0
pytest==7.4.2
scikit_learn==1.3.0
scikit_learn==1.3.1
scipy==1.11.1
seaborn==0.12.2
seaborn==0.13.0
sparse==0.14.0
tqdm==4.66.1
umap-learn==0.5.3
umap-learn==0.5.4
xxhash==3.2.0
bokeh==3.2.2
psutil==5.9.5
Expand All @@ -25,6 +25,7 @@ mat73==0.62
torch==2.0.1
torchvision==0.15.2
torchaudio==2.0.2
selenium==4.14.0
skl2onnx==1.15.0
onnx==1.14.0
onnxruntime==1.15.1
onnx==1.14.1
onnxruntime==1.16.0
2 changes: 1 addition & 1 deletion roicat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
for pkg in __all__:
exec('from . import ' + pkg)

__version__ = '1.1.23'
__version__ = '1.1.24'
2 changes: 1 addition & 1 deletion roicat/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def select_region_scatterPlot(
# Declare points as source of selection stream
selection = hv.streams.Selection1D(source=points)

path_tempFile = tempfile.gettempdir() + '/indices.csv' if path is None else path
path_tempFile = os.path.join(tempfile.gettempdir(), 'indices.csv') if path is None else path

# Write function that uses the selection indices to slice points and compute stats
def callback(index):
Expand Down
313 changes: 313 additions & 0 deletions tests/test_interactive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
import os
import warnings
import pytest
import tempfile
import shutil
import requests
import socket
import psutil
import time
from functools import partial

import numpy as np
import torch
import multiprocessing as mp
from multiprocessing import Queue

from roicat import visualization
import holoviews as hv
from bokeh.server.server import Server
from bokeh.application import Application
from bokeh.application.handlers.function import FunctionHandler

from selenium import webdriver
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.by import By
from selenium.webdriver.common.action_chains import ActionChains
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options

hv.extension("bokeh")

def get_tempdir_free_space():
temp_dir = tempfile.gettempdir()
usage = shutil.disk_usage(temp_dir)
return usage.free


def create_mock_input():
## Create mock input. Looks dumb. But it's just a test.
mock_data = np.array([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=np.float32)
mock_idx_images_overlay = torch.tensor([0, 1, 2, 3])
mock_images_overlay = np.random.rand(4, 2, 2)
return mock_data, mock_idx_images_overlay, mock_images_overlay


def get_indices(path_tempFile):
## Steal the fn_get_indices function from roicat.visualization for testing purposes
with open(path_tempFile, "r") as f:
indices = f.read().split(",")
indices = [int(i) for i in indices if i != ""] if len(indices) > 0 else None
return indices


def deploy_bokeh(instance):
## Draw test plot and add to Bokeh document
hv.extension("bokeh")

## Create a mock input
mock_data, mock_idx_images_overlay, mock_images_overlay = create_mock_input()

## Create a scatter plot
_, layout, _ = visualization.select_region_scatterPlot(
data=mock_data,
idx_images_overlay=mock_idx_images_overlay,
images_overlay=mock_images_overlay,
size_images_overlay=0.01,
frac_overlap_allowed=0.5,
figsize=(1200, 1200),
alpha_points=1.0,
size_points=10,
color_points="b",
)

## Render plot
hv_layout = hv.render(layout)
hv_layout.name = "drawing_test"

## Add to Bokeh document
instance.add_root(hv_layout)


def is_port_available(port, host='localhost'):
"""Check if a given port is available."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
try:
s.bind((host, port))
except socket.error:
return False
return True


def start_server(apps, query):
## Start Bokeh server given a test scatter plot
server = Server(
apps,
port=5006,
address="0.0.0.0",
allow_websocket_origin=["0.0.0.0:5006", "localhost:5006"],
)
server.start()

## Put server details into the queue
query.put(
{
"address": server.address,
"port": server.port,
}
)

## Setup IO loop
server.io_loop.start()


def kill_process_on_port(port):
for proc in psutil.process_iter(attrs=["pid", "connections"]):
if (proc.info["connections"] != []) and (proc.info["connections"] is not None):
for connection in proc.info["connections"]:
if connection.status == psutil.CONN_LISTEN and connection.laddr.port == port:
try:
psutil.Process(proc.info["pid"]).terminate()
return True
except psutil.AccessDenied:
return False
return False


def check_server():
try:
response = requests.get("http://localhost:5006/test_drawing")
if response.status_code == 200:
warnings.warn("Server is up and running!")
return 1
else:
warnings.warn(f"Server responded with status code: {response.status_code}")
return 0
except requests.ConnectionError:
warnings.warn("Cannot connect to the server!")
return 0


def test_interactive_drawing():
warnings.warn("Interactive GUI Drawing Test is running. Please wait...")
## Parameter setting
port = 5006
server_iter = 0
max_server_iter = 3
max_webdriver_iter = 10
path_tempdir = tempfile.gettempdir()
path_tempfile = os.path.join(path_tempdir, "indices.csv")

## Make sure the tempdir has enough space
warnings.warn("Check if tempdir has enough space...")
free_space = get_tempdir_free_space()
warnings.warn(f"Free space in tempdir: {free_space / (1024**3):.2f} GB")

## Start iteration. Iteration will stop if indices.csv is created.
while server_iter <= max_server_iter:
## First, check if the port for Bokeh server is available
if is_port_available(port):
warnings.warn(f"Port {port} is available!")
else:
warnings.warn(f"Port {port} is not available!")
warnings.warn(f"Kill process using port {port}...")
kill_process_on_port(port)

## Bokeh server deployment at http://localhost:5006/test_drawing
apps = {"/test_drawing": Application(FunctionHandler(deploy_bokeh))}

warnings.warn("Deploy Bokeh server to localhost:5006/test_drawing...")
## Let it run in the background so that the test can continue
query = Queue()
server_process = mp.Process(target=start_server, args=(apps, query))
server_process.start()

## Wait for the server to start
time.sleep(5)

## Start webdriver iteration using selenium
webdriver_iter = 0
while webdriver_iter <= max_webdriver_iter:
warnings.warn(f"Iteration {webdriver_iter} starts...")
## Get server info
server_query = query.get()

warnings.warn(f"Server address: {server_query['address']}")
warnings.warn(f"Server port: {server_query['port']}")

## Check if the server is up and running
warnings.warn("Check if Bokeh server is up and running...")
server_status = check_server()
if not server_status:
server_process.terminate()
server_process.join()
raise Exception("Server is not up and running!")

warnings.warn("Setup chrome webdriver...")
service = Service()
chrome_options = Options()
chrome_options.add_argument("--window-size=1280,1280")

## For local testing, comment out these options to visualize actions in bokeh / selenium server.
chrome_options.add_argument("--headless")
chrome_options.add_argument("--disable-gpu")

## if you are on latest version say selenium v4.6.0 or higher, you don't have to use third party library such as WebDriverManager
warnings.warn("Driver parameter sanity check...")
driver = webdriver.Chrome(service=service, options=chrome_options)
capabilities = driver.capabilities
warnings.warn("Browser Name: {}".format(capabilities.get("browserName")))
warnings.warn("Browser Version: {}".format(capabilities.get("browserVersion")))
warnings.warn("Platform Name: {}".format(capabilities.get("platformName")))
warnings.warn(
"Chrome Driver Version: {}".format(
capabilities.get("chrome").get("chromedriverVersion")
)
)

warnings.warn("Get to the Bokeh server...")
driver.get("http://localhost:5006/test_drawing")
wait = WebDriverWait(driver, 10)
warnings.warn("Found the Bokeh server, locate drawing Bokeh element...")
try:
element = wait.until(EC.presence_of_element_located((By.XPATH, "//*")))
warnings.warn("Found Bokeh drawing element!")
except Exception as e:
warnings.warn(f"Failed to locate element: {str(e)}")

## Create movement set
driver.execute_script("document.body.style.zoom='60%'") ## Zoom out
size = element.size
width, height = size["width"], size["height"]

warnings.warn("element size: {}".format(size))
warnings.warn("element tagname: {}".format(element.tag_name))
warnings.warn("element text: {}".format(element.text))
warnings.warn("element location: {}".format(element.location))
warnings.warn("element displayed: {}".format(element.is_displayed()))
warnings.warn("element enabled: {}".format(element.is_enabled()))
warnings.warn("element selected: {}".format(element.is_selected()))

## Move to the center of the element
warnings.warn("Start mouse movement...")
actions = ActionChains(driver)

## Surprisingly, this pause seems to be crucial.
actions.pause(5)

## Move to the center of the element
actions.move_to_element(element)
actions.click_and_hold()

## Draw!
actions.move_by_offset(
int(width / 2), int(0)
) ## Move from center to midpoint of right edge
actions.move_by_offset(
int(0), int(-height / 2)
) ## Move from midpoint of right edge to top right corner
actions.move_by_offset(
int(-width / 2), int(0)
) ## Move from top right corner to midpoint of top edge
actions.release()
actions.perform()

warnings.warn("Mouse movement done! Detach Selenium from Bokeh server...")
driver.quit()

## Wait for the server to save indices.csv
time.sleep(5)

## Test if indices.csv is created
if os.path.exists(path_tempfile):
break
else:
webdriver_iter += 1

## Kill the process to prevent potential race condition
warnings.warn("Kill the Bokeh server...")
server_process.terminate()
server_process.join()

## Wait for the process to terminate
time.sleep(5)

## If indices.csv is created, break the loop
if os.path.exists(path_tempfile):
warnings.warn("indices.csv is created!")
break
else:
warnings.warn("indices.csv is not created!")
server_iter += 1

warnings.warn("Test if indices.csv has correct permission...")
if not os.access(path_tempfile, os.R_OK):
warnings.warn("indices.csv is not readable!")
server_process.terminate()
server_process.join()
raise Exception("indices.csv is not readable!")

warnings.warn("Test if indices are correctly saved...")
indices = get_indices(path_tempfile)

if indices is None:
warnings.warn("indices.csv is created, but no indices are saved!")
server_process.terminate()
server_process.join()
raise Exception("indices.csv is created, but no indices are saved!")

## Check if the indices are correct
assert indices == [3]
warnings.warn("Test is done.")

0 comments on commit 2e6912f

Please sign in to comment.