Skip to content

Commit

Permalink
working on refactoring stream_daq_pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
sneakers-the-rat committed Jun 22, 2024
1 parent b9138dd commit aac78ce
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 82 deletions.
130 changes: 60 additions & 70 deletions miniscope_io/stream_daq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import argparse
import logging
import multiprocessing
import os
import sys
Expand All @@ -13,7 +14,7 @@
import cv2
import numpy as np
import serial
from bitstring import Array, BitArray, Bits
from bitstring import BitArray, Bits

from miniscope_io import init_logger
from miniscope_io.devices.mocks import okDevMock
Expand Down Expand Up @@ -147,6 +148,7 @@ def _parse_header(
The returned header data and (optionally truncated) buffer data.
"""
pre = Bits(self.preamble)
buffer = Bits(buffer)
if self.config.LSB:
pre = pre[::-1]
pre_len = len(pre)
Expand All @@ -162,11 +164,30 @@ def _parse_header(
header_data = BufferHeader.model_construct(**header_data)

if truncate == "preamble":
return header_data, buffer[pre_len:]
return header_data, buffer[pre_len:].tobytes()
elif truncate == "header":
return header_data, buffer[self.config.header_len :]
return header_data, buffer[self.config.header_len :].tobytes()
else:
return header_data, buffer
return header_data, buffer.tobytes()

def _trim(self, data: np.ndarray, expected_size: int, logger: logging.Logger) -> np.ndarray:
"""
Trim or pad an array to match an expected size
"""
if data.shape[0] != expected_size:
logger.warning(
f"Expected buffer data length: {expected_size}, got data with shape "
f"{data.shape}. Padding to expected length",
)

# trim if too long
if data.shape[0] > expected_size:
data = data[0:expected_size]
# pad if too short
else:
data = np.pad(data, (0, expected_size - data.shape[0]))

return data

def _uart_recv(
self, serial_buffer_queue: multiprocessing.Queue, comport: str, baudrate: int
Expand Down Expand Up @@ -325,16 +346,19 @@ def _buffer_to_frame(
cur_fm_buffer_index = -1 # Index of buffer within frame
cur_fm_num = -1 # Frame number

frame_buffer = [None] * self.nbuffer_per_fm
frame_buffer = []

try:
for serial_buffer in exact_iter(serial_buffer_queue.get, None):
serial_buffer = Bits(serial_buffer)

header_data, serial_buffer = self._parse_header(serial_buffer)
header_data, serial_buffer = self._parse_header(serial_buffer, truncate="header")
serial_buffer = np.frombuffer(serial_buffer, dtype=np.uint8)
serial_buffer = self._trim(
serial_buffer, self.buffer_npix[header_data.frame_buffer_count], locallogs
)

# log metadata
locallogs.debug(str(header_data.model_dump()))
locallogs.debug(header_data)

# if first buffer of a frame
if header_data.frame_num != cur_fm_num:
Expand All @@ -345,27 +369,29 @@ def _buffer_to_frame(
# push frame_buffer into frame_buffer queue
frame_buffer_queue.put(frame_buffer)
# init frame_buffer
frame_buffer = [None] * self.nbuffer_per_fm
frame_buffer = []

# update frame_num and index
cur_fm_num = header_data.frame_num
cur_fm_buffer_index = header_data.frame_buffer_count

# update data
frame_buffer[cur_fm_buffer_index] = serial_buffer.tobytes()

if cur_fm_buffer_index != 0:
locallogs.warning(
f"Frame {cur_fm_num} started with buffer {cur_fm_buffer_index}"
)
for i in range(cur_fm_buffer_index):
frame_buffer.append(np.zeros(self.buffer_npix[i], dtype=np.uint8))

# update data
frame_buffer.append(serial_buffer)

# if same frame_num with previous buffer.
elif (
header_data.frame_num == cur_fm_num
and header_data.frame_buffer_count > cur_fm_buffer_index
):
cur_fm_buffer_index = header_data.frame_buffer_count
frame_buffer[cur_fm_buffer_index] = serial_buffer.tobytes()
frame_buffer.append(serial_buffer)
locallogs.debug("----buffer #" + str(cur_fm_buffer_index) + " stored")

# if lost frame from buffer -> reset index
Expand Down Expand Up @@ -399,65 +425,30 @@ def _format_frame(
Output image array queue.
"""
locallogs = init_logger("streamDaq.frame")
header_data = None
try:
for frame_data in exact_iter(frame_buffer_queue.get, None):
locallogs.debug("Found frame in queue")

nbit_lost = 0

for i, npix_expected in enumerate(self.buffer_npix):
if frame_data[i] is not None:
header_data, fm_dat = self._parse_header(
Bits(frame_data[i]), truncate="header"
)
else:
frame_data[i] = Bits(int=0, length=npix_expected * self.config.pix_depth)
nbit_lost += npix_expected
continue
npix_header = header_data.pixel_count
npix_actual = len(fm_dat) / self.config.pix_depth

if npix_actual != npix_expected:
if i < len(self.buffer_npix) - 1:
locallogs.warning(
f"Pixel count inconsistent for frame {header_data.frame_num} "
f"buffer {header_data.frame_buffer_count}. "
f"Expected: {npix_expected}, "
f"Header: {npix_header}, "
f"Actual: {npix_actual}"
)
nbit_expected = npix_expected * self.config.pix_depth
if len(fm_dat) > nbit_expected:
fm_dat = fm_dat[:nbit_expected]
else:
nbit_pad = nbit_expected - len(fm_dat)
fm_dat = fm_dat + Bits(int=0, length=nbit_pad)
nbit_lost += nbit_pad

frame_data[i] = fm_dat

pixel_vector = frame_data[0]
for d in frame_data[1:]:
pixel_vector = pixel_vector + d

assert len(pixel_vector) == (
self.config.frame_height * self.config.frame_width * self.config.pix_depth
)

if len(frame_data) == 0:
continue
frame_data = np.concat(frame_data)
if self.config.LSB:
pixel_vector = Array(
"uint:32",
[
pixel_vector[i : i + 32][::-1].uint
for i in reversed(range(0, len(pixel_vector), 32))
],
)
img = np.frombuffer(pixel_vector.tobytes(), dtype=np.uint8)
imagearray.put(img)

if header_data is not None:
locallogs.info(f"frame: {header_data.frame_num}, bits lost: {nbit_lost}")
frame_data = np.flip(frame_data)

frame = np.reshape(frame_data, (self.config.frame_width, self.config.frame_height))

# if self.config.LSB:
# pixel_vector = Array(
# "uint:32",
# [
# pixel_vector[i : i + 32][::-1].uint
# for i in reversed(range(0, len(pixel_vector), 32))
# ],
# )
# img = np.frombuffer(pixel_vector.tobytes(), dtype=np.uint8)
imagearray.put(frame)

# if header_data is not None:
# locallogs.info(f"frame: {header_data.frame_num}, bits lost: {nbit_lost}")
finally:
imagearray.put(None)

Expand Down Expand Up @@ -573,8 +564,7 @@ def capture(
p_format_frame.start()
# p_terminate.start()
try:
for imagearray_plot in exact_iter(imagearray.get, None):
image = imagearray_plot.reshape(self.config.frame_width, self.config.frame_height)
for image in exact_iter(imagearray.get, None):
if self.config.show_video is True:
cv2.imshow("image", image)
if writer:
Expand Down
27 changes: 15 additions & 12 deletions tests/test_stream_daq.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import pdb

import pytest

from miniscope_io.stream_daq import StreamDaqConfig, StreamDaq
from miniscope_io.utils import hash_video, hash_file
from .conftest import DATA_DIR, CONFIG_DIR


@pytest.mark.parametrize(
'config,data,video_hash',
"config,data,video_hash",
[
(
'stream_daq_test_200px.yml',
'stream_daq_test_fpga_raw_input_200px.bin',
'40047689185bdbeb81829aea7c6e3070bcd4673976e5836a138c3e1b54d75099'
"stream_daq_test_200px.yml",
"stream_daq_test_fpga_raw_input_200px.bin",
"40047689185bdbeb81829aea7c6e3070bcd4673976e5836a138c3e1b54d75099",
)
]
],
)
@pytest.mark.timeout(30)
def test_video_output(config, data, video_hash, tmp_path, set_okdev_input):
output_video = tmp_path / 'output.avi'
output_video = tmp_path / "output.avi"

test_config_path = CONFIG_DIR / config
daqConfig = StreamDaqConfig.from_yaml(test_config_path)
Expand All @@ -35,13 +38,13 @@ def test_video_output(config, data, video_hash, tmp_path, set_okdev_input):


@pytest.mark.parametrize(
'config,data',
"config,data",
[
(
'stream_daq_test_200px.yml',
'stream_daq_test_fpga_raw_input_200px.bin',
"stream_daq_test_200px.yml",
"stream_daq_test_fpga_raw_input_200px.bin",
)
]
],
)
@pytest.mark.timeout(30)
def test_binary_output(config, data, set_okdev_input, tmp_path):
Expand All @@ -51,11 +54,11 @@ def test_binary_output(config, data, set_okdev_input, tmp_path):
data_file = DATA_DIR / data
set_okdev_input(data_file)

output_file = tmp_path / 'output.bin'
output_file = tmp_path / "output.bin"

daq_inst = StreamDaq(config=daqConfig)
daq_inst.capture(source="fpga", binary=output_file)

assert output_file.exists()

assert hash_file(data_file) == hash_file(output_file)
assert hash_file(data_file) == hash_file(output_file)

0 comments on commit aac78ce

Please sign in to comment.