Skip to content

Commit

Permalink
add pre-commit checks
Browse files Browse the repository at this point in the history
  • Loading branch information
sneakers-the-rat committed Jun 21, 2024
1 parent 37e6910 commit c6f86c3
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 18 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.10
hooks:
- id: ruff
args: [ --fix ]
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
17 changes: 10 additions & 7 deletions miniscope_io/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class Frame(BaseModel, arbitrary_types_allowed=True):
Typically returned from :meth:`.SDCard.read`
"""

data: Optional[np.ndarray] = None
headers: Optional[List[SDBufferHeader]] = None
data: np.ndarray
headers: List[SDBufferHeader]

@field_validator("headers")
@classmethod
Expand All @@ -34,9 +34,9 @@ def frame_nums_must_be_equal(cls, v: List[SDBufferHeader]) -> Optional[List[SDBu
return v

@property
def frame_num(self) -> int:
def frame_num(self) -> Optional[int]:
"""
Frame number for this set of headers
Frame number for this set of headers, if headers are present
"""
return self.headers[0].frame_num

Expand All @@ -49,10 +49,12 @@ class Frames(BaseModel):
frames: List[Frame]

@overload
def flatten_headers(self, as_dict: Literal[False] = False) -> List[SDBufferHeader]: ...
def flatten_headers(self, as_dict: Literal[False]) -> List[SDBufferHeader]:
...

@overload
def flatten_headers(self, as_dict: Literal[True] = True) -> List[dict]: ...
def flatten_headers(self, as_dict: Literal[True]) -> List[dict]:
...

def flatten_headers(self, as_dict: bool = False) -> Union[List[dict], List[SDBufferHeader]]:
"""
Expand All @@ -62,8 +64,9 @@ def flatten_headers(self, as_dict: bool = False) -> Union[List[dict], List[SDBuf
as_dict (bool): If `True`, return a list of dictionaries, if `False`
(default), return a list of :class:`.SDBufferHeader` s.
"""
h = []
h: Union[List[dict], List[SDBufferHeader]] = []
for frame in self.frames:
headers: Union[List[dict], List[SDBufferHeader]]
if as_dict:
headers = [header.model_dump() for header in frame.headers]
else:
Expand Down
15 changes: 9 additions & 6 deletions miniscope_io/devices/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import os
from pathlib import Path
from typing import Dict
from typing import Dict, Optional

from miniscope_io.exceptions import EndOfRecordingException

Expand All @@ -21,7 +21,7 @@ class okDevMock:
Mock class for :class:`~miniscope_io.devices.opalkelly.okDev`
"""

DATA_FILE = None
DATA_FILE: Optional[Path] = None
"""
Recorded data file to use for simulating read.
Expand All @@ -34,19 +34,22 @@ class okDevMock:

def __init__(self, serial_id: str = ""):
self.serial_id = serial_id
self.bit_file = None
self.bit_file: Optional[Path] = None

self._wires: Dict[int, int] = {}
self._buffer_position = 0

# preload the data file to a byte array
if self.DATA_FILE is None:
if os.environ.get("PYTEST_OKDEV_DATA_FILE", False):
if os.environ.get("PYTEST_OKDEV_DATA_FILE") is not None:
# need to get file from env variables here because on some platforms
# the default method for creating a new process is "spawn" which creates
# an entirely new python session instead of "fork" which would preserve
# the classvar
okDevMock.DATA_FILE = Path(os.environ.get("PYTEST_OKDEV_DATA_FILE"))
data_file: str = os.environ.get("PYTEST_OKDEV_DATA_FILE") # type: ignore

self.DATA_FILE = Path(data_file)
okDevMock.DATA_FILE = Path(data_file)
else:
raise RuntimeError("DATA_FILE class attr must be set before using the mock")

Expand All @@ -55,7 +58,7 @@ def __init__(self, serial_id: str = ""):

def uploadBit(self, bit_file: str) -> None:
assert Path(bit_file).exists()
self.bit_file = bit_file
self.bit_file = Path(bit_file)

def readData(self, length: int, addr: int = 0xA0, blockSize: int = 16) -> bytearray:
if self._buffer_position >= len(self._buffer):
Expand Down
4 changes: 2 additions & 2 deletions miniscope_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def hash_file(path: Union[Path, str]) -> str:

def hash_video(
path: Union[Path, str],
method: hashlib.algorithms_available = "blake2s",
method: str = "blake2s",
) -> str:
"""
Create a hash of a video by digesting the byte string each of its decoded frames.
Expand All @@ -60,6 +60,6 @@ def hash_video(
ret, frame = vid.read()
if not ret:
break
h.update(frame)
h.update(frame) # type: ignore

return h.hexdigest()
Loading

0 comments on commit c6f86c3

Please sign in to comment.