-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e464492
commit 8783312
Showing
5 changed files
with
204 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ | |
ROOT = path.dirname(path.abspath(__file__)) | ||
|
||
FRAMES = path.join(ROOT, "frames") | ||
QUIET = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,29 @@ | ||
|
||
class ColorUpdate: | ||
def __init__(self, colors: list[tuple[float, float, float]], temps: list[float]): | ||
def __init__(self, colors: list[tuple[float, float, float]], temps: list[float], timestamp: float = -1.0): | ||
self._colors = colors | ||
self._temps = temps | ||
self._timestamp = timestamp | ||
self._invalid = False | ||
|
||
def set_timestamp(self, timestamp: float): | ||
self._timestamp = timestamp | ||
|
||
def __str__(self) -> str: | ||
if self._invalid: | ||
return f"{self._timestamp:.2f}s: invalid" | ||
return f"{self._timestamp:.2f}s: {self._colors} {self._temps}" | ||
|
||
def __repr__(self) -> str: | ||
return self.__str__() | ||
|
||
def __eq__(self, __value: object) -> bool: | ||
if not isinstance(__value, ColorUpdate): | ||
return False | ||
return self._colors == __value._colors and self._temps == __value._temps and self._timestamp == __value._timestamp | ||
|
||
@staticmethod | ||
def invalid(timestamp: float = -1.0): | ||
update = ColorUpdate([], [], timestamp) | ||
update._invalid = True | ||
return update |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,110 @@ | ||
import logging | ||
from abc import ABC | ||
from multiprocessing import Pool | ||
from typing import Optional | ||
|
||
import numpy.typing as npt | ||
import tqdm | ||
|
||
from . import color, constants, data, image, model | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AbstractExtractor(ABC): | ||
def is_valid(self, frame: npt.NDArray) -> bool: | ||
"""Checks if the frame is valid for the scheme.""" | ||
raise NotImplementedError | ||
|
||
def extract(self, frame: npt.NDArray) -> model.ColorUpdate: | ||
"""Extracts the color from the frame.""" | ||
raise NotImplementedError | ||
|
||
def similar(self, update1: model.ColorUpdate, update2: model.ColorUpdate) -> bool: | ||
"""Checks if the two updates are similar.""" | ||
raise NotImplementedError | ||
|
||
|
||
class Extractor(AbstractExtractor): | ||
def __init__(self, c: color.AbstractColor, hue_areas: list[npt.NDArray], temp_areas: list[npt.NDArray], valid_mask: npt.NDArray, valid_content: npt.NDArray, valid_threshold: float = 0.8): | ||
"""Defines how to search for the color and temperature of the image.""" | ||
self._color = c | ||
self._hue_areas = hue_areas | ||
self._temp_areas = temp_areas | ||
self._valid_mask = valid_mask | ||
self._valid_content = valid_content | ||
self._valid_threshold = valid_threshold | ||
|
||
def is_valid(self, frame: npt.NDArray) -> bool: | ||
"""Checks if the frame is valid for the scheme.""" | ||
return image.similar_image(frame, self._valid_content, self._valid_mask) > self._valid_threshold | ||
|
||
def extract(self, frame: npt.NDArray) -> model.ColorUpdate: | ||
"""Extracts the color from the frame.""" | ||
if not self.is_valid(frame): | ||
raise Exception("Frame is not valid for this scheme.") | ||
|
||
hues = [] | ||
temps = [] | ||
for area in self._hue_areas: | ||
hue = self._color.hue(frame, area) | ||
hues.append(hue) | ||
for area in self._temp_areas: | ||
temp = self._color.temp(frame, area) | ||
temps.append(temp) | ||
return model.ColorUpdate(hues, temps) | ||
|
||
def similar(self, update1: model.ColorUpdate, update2: model.ColorUpdate) -> bool: | ||
"""Checks if the two updates are similar.""" | ||
return self._color.similar(update1, update2) | ||
|
||
|
||
class Search: | ||
def __init__(self): | ||
pass | ||
def __init__(self, s: AbstractExtractor, d: data.FrameGenerator, step: int = 120, refinement_accuracy: float = 3.0, workers: int = 1, quiet: bool = constants.QUIET): | ||
"""Extracts the color and temperature of the data in a binary-search fashion.""" | ||
self._scheme = s | ||
self._data = d | ||
self._step = step | ||
self._refinement_accuracy = refinement_accuracy | ||
self._workers = workers | ||
self._quiet = quiet | ||
|
||
def _search_step(self, step: float) -> model.ColorUpdate: | ||
frame = self._data.get_frame(step) | ||
if not self._scheme.is_valid(frame): | ||
logger.debug(f"frame {step:.2f}s not valid") | ||
return model.ColorUpdate.invalid(timestamp=step) | ||
|
||
update = self._scheme.extract(frame) | ||
update.set_timestamp(step) | ||
return update | ||
|
||
def _search_raw(self) -> list[model.ColorUpdate]: | ||
steps = int(self._data.length / self._step) | ||
|
||
with Pool(self._workers) as p: | ||
return list( | ||
tqdm.tqdm(p.imap( | ||
self._search_step, | ||
range(0, int(self._data.length)+1, self._step) | ||
), total=steps, disable=self._quiet)) | ||
|
||
def _search_compact(self, raw: list[model.ColorUpdate]) -> list[model.ColorUpdate]: | ||
updates: list[model.ColorUpdate] = [] | ||
|
||
for update in raw: | ||
if update._invalid: | ||
continue | ||
if len(updates) > 0 and self._scheme.similar(update, updates[-1]): | ||
logger.debug( | ||
f"frame {update._timestamp:.2f}s similar to previous") | ||
continue | ||
updates.append(update) | ||
return updates | ||
|
||
def search(self) -> list[model.ColorUpdate]: | ||
"""Searches for the color and temperature of the data in a binary-search fashion.""" | ||
updates = self._search_raw() | ||
updates = self._search_compact(updates) | ||
|
||
return updates |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ matplotlib | |
yt-dlp | ||
sewar | ||
colour-science | ||
tqdm | ||
|
||
autopep8 | ||
isort | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
from extractor import data, model, search | ||
|
||
|
||
class TestExtractor(search.AbstractExtractor): | ||
def __init__(self, valid_frames: list[int] = [], similar_frames: dict[int, int] = {}, updates: dict[int, model.ColorUpdate] = {}) -> None: | ||
self._valid_frames = valid_frames | ||
self._updates = updates | ||
self._similar_frames = similar_frames | ||
|
||
def is_valid(self, frame: npt.NDArray) -> bool: | ||
"""Checks if the frame is valid for the scheme.""" | ||
return int(frame.item()) in self._valid_frames | ||
|
||
def extract(self, frame: npt.NDArray) -> model.ColorUpdate: | ||
"""Extracts the color from the frame.""" | ||
return self._updates[int(frame.item())] | ||
|
||
def similar(self, update1: model.ColorUpdate, update2: model.ColorUpdate) -> bool: | ||
"""Checks if the two updates are similar.""" | ||
try: | ||
return self._similar_frames[int(update1._timestamp)] == int(update2._timestamp) or self._similar_frames[int(update2._timestamp)] == int(update1._timestamp) | ||
except KeyError: | ||
return False | ||
|
||
|
||
class TestFrameGenerator(data.FrameGenerator): | ||
def __init__(self, length: int) -> None: | ||
self._length = length | ||
|
||
def get_frame(self, second: float) -> npt.NDArray: | ||
"""Obtain a frame at a certain time point.""" | ||
return np.array([int(second)]) | ||
|
||
@property | ||
def length(self) -> float: | ||
return float(self._length) | ||
|
||
|
||
class TestSearch(unittest.TestCase): | ||
def test_search_raw(self): | ||
generator = TestFrameGenerator(5) | ||
extractor = TestExtractor( | ||
updates={ | ||
0: model.ColorUpdate([], [0]), | ||
1: model.ColorUpdate([], [1]), | ||
2: model.ColorUpdate([], [2]), | ||
3: model.ColorUpdate([], [3]), | ||
4: model.ColorUpdate([], [4]), | ||
5: model.ColorUpdate([], [5]), | ||
}, | ||
valid_frames=[0, 2, 3, 5], | ||
) | ||
actual = search.Search(extractor, generator, | ||
step=1, workers=1, quiet=True)._search_raw() | ||
expected = [ | ||
model.ColorUpdate([], [0], 0.0), | ||
model.ColorUpdate.invalid(1.0), | ||
model.ColorUpdate([], [2], 2.0), | ||
model.ColorUpdate([], [3], 3.0), | ||
model.ColorUpdate.invalid(4.0), | ||
model.ColorUpdate([], [5], 5.0), | ||
] | ||
|
||
self.assertEqual(actual, expected) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |