diff --git a/tests/test_subsamplers.py b/tests/test_subsamplers.py index be6045ac..8e5c47d3 100644 --- a/tests/test_subsamplers.py +++ b/tests/test_subsamplers.py @@ -211,15 +211,14 @@ def test_audio_rate_subsampler(sample_rate, n_audio_channels): "cut_detection_mode,framerates", [("longest", []), ("longest", [1]), ("all", []), ("all", [1])] ) def test_cut_detection_subsampler(cut_detection_mode, framerates): - current_folder = os.path.dirname(__file__) - video = os.path.join(current_folder, "test_files/test_video.mp4") - with open(video, "rb") as vid_f: - video_bytes = vid_f.read() - subsampler = CutDetectionSubsampler(cut_detection_mode, framerates, threshold=5) - streams = {"video": [video_bytes]} - streams, cuts, err_msg = subsampler(streams) + current_folder = os.path.dirname(__file__) + video_filepath = os.path.join(current_folder, "test_files/test_video.mp4") + metadata, error_message = subsampler(video_filepath) + assert error_message is None + cuts = metadata["cuts"] + if cut_detection_mode == "longest": assert len(cuts["cuts_original_fps"]) == 1 assert cuts["cuts_original_fps"][0] == [0, 2096] @@ -276,17 +275,11 @@ def test_optical_flow_subsampler(detector, fps, params): @pytest.mark.parametrize("extract_keyframes", [False, True]) def test_ffprobe_subsampler(extract_keyframes): - current_folder = os.path.dirname(__file__) - # video length - 2:02, 1080x1920, 30 fps - video = os.path.join(current_folder, "test_files/test_video.mp4") - with open(video, "rb") as vid_f: - video_bytes = vid_f.read() - subsampler = FFProbeSubsampler(extract_keyframes) - streams = {"video": [video_bytes]} - metadata = {} - subsampled_streams, metadata, error_message = subsampler(streams, metadata) + current_folder = os.path.dirname(__file__) + video_filepath = os.path.join(current_folder, "test_files/test_video.mp4") + metadata, error_message = subsampler(video_filepath) assert error_message is None assert metadata is not None assert "video_metadata" in metadata diff --git a/video2dataset/subsamplers/cut_detection_subsampler.py b/video2dataset/subsamplers/cut_detection_subsampler.py index a5349f7d..3c57122b 100644 --- a/video2dataset/subsamplers/cut_detection_subsampler.py +++ b/video2dataset/subsamplers/cut_detection_subsampler.py @@ -3,10 +3,10 @@ """ import numpy as np from scenedetect import ContentDetector, SceneManager, open_video -import os -import tempfile +from typing import Tuple, List, Optional, Literal -from .subsampler import Subsampler +from video2dataset.subsamplers.subsampler import Subsampler +from video2dataset.types import Metadata, Error # TODO: this can be done more elegantly: # from scenedetect import scene_manager and set that in correct namespace @@ -45,53 +45,52 @@ class CutDetectionSubsampler(Subsampler): - min_scene_len - minimum scene length to not drop a scene (see pyscenedeteect docs for more explanation) """ - def __init__(self, cut_detection_mode="all", framerates=None, threshold=27, min_scene_len=15): - self.framerates = framerates + def __init__( + self, + cut_detection_mode: Literal["all", "longest"] = "all", + framerates: Optional[List[int]] = None, + threshold: int = 27, + min_scene_len: int = 15, + ): + self.framerates = framerates if framerates is not None else [] self.cut_detection_mode = cut_detection_mode self.threshold = threshold self.min_scene_len = min_scene_len - def __call__(self, streams, metadata=None): - video_bytes = streams["video"][0] - + def __call__(self, video_filepath: str, metadata: Optional[Metadata] = None) -> Tuple[Metadata, Error]: + metadata = metadata if metadata is not None else {} try: - with tempfile.TemporaryDirectory() as tmpdir: - video_path = os.path.join(tmpdir, "input.mp4") - with open(video_path, "wb") as f: - f.write(video_bytes) - - video = open_video(video_path) + # find scene changes + video = open_video(video_filepath) + detector = ContentDetector(threshold=self.threshold, min_scene_len=self.min_scene_len) + scene_manager = SceneManager() + scene_manager.add_detector(detector) + scene_manager.auto_downscale = False + scene_manager.downscale = video.frame_size[0] // DEFAULT_MIN_WIDTH + scene_manager.detect_scenes(video=video) + + # extract cuts in both original fps and target fps + cuts = {} + original_fps = video.frame_rate + cuts["original_fps"] = original_fps + cuts["cuts_original_fps"] = get_scenes_from_scene_manager(scene_manager, self.cut_detection_mode) + for target_fps in self.framerates: + video.reset() detector = ContentDetector(threshold=self.threshold, min_scene_len=self.min_scene_len) scene_manager = SceneManager() scene_manager.add_detector(detector) - scene_manager.auto_downscale = False - scene_manager.downscale = video.frame_size[0] // DEFAULT_MIN_WIDTH - - cuts = {} - original_fps = video.frame_rate - cuts["original_fps"] = original_fps - - scene_manager.detect_scenes(video=video) - cuts["cuts_original_fps"] = get_scenes_from_scene_manager(scene_manager, self.cut_detection_mode) - if self.framerates is not None: - for target_fps in self.framerates: - video.reset() - - detector = ContentDetector(threshold=self.threshold, min_scene_len=self.min_scene_len) - scene_manager = SceneManager() - scene_manager.add_detector(detector) - frame_skip = max( - int(original_fps // target_fps) - 1, 0 - ) # if we take 1 frame and skip N frames we're sampling 1/N+1 % of the video - # so if we desire to sample 1/N of the video, we need to subtract one when doing frame skipping - - scene_manager.detect_scenes(video=video, frame_skip=frame_skip) - cuts[f"cuts_{target_fps}"] = get_scenes_from_scene_manager( - scene_manager, self.cut_detection_mode - ) - scene_manager.clear() - except Exception as err: # pylint: disable=broad-except - return {}, None, str(err) + frame_skip = max( + int(original_fps // target_fps) - 1, 0 + ) # if we take 1 frame and skip N frames we're sampling 1/N+1 % of the video + # so if we desire to sample 1/N of the video, we need to subtract one when doing frame skipping - return streams, cuts, None + scene_manager.detect_scenes(video=video, frame_skip=frame_skip) + cuts[f"cuts_{target_fps}"] = get_scenes_from_scene_manager(scene_manager, self.cut_detection_mode) + scene_manager.clear() + + # save and return metadata + metadata["cuts"] = cuts + except Exception as err: # pylint: disable=broad-except + return metadata, str(err) + return metadata, None diff --git a/video2dataset/subsamplers/ffprobe_subsampler.py b/video2dataset/subsamplers/ffprobe_subsampler.py index 52c5b61e..b1c9c01a 100644 --- a/video2dataset/subsamplers/ffprobe_subsampler.py +++ b/video2dataset/subsamplers/ffprobe_subsampler.py @@ -1,10 +1,10 @@ """extracts basic video compression metadata.""" -import os import json import subprocess -import tempfile +from typing import Tuple, Optional -from .subsampler import Subsampler +from video2dataset.subsamplers.subsampler import Subsampler +from video2dataset.types import Metadata, Error # TODO: figuer out why this is so slow (12 samples/s) @@ -18,41 +18,38 @@ class FFProbeSubsampler(Subsampler): def __init__(self, extract_keyframes=False): self.extract_keyframes = extract_keyframes - def __call__(self, streams, metadata): - # TODO: this should also work for audio (maybe others) - video_bytes = streams["video"][0] - with tempfile.TemporaryDirectory() as tmpdir: - with open(os.path.join(tmpdir, "input.mp4"), "wb") as f: - f.write(video_bytes) - try: - command = [ - "ffprobe", - "-v", - "quiet", - "-print_format", - "json", - "-show_format", - "-show_streams", - f"{tmpdir}/input.mp4", + def __call__(self, video_filepath: str, metadata: Optional[Metadata] = None) -> Tuple[Metadata, Error]: + metadata = metadata if metadata is not None else {} + try: + # extract video metadata + command = [ + "ffprobe", + "-v", + "quiet", + "-print_format", + "json", + "-show_format", + "-show_streams", + f"{video_filepath}", + ] + if self.extract_keyframes: + command.extend(["-select_streams", "v:0", "-show_entries", "packet=pts_time,flags"]) + process = subprocess.run(command, capture_output=True, text=True, check=True) + video_metadata = json.loads(process.stdout) + + # extract keyframe timestamps if requested + if self.extract_keyframes: + keyframe_timestamps = [ + float(packet["pts_time"]) for packet in video_metadata["packets"] if "K" in packet.get("flags", "") ] - - if self.extract_keyframes: - command.extend(["-select_streams", "v:0", "-show_entries", "packet=pts_time,flags"]) - - process = subprocess.run(command, capture_output=True, text=True, check=True) - video_metadata = json.loads(process.stdout) - - if self.extract_keyframes: - keyframe_info = [entry for entry in video_metadata["packets"] if "K" in entry.get("flags", "")] - keyframe_timestamps = [float(entry["pts_time"]) for entry in keyframe_info] - if "duration" in video_metadata["format"]: - duration = float(video_metadata["format"]["duration"]) - keyframe_timestamps.append(duration) - video_metadata["keyframe_timestamps"] = keyframe_timestamps - video_metadata.pop("packets") # Don't need it anymore - metadata["video_metadata"] = video_metadata - - except Exception as err: # pylint: disable=broad-except - return streams, metadata, str(err) - - return streams, metadata, None + if "duration" in video_metadata["format"]: + duration = float(video_metadata["format"]["duration"]) + keyframe_timestamps.append(duration) + video_metadata["keyframe_timestamps"] = keyframe_timestamps + video_metadata.pop("packets") # Don't need it anymore + + # save and return metadata + metadata["video_metadata"] = video_metadata + except Exception as err: # pylint: disable=broad-except + return metadata, str(err) + return metadata, None diff --git a/video2dataset/subsamplers/noop_subsampler.py b/video2dataset/subsamplers/noop_subsampler.py index b12975a8..dd714055 100644 --- a/video2dataset/subsamplers/noop_subsampler.py +++ b/video2dataset/subsamplers/noop_subsampler.py @@ -1,11 +1,13 @@ """No operation subsampler""" +from typing import List, Tuple -from .subsampler import Subsampler +from video2dataset.subsamplers.subsampler import Subsampler +from video2dataset.types import Metadata, Error, TempFilepaths class NoOpSubsampler(Subsampler): def __init__(self): pass - def __call__(self, streams, metadata): - return streams, [metadata], None + def __call__(self, filepaths: TempFilepaths, metadata: Metadata) -> Tuple[TempFilepaths, List[Metadata], Error]: + return filepaths, [metadata], None diff --git a/video2dataset/types.py b/video2dataset/types.py index 03ae7c1e..9e170e64 100644 --- a/video2dataset/types.py +++ b/video2dataset/types.py @@ -1,5 +1,5 @@ """Type definitions for video2dataset.""" -from typing import List, TypedDict +from typing import List, TypedDict, Optional class EncodeFormats(TypedDict, total=False): @@ -14,3 +14,12 @@ class Streams(TypedDict, total=False): # TODO: make more structured Metadata = dict + + +Error = Optional[str] + + +# TODO: remove after refactoring is complete +class TempFilepaths(TypedDict, total=False): + video: List[str] + audio: List[str] diff --git a/video2dataset/workers/worker.py b/video2dataset/workers/worker.py index 45650829..b37e42db 100644 --- a/video2dataset/workers/worker.py +++ b/video2dataset/workers/worker.py @@ -1,7 +1,10 @@ """Standard worker for video2dataset.""" from dataclasses import dataclass, field import numpy as np -from typing import Any, List, Tuple, Optional +import os +import tempfile +from typing import Any, List, Tuple, Optional, Literal, cast +import uuid from video2dataset.logger import CappedCounter from video2dataset.subsamplers import ( @@ -14,24 +17,7 @@ AudioRateSubsampler, Subsampler, ) -from video2dataset.types import EncodeFormats, Streams, Metadata - - -@dataclass -class ShardStatus: - """Shard processing status""" - - successes: int = 0 - failed: dict = field( - default_factory=lambda: { - "failed_to_download": 0, - "failed_to_subsample": 0, - } - ) - status_dict: CappedCounter = field(default_factory=CappedCounter) - error_message: Optional[str] = None - count: int = 0 - bytes_downloaded: int = 0 +from video2dataset.types import EncodeFormats, Streams, Metadata, TempFilepaths @dataclass @@ -114,6 +100,50 @@ def get_subsamplers( ) +@dataclass +class ShardStatus: + """Shard processing status""" + + successes: int = 0 + failed: dict = field( + default_factory=lambda: { + "failed_to_download": 0, + "failed_to_subsample": 0, + } + ) + status_dict: CappedCounter = field(default_factory=CappedCounter) + error_message: Optional[str] = None + count: int = 0 + bytes_downloaded: int = 0 + + +def extract_video_metadata( + subsamplers: Subsamplers, + shard_status: ShardStatus, + metadata: Metadata, + video_filepath: str, + captions_are_subtitles: bool, +): + """Add additional metadata keys for video file""" + + if subsamplers.ffprobe_subsampler is not None: + metadata, shard_status.error_message = subsamplers.ffprobe_subsampler(video_filepath, metadata) + assert shard_status.error_message is None + + if captions_are_subtitles: # create clips + subtitles = metadata["yt_meta_dict"]["subtitles"] + metadata["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] + elif subsamplers.cut_detection_subsampler is not None: # apply cut detection to get clips + metadata, shard_status.error_message = subsamplers.cut_detection_subsampler(video_filepath, metadata) + assert shard_status.error_message is None + cuts = metadata["cuts"] + assert cuts is not None + if subsamplers.cuts_are_clips: + metadata["clips"] = (np.array(cuts["cuts_original_fps"]) / cuts["original_fps"]).tolist() + + return metadata + + def process_sample( subsamplers: Subsamplers, shard_status: ShardStatus, @@ -127,20 +157,32 @@ def process_sample( """Process a single video""" try: - if subsamplers.ffprobe_subsampler is not None: - streams, metadata, shard_status.error_message = subsamplers.ffprobe_subsampler(streams, metadata) - assert shard_status.error_message is None - - if captions_are_subtitles: # create clips - subtitles = metadata["yt_meta_dict"]["subtitles"] - metadata["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] - elif subsamplers.cut_detection_subsampler is not None: # apply cut detection to get clips - streams, cuts, shard_status.error_message = subsamplers.cut_detection_subsampler(streams) - assert shard_status.error_message is None - metadata["cuts"] = cuts - assert cuts is not None - if subsamplers.cuts_are_clips: - metadata["clips"] = (np.array(cuts["cuts_original_fps"]) / cuts["original_fps"]).tolist() + with tempfile.TemporaryDirectory() as tmpdir: + # save temp stream dumps + temp_filepaths: TempFilepaths = {} + for modality in streams: + modality = cast(Literal["video", "audio"], modality) + temp_filepaths[modality] = [] + for stream in streams[modality]: + stream_uuid = str(uuid.uuid4()) + temp_filepath = os.path.join(tmpdir, stream_uuid) + with open(temp_filepath, "wb") as f: + f.write(stream) + temp_filepaths[modality].append(temp_filepath) + + # this is pre-broadcast, so there should only be one video + assert "video" in temp_filepaths + assert len(temp_filepaths["video"]) == 1 + video_filepath = temp_filepaths["video"][0] + + # add info about keyframes and cuts + metadata = extract_video_metadata( + subsamplers=subsamplers, + shard_status=shard_status, + metadata=metadata, + video_filepath=video_filepath, + captions_are_subtitles=captions_are_subtitles, + ) # 1 video -> many videos (either clipping or noop which does identity broadcasting) subsampled_streams, metadatas, shard_status.error_message = subsamplers.broadcast_subsampler(streams, metadata)