Skip to content

Commit

Permalink
First search implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
bauersimon committed Dec 21, 2023
1 parent e464492 commit 8783312
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 3 deletions.
1 change: 1 addition & 0 deletions extractor/extractor/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
ROOT = path.dirname(path.abspath(__file__))

FRAMES = path.join(ROOT, "frames")
QUIET = False
21 changes: 20 additions & 1 deletion extractor/extractor/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,29 @@

class ColorUpdate:
def __init__(self, colors: list[tuple[float, float, float]], temps: list[float]):
def __init__(self, colors: list[tuple[float, float, float]], temps: list[float], timestamp: float = -1.0):
self._colors = colors
self._temps = temps
self._timestamp = timestamp
self._invalid = False

def set_timestamp(self, timestamp: float):
self._timestamp = timestamp

def __str__(self) -> str:
if self._invalid:
return f"{self._timestamp:.2f}s: invalid"
return f"{self._timestamp:.2f}s: {self._colors} {self._temps}"

def __repr__(self) -> str:
return self.__str__()

def __eq__(self, __value: object) -> bool:
if not isinstance(__value, ColorUpdate):
return False
return self._colors == __value._colors and self._temps == __value._temps and self._timestamp == __value._timestamp

@staticmethod
def invalid(timestamp: float = -1.0):
update = ColorUpdate([], [], timestamp)
update._invalid = True
return update
111 changes: 109 additions & 2 deletions extractor/extractor/search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,110 @@
import logging
from abc import ABC
from multiprocessing import Pool
from typing import Optional

import numpy.typing as npt
import tqdm

from . import color, constants, data, image, model

logger = logging.getLogger(__name__)


class AbstractExtractor(ABC):
def is_valid(self, frame: npt.NDArray) -> bool:
"""Checks if the frame is valid for the scheme."""
raise NotImplementedError

def extract(self, frame: npt.NDArray) -> model.ColorUpdate:
"""Extracts the color from the frame."""
raise NotImplementedError

def similar(self, update1: model.ColorUpdate, update2: model.ColorUpdate) -> bool:
"""Checks if the two updates are similar."""
raise NotImplementedError


class Extractor(AbstractExtractor):
def __init__(self, c: color.AbstractColor, hue_areas: list[npt.NDArray], temp_areas: list[npt.NDArray], valid_mask: npt.NDArray, valid_content: npt.NDArray, valid_threshold: float = 0.8):
"""Defines how to search for the color and temperature of the image."""
self._color = c
self._hue_areas = hue_areas
self._temp_areas = temp_areas
self._valid_mask = valid_mask
self._valid_content = valid_content
self._valid_threshold = valid_threshold

def is_valid(self, frame: npt.NDArray) -> bool:
"""Checks if the frame is valid for the scheme."""
return image.similar_image(frame, self._valid_content, self._valid_mask) > self._valid_threshold

def extract(self, frame: npt.NDArray) -> model.ColorUpdate:
"""Extracts the color from the frame."""
if not self.is_valid(frame):
raise Exception("Frame is not valid for this scheme.")

hues = []
temps = []
for area in self._hue_areas:
hue = self._color.hue(frame, area)
hues.append(hue)
for area in self._temp_areas:
temp = self._color.temp(frame, area)
temps.append(temp)
return model.ColorUpdate(hues, temps)

def similar(self, update1: model.ColorUpdate, update2: model.ColorUpdate) -> bool:
"""Checks if the two updates are similar."""
return self._color.similar(update1, update2)


class Search:
def __init__(self):
pass
def __init__(self, s: AbstractExtractor, d: data.FrameGenerator, step: int = 120, refinement_accuracy: float = 3.0, workers: int = 1, quiet: bool = constants.QUIET):
"""Extracts the color and temperature of the data in a binary-search fashion."""
self._scheme = s
self._data = d
self._step = step
self._refinement_accuracy = refinement_accuracy
self._workers = workers
self._quiet = quiet

def _search_step(self, step: float) -> model.ColorUpdate:
frame = self._data.get_frame(step)
if not self._scheme.is_valid(frame):
logger.debug(f"frame {step:.2f}s not valid")
return model.ColorUpdate.invalid(timestamp=step)

update = self._scheme.extract(frame)
update.set_timestamp(step)
return update

def _search_raw(self) -> list[model.ColorUpdate]:
steps = int(self._data.length / self._step)

with Pool(self._workers) as p:
return list(
tqdm.tqdm(p.imap(
self._search_step,
range(0, int(self._data.length)+1, self._step)
), total=steps, disable=self._quiet))

def _search_compact(self, raw: list[model.ColorUpdate]) -> list[model.ColorUpdate]:
updates: list[model.ColorUpdate] = []

for update in raw:
if update._invalid:
continue
if len(updates) > 0 and self._scheme.similar(update, updates[-1]):
logger.debug(
f"frame {update._timestamp:.2f}s similar to previous")
continue
updates.append(update)
return updates

def search(self) -> list[model.ColorUpdate]:
"""Searches for the color and temperature of the data in a binary-search fashion."""
updates = self._search_raw()
updates = self._search_compact(updates)

return updates
1 change: 1 addition & 0 deletions extractor/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ matplotlib
yt-dlp
sewar
colour-science
tqdm

autopep8
isort
Expand Down
73 changes: 73 additions & 0 deletions extractor/tests/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import unittest

import numpy as np
import numpy.typing as npt

from extractor import data, model, search


class TestExtractor(search.AbstractExtractor):
def __init__(self, valid_frames: list[int] = [], similar_frames: dict[int, int] = {}, updates: dict[int, model.ColorUpdate] = {}) -> None:
self._valid_frames = valid_frames
self._updates = updates
self._similar_frames = similar_frames

def is_valid(self, frame: npt.NDArray) -> bool:
"""Checks if the frame is valid for the scheme."""
return int(frame.item()) in self._valid_frames

def extract(self, frame: npt.NDArray) -> model.ColorUpdate:
"""Extracts the color from the frame."""
return self._updates[int(frame.item())]

def similar(self, update1: model.ColorUpdate, update2: model.ColorUpdate) -> bool:
"""Checks if the two updates are similar."""
try:
return self._similar_frames[int(update1._timestamp)] == int(update2._timestamp) or self._similar_frames[int(update2._timestamp)] == int(update1._timestamp)
except KeyError:
return False


class TestFrameGenerator(data.FrameGenerator):
def __init__(self, length: int) -> None:
self._length = length

def get_frame(self, second: float) -> npt.NDArray:
"""Obtain a frame at a certain time point."""
return np.array([int(second)])

@property
def length(self) -> float:
return float(self._length)


class TestSearch(unittest.TestCase):
def test_search_raw(self):
generator = TestFrameGenerator(5)
extractor = TestExtractor(
updates={
0: model.ColorUpdate([], [0]),
1: model.ColorUpdate([], [1]),
2: model.ColorUpdate([], [2]),
3: model.ColorUpdate([], [3]),
4: model.ColorUpdate([], [4]),
5: model.ColorUpdate([], [5]),
},
valid_frames=[0, 2, 3, 5],
)
actual = search.Search(extractor, generator,
step=1, workers=1, quiet=True)._search_raw()
expected = [
model.ColorUpdate([], [0], 0.0),
model.ColorUpdate.invalid(1.0),
model.ColorUpdate([], [2], 2.0),
model.ColorUpdate([], [3], 3.0),
model.ColorUpdate.invalid(4.0),
model.ColorUpdate([], [5], 5.0),
]

self.assertEqual(actual, expected)


if __name__ == '__main__':
unittest.main()

0 comments on commit 8783312

Please sign in to comment.