From 5fcc780ee680a1dec26aa7db54af1202c756f485 Mon Sep 17 00:00:00 2001 From: Brice Parent <34689945+Flowtter@users.noreply.github.com> Date: Sun, 28 Jan 2024 19:23:59 +0100 Subject: [PATCH] feat: Add the finals (#44) * feat: The finals implemented * feat: Fix pytest dep until they fix https://github.com/pytest-dev/pytest/issues/11662 * feat: Add test Dockerfile --- README.md | 41 +++++++- crispy-api/Dockerfile.test | 14 +++ crispy-api/api/__init__.py | 25 ++--- crispy-api/api/config.py | 23 +++-- crispy-api/api/models/highlight.py | 98 +++++++++++++++++-- crispy-api/api/routes/highlight.py | 4 +- crispy-api/api/tools/enums.py | 1 + crispy-api/api/tools/image.py | 7 +- crispy-api/api/tools/setup.py | 63 ++++++++++++- crispy-api/api/tools/utils.py | 24 +++++ crispy-api/api/tools/video.py | 135 +++++++++++++++++++++++++-- crispy-api/requirements-dev.txt | 6 +- crispy-api/requirements.txt | 1 + crispy-api/tests/assets | 2 +- crispy-api/tests/constants.py | 1 + crispy-api/tests/models/highlight.py | 10 +- crispy-api/tests/tools/setup.py | 22 +++++ crispy-api/tests/tools/utils.py | 19 ++++ crispy-api/tests/tools/video.py | 30 +++++- 19 files changed, 477 insertions(+), 49 deletions(-) create mode 100644 crispy-api/Dockerfile.test create mode 100644 crispy-api/tests/tools/utils.py diff --git a/README.md b/README.md index 10cde67..4ec8662 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ It uses a neural network to detect highlights in the video-game frames.\ # Supported games -Currently it supports **[Valorant](https://playvalorant.com/)**, **[Overwatch](https://playoverwatch.com/)** and **[CSGO2](https://www.counter-strike.net/cs2)**. +Currently it supports **[Valorant](https://playvalorant.com/)**, **[Overwatch](https://playoverwatch.com/)**, **[CSGO2](https://www.counter-strike.net/cs2)** and **[The Finals](https://www.reachthefinals.com/)**. # Usage @@ -122,6 +122,21 @@ Here are some settings that I found to work well for me: } ``` +#### The Finals + +```json +{ + "clip": { + "framerate": 8, + "second-before": 6, + "second-after": 0, + "second-between-kills": 3 + }, + "stretch": false, + "game": "thefinals" +} +``` + ## Run You can now run the application with the run.[sh|bat] file. @@ -176,6 +191,30 @@ The following effects are available to use: In the result view, you can see the result of your montage. +# Q&A + +### **Q:** Why are some games not using the neural network? + +**A:** To detect highlights in a video-game, the neural-network searches for things that always happen in a highlight.\ +For example, in Overwatch, a kill is symbolized by a red skull. So the neural-network will search for red skulls in the frames.\ +Unfortunately, not all games have such things.\ +The finals, for example, is a game where you don't have any symbol to represent a kill.\ +So for those games, the neural-network is not used. Instead, we're using an OCR to detect the killfeed.\ +The OCR is definitely not as efficient as the neural-network, slow, and depends on the quality of the video.\ +But it's the best we can do for now. + +### **Q:** Why are some games not supported? + +**A:** The neural-network has simply not been trained for those games.\ +If you want to add support for a game, you can train the neural-network yourself and then make a pull request.\ +A tutorial is available [here](https://github.com/Flowtter/crispy/tree/master/crispy-api/dataset). + +### **Q:** In CSGO2, I moved the UI, and the kills are not detected anymore. What can I do? + +**A:** Unfortunately, there is nothing you can do.\ +The neural-network is trained to detect kills in the default UI.\ +I'm planning to add support for custom UI in the future, but this is definitely not a priority. + # Contributing Every contribution is welcome. diff --git a/crispy-api/Dockerfile.test b/crispy-api/Dockerfile.test new file mode 100644 index 0000000..d476f47 --- /dev/null +++ b/crispy-api/Dockerfile.test @@ -0,0 +1,14 @@ +FROM python:3.10-slim +WORKDIR /app +RUN apt-get -y update +RUN apt-get install ffmpeg -y + +COPY requirements.txt . +COPY requirements-dev.txt . +RUN pip install -r requirements-dev.txt +COPY api api +COPY tests tests +COPY assets assets +COPY settings.json settings.json +COPY setup.cfg setup.cfg +CMD ["pytest"] diff --git a/crispy-api/api/__init__.py b/crispy-api/api/__init__.py index af3535c..968a8da 100644 --- a/crispy-api/api/__init__.py +++ b/crispy-api/api/__init__.py @@ -12,24 +12,27 @@ from montydb import MontyClient, set_storage from pydantic.json import ENCODERS_BY_TYPE -from api.config import ASSETS, DATABASE_PATH, DEBUG, FRAMERATE, GAME, MUSICS, VIDEOS +from api.config import ( + ASSETS, + DATABASE_PATH, + DEBUG, + FRAMERATE, + GAME, + MUSICS, + USE_NETWORK, + VIDEOS, +) from api.tools.AI.network import NeuralNetwork -from api.tools.enums import SupportedGames from api.tools.filters import apply_filters # noqa from api.tools.setup import handle_highlights, handle_musics ENCODERS_BY_TYPE[ObjectId] = str -neural_network = NeuralNetwork(GAME) +NEURAL_NETWORK = None -if GAME == SupportedGames.OVERWATCH: - neural_network.load(os.path.join(ASSETS, "overwatch.npy")) -elif GAME == SupportedGames.VALORANT: - neural_network.load(os.path.join(ASSETS, "valorant.npy")) -elif GAME == SupportedGames.CSGO2: - neural_network.load(os.path.join(ASSETS, "csgo2.npy")) -else: - raise ValueError(f"game {GAME} not supported") +if USE_NETWORK: + NEURAL_NETWORK = NeuralNetwork(GAME) + NEURAL_NETWORK.load(os.path.join(ASSETS, GAME + ".npy")) logging.getLogger("PIL").setLevel(logging.ERROR) diff --git a/crispy-api/api/config.py b/crispy-api/api/config.py index ea8eaf4..b06e4e1 100644 --- a/crispy-api/api/config.py +++ b/crispy-api/api/config.py @@ -1,6 +1,7 @@ import json import os +import easyocr from starlette.config import Config from api.tools.enums import SupportedGames @@ -49,15 +50,23 @@ FRAMES_BEFORE = __clip.get("second-before", 0) * FRAMERATE FRAMES_AFTER = __clip.get("second-after", 0) * FRAMERATE - __neural_network = __settings.get("neural-network") - if __neural_network is None: - raise KeyError("neural-network not found in settings.json") - - CONFIDENCE = __neural_network.get("confidence", 0.6) - - STRETCH = __settings.get("stretch", False) GAME = __settings.get("game") if GAME is None: raise KeyError("game not found in settings.json") if GAME.upper() not in [game.name for game in SupportedGames]: raise ValueError(f"game {GAME} not supported") + + USE_NETWORK = GAME not in [SupportedGames.THEFINALS] + + __neural_network = __settings.get("neural-network") + if __neural_network is None and USE_NETWORK: + raise KeyError("neural-network not found in settings.json") + + if __neural_network is not None: + CONFIDENCE = __neural_network.get("confidence", 0.6) + else: + CONFIDENCE = 0 + + STRETCH = __settings.get("stretch", False) + +READER = easyocr.Reader(["en", "fr"], gpu=True, verbose=False) diff --git a/crispy-api/api/models/highlight.py b/crispy-api/api/models/highlight.py index 969c60c..bc94b51 100644 --- a/crispy-api/api/models/highlight.py +++ b/crispy-api/api/models/highlight.py @@ -22,28 +22,34 @@ class Box: def __init__( self, - offset_x: int, + x: int, y: int, width: int, height: int, shift_x: int, stretch: bool, + from_center: bool = True, ) -> None: """ - :param offset_x: Offset in pixels from the center of the video to the left + :param x: Offset in pixels from the left of the video or from the center if use_offset is enabled :param y: Offset in pixels from the top of the video :param width: Width of the box in pixels :param height: Height of the box in pixels :param shift_x: Shift the box by a certain amount of pixels to the right + :param stretch: Stretch the box to fit the video + :param use_offset: If enabled, x will be from the center of the video, else it will be from the left (usef) example: If you want to create a box at 50 px from the center on x, but shifted by 20px to the right you would do: Box(50, 0, 100, 100, 20) """ - half = 720 if stretch else 960 + if from_center: + half = 720 if stretch else 960 + self.x = half - x + shift_x + else: + self.x = x + shift_x - self.x = half - offset_x + shift_x self.y = y self.width = width self.height = height @@ -93,17 +99,21 @@ async def extract_images( post_process: Callable, coordinates: Box, framerate: int = 4, + save_path: str = "images", + force_extract: bool = False, ) -> bool: """ Extracts images from a video at a given framerate :param post_process: Function to apply to each image + :param coordinates: Coordinates of the box to extract :param framerate: Framerate to extract the images + :param save_path: Path to save the images """ - if self.images_path: + if self.images_path and not force_extract: return False - images_path = os.path.join(self.directory, "images") + images_path = os.path.join(self.directory, save_path) if not os.path.exists(images_path): os.mkdir(images_path) @@ -124,8 +134,9 @@ async def extract_images( post_process(im).save(im_path) - self.update({"images_path": images_path}) - self.save() + if save_path == "images": + self.update({"images_path": images_path}) + self.save() return True @@ -220,6 +231,73 @@ def post_process(image: Image) -> Image: post_process, Box(50, 925, 100, 100, 20, stretch), framerate=framerate ) + async def extract_the_finals_images( + self, framerate: int = 4, stretch: bool = False + ) -> bool: + def is_color_close( + pixel: Tuple[int, int, int], + expected: Tuple[int, int, int], + threshold: int = 100, + ) -> bool: + distance: int = ( + sum((pixel[i] - expected[i]) ** 2 for i in range(len(pixel))) ** 0.5 + ) + return distance < threshold + + def post_process_killfeed(image: Image) -> Image: + r, g, b = image.split() + for x in range(image.width): + for y in range(image.height): + if not is_color_close( + (r.getpixel((x, y)), g.getpixel((x, y)), b.getpixel((x, y))), + (12, 145, 201), + 120, + ): + r.putpixel((x, y), 0) + b.putpixel((x, y), 0) + g.putpixel((x, y), 0) + + im = ImageOps.grayscale(Image.merge("RGB", (r, g, b))) + + final = Image.new("RGB", (250, 115)) + final.paste(im, (0, 0)) + return final + + killfeed_state = await self.extract_images( + post_process_killfeed, + Box(1500, 75, 250, 115, 0, stretch, from_center=False), + framerate=framerate, + ) + + def post_process(image: Image) -> Image: + r, g, b = image.split() + for x in range(image.width): + for y in range(image.height): + if not is_color_close( + (r.getpixel((x, y)), g.getpixel((x, y)), b.getpixel((x, y))), + (255, 255, 255), + ): + r.putpixel((x, y), 0) + b.putpixel((x, y), 0) + g.putpixel((x, y), 0) + + im = ImageOps.grayscale(Image.merge("RGB", (r, g, b))) + + final = Image.new("RGB", (200, 120)) + final.paste(im, (0, 0)) + return final + + return ( + await self.extract_images( + post_process, + Box(20, 800, 200, 120, 0, stretch, from_center=False), + framerate=framerate, + save_path="usernames", + force_extract=True, + ) + and killfeed_state + ) + async def extract_images_from_game( self, game: SupportedGames, framerate: int = 4, stretch: bool = False ) -> bool: @@ -229,8 +307,10 @@ async def extract_images_from_game( return await self.extract_valorant_images(framerate, stretch) elif game == SupportedGames.CSGO2: return await self.extract_csgo2_images(framerate, stretch) + elif game == SupportedGames.THEFINALS: + return await self.extract_the_finals_images(framerate, stretch) else: - raise NotImplementedError + raise NotImplementedError(f"game {game} not supported") def recompile(self) -> bool: from api.tools.utils import sanitize_dict diff --git a/crispy-api/api/routes/highlight.py b/crispy-api/api/routes/highlight.py index 78cc20f..f74f576 100644 --- a/crispy-api/api/routes/highlight.py +++ b/crispy-api/api/routes/highlight.py @@ -5,7 +5,7 @@ from fastapi.responses import FileResponse from pydantic import BaseModel -from api import app, neural_network +from api import NEURAL_NETWORK, app from api.config import CONFIDENCE, FRAMERATE, FRAMES_AFTER, FRAMES_BEFORE, OFFSET from api.models.highlight import Highlight from api.models.segment import Segment @@ -84,7 +84,7 @@ async def post_highlights_segments_generate() -> None: extract_segments, kwargs={ "highlight": highlight, - "neural_network": neural_network, + "neural_network": NEURAL_NETWORK, "confidence": CONFIDENCE, "framerate": FRAMERATE, "offset": OFFSET, diff --git a/crispy-api/api/tools/enums.py b/crispy-api/api/tools/enums.py index d3e8c89..b72edde 100644 --- a/crispy-api/api/tools/enums.py +++ b/crispy-api/api/tools/enums.py @@ -5,3 +5,4 @@ class SupportedGames(str, Enum): VALORANT = "valorant" OVERWATCH = "overwatch" CSGO2 = "csgo2" + THEFINALS = "thefinals" diff --git a/crispy-api/api/tools/image.py b/crispy-api/api/tools/image.py index 41d6424..3ef66ab 100644 --- a/crispy-api/api/tools/image.py +++ b/crispy-api/api/tools/image.py @@ -19,4 +19,9 @@ def compare_image(path1: str, path2: str) -> bool: data1 = np.asarray(blur1) data2 = np.asarray(blur2) - return bool((1 + np.corrcoef(data1.flat, data2.flat)[0, 1]) / 2 > 0.8) + # https://stackoverflow.com/questions/51248810/python-why-would-numpy-corrcoef-return-nan-values + corrcoef = np.corrcoef(data1.flat, data2.flat) + if np.isnan(corrcoef).all(): # pragma: no cover + return True + + return bool((1 + corrcoef[0, 1]) / 2 > 0.8) diff --git a/crispy-api/api/tools/setup.py b/crispy-api/api/tools/setup.py index 8798055..24e70ff 100644 --- a/crispy-api/api/tools/setup.py +++ b/crispy-api/api/tools/setup.py @@ -1,18 +1,20 @@ import logging import os import shutil +from collections import Counter from typing import List import ffmpeg from PIL import Image -from api.config import SESSION, SILENCE_PATH, STRETCH +from api.config import READER, SESSION, SILENCE_PATH, STRETCH from api.models.filter import Filter from api.models.highlight import Highlight from api.models.music import Music from api.tools.audio import video_has_audio from api.tools.enums import SupportedGames from api.tools.job_scheduler import JobScheduler +from api.tools.utils import levenstein_distance logger = logging.getLogger("uvicorn") @@ -23,6 +25,63 @@ def __sanitize_path(path: str) -> str: return path +def handle_the_finals( + new_highlights: List[Highlight], + framerate: int = 4, +) -> None: + for highlight in new_highlights: + path = os.path.join(highlight.directory, "usernames") + images = os.listdir(path) + usernames: List[str] = [] + usernames_histogram: Counter = Counter() + + step = int(framerate / 2) + step = 1 if step == 0 else step + for i in range(0, len(images), step): + image = images[i] + image_path = os.path.join(path, image) + result = READER.readtext(image_path) + for text in result: + if text[1].isnumeric(): + continue + usernames_histogram[text[1].lower()] += 1 + most_common_usernames = usernames_histogram.most_common(2) + if ( + len(most_common_usernames) == 2 + and most_common_usernames[0][1] >= 10 + and most_common_usernames[1][1] >= 10 + and levenstein_distance( + most_common_usernames[0][0], most_common_usernames[1][0] + ) + >= 3 + ): # pragma: no cover + break + + for username, count in usernames_histogram.items(): + if count > 2: + if username not in usernames: + usernames.append(username) + + for ch in ("_", " ", ".", "-"): + split_username = username.split(ch) + if len(split_username) > 1: + for split in split_username: + if split not in usernames and len(split) > 2: + usernames.append(split) + + highlight.update({"usernames": usernames}) + highlight.save() + + +def handle_specific_game( + new_highlights: List[Highlight], + game: SupportedGames, + framerate: int = 4, +) -> None: + if game == SupportedGames.THEFINALS: + handle_the_finals(new_highlights, framerate) + + async def handle_highlights( path: str, game: SupportedGames, @@ -104,6 +163,8 @@ async def handle_highlights( Highlight.update_many({}, {"$set": {"job_id": None}}) + handle_specific_game(new_highlights, game, framerate) + return new_highlights diff --git a/crispy-api/api/tools/utils.py b/crispy-api/api/tools/utils.py index 1206a4d..8fe2f9b 100644 --- a/crispy-api/api/tools/utils.py +++ b/crispy-api/api/tools/utils.py @@ -25,3 +25,27 @@ def get_all_jobs_from_highlights( def sanitize_dict(d: Any) -> Dict: """Remove all keys with None, False or Empty string values from a dict/object""" "" return {k: v for k, v in d.items() if v} + + +def levenstein_distance(s1: str, s2: str) -> int: + """Calculates the levenstein distance between two strings""" + if len(s1) < len(s2): + return levenstein_distance(s2, s1) + + if len(s2) == 0: + return len(s1) + + previous_row = list(range(len(s2) + 1)) + + for i, c1 in enumerate(s1): + current_row = [i + 1] + + for j, c2 in enumerate(s2): + insertions = previous_row[j + 1] + 1 + deletions = current_row[j] + 1 + substitutions = previous_row[j] + (c1 != c2) + current_row.append(min(insertions, deletions, substitutions)) + + previous_row = current_row + + return previous_row[-1] diff --git a/crispy-api/api/tools/video.py b/crispy-api/api/tools/video.py index d1dfc83..b96b5ca 100644 --- a/crispy-api/api/tools/video.py +++ b/crispy-api/api/tools/video.py @@ -1,14 +1,18 @@ import asyncio import logging import os +from collections import Counter from typing import List, Tuple import numpy as np from PIL import Image +from api.config import GAME, READER from api.models.highlight import Highlight from api.models.segment import Segment from api.tools.AI.network import NeuralNetwork +from api.tools.enums import SupportedGames +from api.tools.utils import levenstein_distance logger = logging.getLogger("uvicorn") @@ -54,6 +58,116 @@ def _create_query_array( return queries +def _create_the_finals_query_array(highlight: Highlight) -> List[int]: + teammate_usernames = highlight.usernames + images = os.listdir(highlight.images_path) + images.sort() + + usernames_histogram: Counter = Counter() + + for image in images: + image_path = os.path.join(highlight.images_path, image) + + result = READER.readtext(image_path) + for text in result: + if text[1].isnumeric(): + continue + usernames_histogram[text[1].lower()] += 1 + + # filter all usernames that have a levenstein distance of 3 or more to the usernames array (filtering teammates) + for username in list(usernames_histogram): + if ( + min( + levenstein_distance(username, teammate_username) + for teammate_username in teammate_usernames + ) + <= 3 + ): + usernames_histogram.pop(username) + + # filter all usernames that appear only once + for username in list(usernames_histogram): + if usernames_histogram[username] == 1: + usernames_histogram.pop(username) + + # merge all usernames that have a levenstein distance of 1 or 2 to the usernames_histogram + while True: + final_usernames_histogram: Counter = Counter() + seen = set() + + for i, username in enumerate(list(usernames_histogram)): + if username in seen: + continue + shift = i + 1 + for other_username in list(usernames_histogram)[shift:]: + if levenstein_distance(username, other_username) <= 2: + most_common_username = max( + username, + other_username, + key=lambda username: usernames_histogram[username], + ) + least_common_username = min( + username, + other_username, + key=lambda username: usernames_histogram[username], + ) + final_usernames_histogram[most_common_username] = ( + usernames_histogram[least_common_username] + + usernames_histogram[most_common_username] + ) + seen.add(least_common_username) + seen.add(most_common_username) + break + else: + final_usernames_histogram[username] = usernames_histogram[username] + + if len(final_usernames_histogram) == len(usernames_histogram): + break + + usernames_histogram = final_usernames_histogram + + if len(final_usernames_histogram) == 0: # pragma: no cover + logger.warning(f"No usernames found for highlight {highlight.id}") + return [] + + queries = [] + predicted_username = max( + final_usernames_histogram, key=final_usernames_histogram.__getitem__ + ) + + for i, image in enumerate(images): + image_path = os.path.join(highlight.images_path, image) + + result = READER.readtext(image_path) + for text in result: + if text[1].isnumeric(): + continue + if levenstein_distance(text[1].lower(), predicted_username) <= 1: + queries.append(i) + break + + logger.debug( + f"For highlight {highlight.id} found {predicted_username} with" + + f"{final_usernames_histogram[predicted_username]} occurences" + ) + return queries + + +def _get_query_array( + neural_network: NeuralNetwork, + highlight: Highlight, + confidence: float, + game: SupportedGames, +) -> List[int]: + if neural_network: + return _create_query_array(neural_network, highlight, confidence) + if game == SupportedGames.THEFINALS: + return _create_the_finals_query_array(highlight) + raise ValueError( + f"No neural network for game {game} and no custom query array" + ) # pragma: no cover + + def _normalize_queries( queries: List[int], frames_before: int, frames_after: int ) -> List[Tuple[int, int]]: @@ -123,19 +237,20 @@ async def extract_segments( offset: int, frames_before: int, frames_after: int, + game: SupportedGames = GAME, ) -> Tuple[List[Tuple[float, float]], List[Segment]]: """ - Extract segments from a highlight - - :param highlight: highlight to extract segments from - :param neural_network: neural network to query - :param confidence: confidence to query - :param offset: offset to post process - :param framerate: framerate of the video - - :return: list of segments + Extract segments from a highlight + game + :param highlight: highlight to extract segments from + :param neural_network: neural network to query + :param confidence: confidence to query + :param offset: offset to post process + :param framerate: framerate of the video + + :return: list of segments """ - queries = _create_query_array(neural_network, highlight, confidence) + queries = _get_query_array(neural_network, highlight, confidence, game) normalized = _normalize_queries(queries, frames_before, frames_after) processed = _post_process_query_array(normalized, offset, framerate) segments = await highlight.extract_segments(processed) diff --git a/crispy-api/requirements-dev.txt b/crispy-api/requirements-dev.txt index a0c50ab..51da8ad 100644 --- a/crispy-api/requirements-dev.txt +++ b/crispy-api/requirements-dev.txt @@ -1,9 +1,9 @@ -r requirements.txt mypy black -pytest -pytest-cov -pytest-asyncio +pytest==7.2.1 +pytest-asyncio==0.20.3 +pytest-cov==4.0.0 flake8 httpx mutagen diff --git a/crispy-api/requirements.txt b/crispy-api/requirements.txt index d761250..01093ea 100644 --- a/crispy-api/requirements.txt +++ b/crispy-api/requirements.txt @@ -20,3 +20,4 @@ progressbar2==4.0.0 scipy==1.8.0 pydub==0.25.1 montydb==2.4.0 +easyocr==1.7.1 diff --git a/crispy-api/tests/assets b/crispy-api/tests/assets index 0c51f43..87e16dd 160000 --- a/crispy-api/tests/assets +++ b/crispy-api/tests/assets @@ -1 +1 @@ -Subproject commit 0c51f4332690fde477f94d1ecf6e6e08dcb82903 +Subproject commit 87e16dd0c1c321c16562d4b5afef0fe47e71621f diff --git a/crispy-api/tests/constants.py b/crispy-api/tests/constants.py index bdc5183..4615bde 100644 --- a/crispy-api/tests/constants.py +++ b/crispy-api/tests/constants.py @@ -14,6 +14,7 @@ MAIN_VIDEO_1440 = os.path.join(VIDEOS_PATH, "main-video-1440.mp4") MAIN_VIDEO_OVERWATCH = os.path.join(VIDEOS_PATH, "main-video-overwatch.mp4") MAIN_VIDEO_CSGO2 = os.path.join(VIDEOS_PATH, "main-video-csgo2.mp4") +MAIN_VIDEO_THEFINALS = os.path.join(VIDEOS_PATH, "main-video-thefinals.mp4") MAIN_SEGMENT = os.path.join(VIDEOS_PATH, "main-video-segment.mp4") DATASET_VALUES_PATH = os.path.join(ROOT_ASSETS, "dataset-values.json") diff --git a/crispy-api/tests/models/highlight.py b/crispy-api/tests/models/highlight.py index 3d66a69..6b993fc 100644 --- a/crispy-api/tests/models/highlight.py +++ b/crispy-api/tests/models/highlight.py @@ -7,7 +7,12 @@ from api.models.highlight import Highlight from api.models.segment import Segment from api.tools.enums import SupportedGames -from tests.constants import MAIN_VIDEO_CSGO2, MAIN_VIDEO_NO_AUDIO, MAIN_VIDEO_OVERWATCH +from tests.constants import ( + MAIN_VIDEO_CSGO2, + MAIN_VIDEO_NO_AUDIO, + MAIN_VIDEO_OVERWATCH, + MAIN_VIDEO_THEFINALS, +) async def test_highlight(highlight): @@ -152,8 +157,9 @@ async def test_segment_video_segments_are_removed(highlight, tmp_path): (None, SupportedGames.VALORANT, 8), (MAIN_VIDEO_OVERWATCH, SupportedGames.OVERWATCH, 1.5), (MAIN_VIDEO_CSGO2, SupportedGames.CSGO2, 1.5), + (MAIN_VIDEO_THEFINALS, SupportedGames.THEFINALS, 0.75), ], - ids=["valorant", "overwatch", "csgo2"], + ids=["valorant", "overwatch", "csgo2", "thefinals"], ) async def test_extract_game_images(highlight, highlight_path, game, rate): if highlight_path is not None: diff --git a/crispy-api/tests/tools/setup.py b/crispy-api/tests/tools/setup.py index 66d992d..808d5cc 100644 --- a/crispy-api/tests/tools/setup.py +++ b/crispy-api/tests/tools/setup.py @@ -10,6 +10,7 @@ MAIN_VIDEO, MAIN_VIDEO_NO_AUDIO, MAIN_VIDEO_STRETCH, + MAIN_VIDEO_THEFINALS, ) @@ -85,6 +86,27 @@ async def test_handle_highlights_stretch(tmp_path): shutil.rmtree(tmp_resources) +async def test_handle_highlights_the_finals(tmp_path): + tmp_session = os.path.join(tmp_path, "session") + tmp_resources = os.path.join(tmp_path, "resources") + os.mkdir(tmp_resources) + + shutil.copy(MAIN_VIDEO_THEFINALS, tmp_resources) + + assert await handle_highlights( + tmp_resources, SupportedGames.THEFINALS, session=tmp_session + ) + + assert Highlight.count_documents() == 1 + usernames = sorted(Highlight.find_one().usernames) + assert "heximius" in usernames + assert "raynox" in usernames + assert "sxr" in usernames + + shutil.rmtree(tmp_session) + shutil.rmtree(tmp_resources) + + async def test_handle_musics(tmp_path): tmp_resources = os.path.join(tmp_path, "resources") os.mkdir(tmp_resources) diff --git a/crispy-api/tests/tools/utils.py b/crispy-api/tests/tools/utils.py new file mode 100644 index 0000000..c4cb86a --- /dev/null +++ b/crispy-api/tests/tools/utils.py @@ -0,0 +1,19 @@ +from api.tools.utils import levenstein_distance + + +async def test_levenshtein_distance(): + assert levenstein_distance("", "") == 0 + assert levenstein_distance("test", "test") == 0 + + assert levenstein_distance("test", "tst") == 1 + assert levenstein_distance("test", "tast") == 1 + assert levenstein_distance("test", "teest") == 1 + + assert levenstein_distance("test", "yolo") == 4 + + assert levenstein_distance("test", "y") == 4 + assert levenstein_distance("y", "test") == 4 + assert levenstein_distance("test", "t") == 3 + assert levenstein_distance("t", "test") == 3 + assert levenstein_distance("test", "") == 4 + assert levenstein_distance("", "test") == 4 diff --git a/crispy-api/tests/tools/video.py b/crispy-api/tests/tools/video.py index 6e12d41..c1a8e5a 100644 --- a/crispy-api/tests/tools/video.py +++ b/crispy-api/tests/tools/video.py @@ -1,3 +1,4 @@ +import os import shutil import pytest @@ -6,7 +7,7 @@ from api.models.segment import Segment from api.tools.enums import SupportedGames from api.tools.video import extract_segments -from tests.constants import MAIN_VIDEO +from tests.constants import MAIN_VIDEO, MAIN_VIDEO_THEFINALS @pytest.mark.parametrize( @@ -228,3 +229,30 @@ async def test_extract_segment_recompile_global( ) assert timestamps == expected shutil.rmtree(highlight.images_path) + + +async def test_extract_segments_the_finals(highlight): + highlight.path = MAIN_VIDEO_THEFINALS + highlight.usernames = ["heximius", "sxr_raynox", "srx", "raynox"] + highlight = highlight.save() + + await highlight.extract_images_from_game(SupportedGames.THEFINALS, 8) + timestamps, _ = await extract_segments( + highlight, + None, + confidence=0, + framerate=8, + offset=0, + frames_before=0, + frames_after=8, + game=SupportedGames.THEFINALS, + ) + assert timestamps == [ + (5.5, 7.875), + (12.125, 13.5), + (19.75, 21.0), + (21.125, 22.375), + (23.0, 25.875), + ] + shutil.rmtree(highlight.images_path) + shutil.rmtree(os.path.join(os.path.dirname(highlight.images_path), "usernames"))