From 9fd5086de54acd6bea910f0b0800ffe1f323424c Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Thu, 8 Feb 2024 16:56:55 +0100 Subject: [PATCH] Revert "Download worker refactor (#288)" (#308) This reverts commit 83afef059ba1a29eb92bb6cb922f1f8e0ffd5965. --- video2dataset/subsamplers/__init__.py | 2 - video2dataset/types.py | 4 - video2dataset/workers/download_worker.py | 304 ++++++++++++++++------- video2dataset/workers/subset_worker.py | 197 ++++++++++++--- video2dataset/workers/worker.py | 197 --------------- 5 files changed, 374 insertions(+), 330 deletions(-) delete mode 100644 video2dataset/workers/worker.py diff --git a/video2dataset/subsamplers/__init__.py b/video2dataset/subsamplers/__init__.py index 53b4141f..90e4cd58 100644 --- a/video2dataset/subsamplers/__init__.py +++ b/video2dataset/subsamplers/__init__.py @@ -12,5 +12,3 @@ from .optical_flow_subsampler import OpticalFlowSubsampler from .whisper_subsampler import WhisperSubsampler from .caption_subsampler import CaptionSubsampler - -from .subsampler import Subsampler diff --git a/video2dataset/types.py b/video2dataset/types.py index 03ae7c1e..77240d32 100644 --- a/video2dataset/types.py +++ b/video2dataset/types.py @@ -10,7 +10,3 @@ class EncodeFormats(TypedDict, total=False): class Streams(TypedDict, total=False): video: List[bytes] audio: List[bytes] - - -# TODO: make more structured -Metadata = dict diff --git a/video2dataset/workers/download_worker.py b/video2dataset/workers/download_worker.py index 3c1bdac4..d008f921 100644 --- a/video2dataset/workers/download_worker.py +++ b/video2dataset/workers/download_worker.py @@ -1,15 +1,29 @@ """the downloader module handles the downloading""" -import fsspec + import math -from multiprocessing.pool import ThreadPool -import pyarrow as pa import time +import pyarrow as pa import traceback -from typing import cast + +import fsspec + +from multiprocessing.pool import ThreadPool +from threading import Semaphore +from typing import List, Any +import numpy as np from video2dataset.data_reader import VideoDataReader +from video2dataset.logger import CappedCounter from video2dataset.logger import write_stats -from video2dataset.workers.worker import ShardStatus, Streams, get_subsamplers, process_sample +from video2dataset.subsamplers import ( + ClippingSubsampler, + CutDetectionSubsampler, + FrameSubsampler, + FFProbeSubsampler, + NoOpSubsampler, + ResolutionSubsampler, + AudioRateSubsampler, +) def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count): @@ -38,154 +52,252 @@ def __init__( self.save_caption = save_caption self.output_folder = output_folder self.column_list = column_list - self.input_encode_formats = encode_formats + self.encode_formats = encode_formats self.config = config + self.data_reader = VideoDataReader(encode_formats, tmp_dir, config["reading"]) - self.url_indice = self.column_list.index("url") - self.caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None - self.oom_sample_per_shard = math.ceil(math.log10(self.config["storage"]["number_sample_per_shard"])) - self.subsamplers, self.output_encode_formats = get_subsamplers( - config, + + self.clipping_subsampler = ClippingSubsampler( + 5, # oom_clip_count encode_formats, - do_clipping=("clips" in self.column_list), + **self.config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], ) + need_keyframes = self.clipping_subsampler.precision == "keyframe_adjusted" + + self.ffprobe_subsampler = None + if "FFProbeSubsampler" in self.config["subsampling"] or need_keyframes: + self.ffprobe_subsampler = FFProbeSubsampler( + **self.config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"] + ) + self.ffprobe_subsampler.extract_keyframes |= need_keyframes + + self.cut_detector = None + self.cuts_are_clips = False + if "CutDetectionSubsampler" in self.config["subsampling"]: + if "args" in self.config["subsampling"]["CutDetectionSubsampler"]: + self.cut_detector = CutDetectionSubsampler( + **self.config["subsampling"]["CutDetectionSubsampler"]["args"] + ) + self.cuts_are_clips = self.config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) + + self.noop_subsampler = NoOpSubsampler() + + video_subsamplers: List[Any] = [] + if "ResolutionSubsampler" in self.config["subsampling"]: + video_subsamplers.append(ResolutionSubsampler(**self.config["subsampling"]["ResolutionSubsampler"]["args"])) + if "FrameSubsampler" in self.config["subsampling"]: + video_subsamplers.append(FrameSubsampler(**self.config["subsampling"]["FrameSubsampler"]["args"])) + + audio_subsamplers: List[Any] = [] + if "AudioRateSubsampler" in self.config["subsampling"]: + audio_subsamplers.append(AudioRateSubsampler(**self.config["subsampling"]["AudioRateSubsampler"]["args"])) + + self.subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} def __call__( self, row, ): try: - shard_file, shard_id = row - self.process_shard(shard_file, shard_id) + self.download_shard(row) return (True, row) except Exception as err: # pylint: disable=broad-except traceback.print_exc() print(f"shard {row[0]} failed with error {err}") return (False, row) - def get_shard_processors( + def download_shard( self, - shard_file: str, - shard_id: int, + row, ): - """Get objects for loading and writing data""" + """Function to start an video downloading in one process""" + + # shard_id, shard_file = row + shard_file, shard_id = row + start_time = time.time() fs, shard_path = fsspec.core.url_to_fs(shard_file) - print(shard_path) with fs.open(shard_path, "rb") as f: df = pa.ipc.open_file(f).read_all() - schema = df.schema schema = df.schema schema = ( schema.append(pa.field("key", pa.string())) .append(pa.field("status", pa.string())) .append(pa.field("error_message", pa.string())) ) - shard_sample_writer = self.sample_writer_class( - shard_id, - self.output_folder, - self.save_caption, - self.config["storage"]["oom_shard_count"], - schema, - self.output_encode_formats, - ) + pydict = df.select(self.column_list).to_pydict() shard_to_dl = list(enumerate(zip(*(pydict[col] for col in self.column_list)))) + del pydict + del df - def rm_shard_path(): - fs.rm(shard_path) + status_dict = CappedCounter() - return shard_sample_writer, shard_to_dl, rm_shard_path + count = len(shard_to_dl) + successes = 0 + failed = { + "failed_to_download": 0, + "failed_to_subsample": 0, + } + bytes_downloaded = 0 + url_indice = self.column_list.index("url") + caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None + key_url_list = [(key, x[url_indice]) for key, x in shard_to_dl] - def process_shard( - self, - shard_file: str, - shard_id: int, - ): - """Function to start an video downloading in one process""" - - start_time = time.time() - shard_sample_writer, shard_to_dl, rm_shard_path = self.get_shard_processors(shard_file, shard_id) - shard_status = ShardStatus(count=len(shard_to_dl)) + semaphore = Semaphore(self.config["distribution"]["thread_count"]) def data_generator(): - for key_and_url in [(key, x[self.url_indice]) for key, x in shard_to_dl]: - yield key_and_url + for e in key_url_list: + semaphore.acquire() # pylint: disable=(consider-using-with) + yield e + + loader = data_generator() - data_reader_call_param_generator = data_generator() + # The subsamplers might change the output format, so we need to update the writer + writer_encode_formats = self.encode_formats.copy() + if self.subsamplers["audio"]: + writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_formats["audio"] + if self.subsamplers["video"]: + writer_encode_formats["video"] = self.subsamplers["video"][0].encode_formats["video"] + + # give schema to writer + sample_writer = self.sample_writer_class( + shard_id, + self.output_folder, + self.save_caption, + self.config["storage"]["oom_shard_count"], + schema, + writer_encode_formats, + ) + oom_sample_per_shard = math.ceil(math.log10(self.config["storage"]["number_sample_per_shard"])) with ThreadPool(self.config["distribution"]["thread_count"]) as thread_pool: - for key, streams, yt_meta_dict, shard_status.error_message in thread_pool.imap_unordered( + for key, streams, yt_meta_dict, error_message in thread_pool.imap_unordered( self.data_reader, # pylint: disable=(unnecessary-lambda) - data_reader_call_param_generator, + loader, ): try: _, sample_data = shard_to_dl[key] str_key = compute_key( - key, shard_id, self.oom_sample_per_shard, self.config["storage"]["oom_shard_count"] + key, shard_id, oom_sample_per_shard, self.config["storage"]["oom_shard_count"] ) - caption = sample_data[self.caption_indice] if self.caption_indice is not None else None - metadata = { + meta = { **{self.column_list[i]: sample_data[i] for i in range(len(self.column_list))}, "key": str_key, "status": None, - "error_message": shard_status.error_message, + "error_message": error_message, "yt_meta_dict": yt_meta_dict, } - except Exception as err: # pylint: disable=broad-except - traceback.print_exc() - print(f"Sample {key} failed to download: {err}") - return - try: - if shard_status.error_message is not None: - print(shard_status.error_message) - if "[youtube]" in shard_status.error_message: # video-specific error, remove videoID - shard_status.error_message = "ERROR: [youtube]:" + shard_status.error_message.split(":")[-1] - raise ValueError - except Exception: # pylint: disable=broad-except - shard_status.failed["failed_to_download"] += 1 - shard_status.status_dict.increment(shard_status.error_message) - metadata["status"] = "failed_to_download" - metadata["error_message"] = shard_status.error_message - shard_sample_writer.write( - {}, - str_key, - sample_data[self.caption_indice] if self.caption_indice is not None else None, - metadata, + if error_message is not None: + print(error_message) + if "[youtube]" in error_message: # video-specific error, remove videoID + error_message = "ERROR: [youtube]:" + error_message.split(":")[-1] + raise ValueError("failed_to_download") + + for stream in streams.values(): + bytes_downloaded += len(stream) + for mod in streams: + streams[mod] = [streams[mod]] + + if self.ffprobe_subsampler is not None: + streams, meta, error_message = self.ffprobe_subsampler(streams, meta) + if error_message is not None: + raise ValueError("failed_to_subsample") + + if self.config["storage"]["captions_are_subtitles"]: # create clips + # all langs have same start and end times + subtitles = meta["yt_meta_dict"]["subtitles"][list(meta["yt_meta_dict"]["subtitles"].keys())[0]] + meta["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] + elif self.cut_detector is not None: # apply cut detection to get clips + streams, cuts, error_message = self.cut_detector(streams) + + if error_message is not None: + raise ValueError("failed_to_subsample") + + meta["cuts"] = cuts + + if self.cuts_are_clips: + cuts = meta["cuts"]["cuts_original_fps"] + native_fps = meta["cuts"]["original_fps"] + meta["clips"] = (np.array(cuts) / native_fps).tolist() + + # 1 video -> many videos (either clipping or noop which does identity broadcasting) + broadcast_subsampler = ( + self.clipping_subsampler + if ( + "clips" in self.column_list + or self.config["storage"]["captions_are_subtitles"] + or self.cuts_are_clips + ) + else self.noop_subsampler ) - return - - for stream in streams.values(): - shard_status.bytes_downloaded += len(stream) - for modality in streams: - streams[modality] = [streams[modality]] - - process_sample( - subsamplers=self.subsamplers, - shard_status=shard_status, - streams=cast(Streams, streams), - key=str_key, - caption=cast(str, caption), - metadata=metadata, - captions_are_subtitles=self.config["storage"]["captions_are_subtitles"], - shard_sample_writer=shard_sample_writer, - ) + subsampled_streams, metas, error_message = broadcast_subsampler(streams, meta) - shard_sample_writer.close() - rm_shard_path() - end_time = time.time() + for modality in subsampled_streams: + for modality_subsampler in self.subsamplers[modality]: + subsampled_streams, metas, error_message = modality_subsampler(subsampled_streams, metas) + + if error_message is not None: + meta["clips"] = [] + raise ValueError("failed_to_subsample") + successes += 1 + status = "success" + status_dict.increment(status) + subsampled_streams_list = [ + dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values()) + ] + for subsampled_streams, meta in zip(subsampled_streams_list, metas): + meta["status"] = status + + text_caption = sample_data[caption_indice] if caption_indice is not None else None + if self.config["storage"]["captions_are_subtitles"]: + text_caption = meta.get("clip_subtitles")[0]["lines"] + + sample_writer.write( + subsampled_streams, + meta["key"], + text_caption, + meta, + ) + except Exception as err: # pylint: disable=broad-except + status = str(err) + if status.startswith("failed_to_"): + failed[status] += 1 + status_dict.increment(error_message) + meta["status"] = status + meta["error_message"] = error_message + sample_writer.write( + {}, + str_key, + sample_data[caption_indice] if caption_indice is not None else None, + meta, + ) + semaphore.release() + else: + traceback.print_exc() + print(f"Sample {key} failed to download: {err}") + + semaphore.release() + + sample_writer.close() + thread_pool.terminate() + thread_pool.join() + del thread_pool + + end_time = time.time() write_stats( self.output_folder, shard_id, - shard_status.count, - shard_status.successes, - shard_status.failed["failed_to_download"], - shard_status.failed["failed_to_subsample"], - shard_status.bytes_downloaded, + count, + successes, + failed["failed_to_download"], + failed["failed_to_subsample"], + bytes_downloaded, start_time, end_time, - shard_status.status_dict, + status_dict, self.config["storage"]["oom_shard_count"], ) + fs.rm(shard_path) diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index 519aad3e..ad4c9ebc 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -1,16 +1,77 @@ """creates a subset of an existing dataset inside the sample dimension""" -import fsspec +from dataclasses import dataclass, field +import time import json import pyarrow as pa -import time import traceback -from typing import Literal, cast + +import fsspec +import numpy as np import webdataset as wds +from typing import List, Any, Optional, Literal, cast from video2dataset.dataloader import get_video_dataset -from video2dataset.logger import write_stats +from video2dataset.logger import CappedCounter, write_stats +from video2dataset.subsamplers import ( + ClippingSubsampler, + CutDetectionSubsampler, + FrameSubsampler, + FFProbeSubsampler, + NoOpSubsampler, + ResolutionSubsampler, + AudioRateSubsampler, +) from video2dataset.types import EncodeFormats, Streams -from video2dataset.workers.worker import ShardStatus, get_subsamplers, process_sample + + +def get_subsamplers(config: dict, encode_formats: EncodeFormats): + """Initialize all subsamplers using config""" + + clipping_subsampler = ClippingSubsampler( + 5, # oom_clip_count + encode_formats, + **config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], + ) + need_keyframes = clipping_subsampler.precision == "keyframe_adjusted" + + cut_detection_subsampler = None + cuts_are_clips = False + if "CutDetectionSubsampler" in config["subsampling"]: + if "args" in config["subsampling"]["CutDetectionSubsampler"]: + cut_detection_subsampler = CutDetectionSubsampler(**config["subsampling"]["CutDetectionSubsampler"]["args"]) + cuts_are_clips = config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) + + broadcast_subsampler = ( + clipping_subsampler if (config["storage"]["captions_are_subtitles"] or cuts_are_clips) else NoOpSubsampler() + ) + + ffprobe_subsampler = None + if "FFProbeSubsampler" in config["subsampling"] or need_keyframes: + ffprobe_subsampler = FFProbeSubsampler(**config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"]) + ffprobe_subsampler.extract_keyframes |= need_keyframes + + video_subsamplers: List[Any] = [] + if "ResolutionSubsampler" in config["subsampling"]: + video_subsamplers.append(ResolutionSubsampler(**config["subsampling"]["ResolutionSubsampler"]["args"])) + if "FrameSubsampler" in config["subsampling"]: + video_subsamplers.append(FrameSubsampler(**config["subsampling"]["FrameSubsampler"]["args"])) + + audio_subsamplers: List[Any] = [] + if "AudioRateSubsampler" in config["subsampling"]: + audio_subsamplers.append(AudioRateSubsampler(**config["subsampling"]["AudioRateSubsampler"]["args"])) + + modal_subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} + + return ffprobe_subsampler, modal_subsamplers, cut_detection_subsampler, cuts_are_clips, broadcast_subsampler + + +@dataclass +class ShardStatus: + successes: int = 0 + failed_to_subsample: int = 0 + status_dict: CappedCounter = field(default_factory=CappedCounter) + error_message: Optional[str] = None + count: int = 0 class SubsetWorker: @@ -26,31 +87,50 @@ def __init__( self.sample_writer_class = sample_writer_class self.output_folder = output_folder self.config = config + ( + self.ffprobe_subsampler, + self.modal_subsamplers, + self.cut_detection_subsampler, + self.cuts_are_clips, + self.broadcast_subsampler, + ) = get_subsamplers(config, encode_formats) + + # set encoding formats self.input_encode_formats = encode_formats - self.subsamplers, self.output_encode_formats = get_subsamplers(config, self.input_encode_formats) + self.output_encode_formats = self.input_encode_formats.copy() + if self.modal_subsamplers["audio"]: + assert ( + len({s.encode_format for s in self.modal_subsamplers["audio"]}) == 1 + ) # assert that all audio subsamplers have the same output format + self.output_encode_formats["audio"] = self.modal_subsamplers["audio"][0].encode_format + if self.modal_subsamplers["video"]: + assert ( + len({s.encode_format for s in self.modal_subsamplers["video"]}) == 1 + ) # assert that all video subsamplers have the same output format + self.output_encode_formats["video"] = self.modal_subsamplers["video"][0].encode_format def __call__( self, row, ): try: - shard_file, shard_id = row - self.process_shard(shard_file, shard_id) + shard, shard_id = row + self.process_shard(shard, shard_id) return (True, row) except Exception as err: # pylint: disable=broad-except traceback.print_exc() - print(f"shard_file {row[0]} failed with error {err}") + print(f"shard {row[0]} failed with error {err}") return (False, row) def get_shard_processors( self, - shard_file: str, + shard: str, shard_id: int, ): """Get objects for loading and writing data""" try: - fs, shard_path = fsspec.core.url_to_fs(shard_file[: -len(".tar")] + ".parquet") + fs, shard_path = fsspec.core.url_to_fs(shard[: -len(".tar")] + ".parquet") with fs.open(shard_path, "rb") as f: df = pa.parquet.read_table(f) schema = df.schema @@ -70,7 +150,7 @@ def get_shard_processors( self.output_encode_formats, ) shard_dataloader = get_video_dataset( - urls=shard_file, + urls=shard, batch_size=1, decoder_kwargs={}, enforce_additional_keys=[], @@ -80,13 +160,13 @@ def get_shard_processors( def process_shard( self, - shard_file: str, + shard: str, shard_id: int, ): """Function to start an video processing in one process""" start_time = time.time() - shard_sample_writer, shard_dataloader = self.get_shard_processors(shard_file, shard_id) + shard_sample_writer, shard_dataloader = self.get_shard_processors(shard, shard_id) shard_status = ShardStatus() for sample in shard_dataloader: @@ -94,27 +174,82 @@ def process_shard( key = sample["__key__"] try: caption = sample.get("txt", b"").decode("utf-8") - metadata = json.loads(sample.get("json", b"{}").decode("utf-8")) + meta = json.loads(sample.get("json", b"{}").decode("utf-8")) except Exception as err: # pylint: disable=broad-except traceback.print_exc() print(f"Sample {key} failed to download: {err}") return - streams: Streams = {} - for modality, encode_format in self.input_encode_formats.items(): - modality = cast(Literal["audio", "video"], modality) - streams[modality] = [sample[encode_format]] - - process_sample( - subsamplers=self.subsamplers, - shard_status=shard_status, - streams=streams, - key=key, - caption=caption, - metadata=metadata, - captions_are_subtitles=self.config["storage"]["captions_are_subtitles"], - shard_sample_writer=shard_sample_writer, - ) + try: + streams: Streams = {} + for modality, encode_format in self.input_encode_formats.items(): + modality = cast(Literal["audio", "video"], modality) + streams[modality] = [sample[encode_format]] + + if self.ffprobe_subsampler is not None: + streams, meta, shard_status.error_message = self.ffprobe_subsampler(streams, meta) + assert shard_status.error_message is None + + if self.config["storage"]["captions_are_subtitles"]: # create clips + subtitles = meta["yt_meta_dict"]["subtitles"] + meta["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] + elif self.cut_detection_subsampler is not None: # apply cut detection to get clips + streams, cuts, shard_status.error_message = self.cut_detection_subsampler(streams) + assert shard_status.error_message is None + meta["cuts"] = cuts + assert cuts is not None + if self.cuts_are_clips: + meta["clips"] = (np.array(cuts["cuts_original_fps"]) / cuts["original_fps"]).tolist() + + # 1 video -> many videos (either clipping or noop which does identity broadcasting) + subsampled_streams, metas, shard_status.error_message = self.broadcast_subsampler(streams, meta) + if shard_status.error_message is not None: + meta["clips"] = [] + assert False + + for modality in list(subsampled_streams.keys()): + for modality_subsampler in self.modal_subsamplers[modality]: + subsampled_streams, metas, shard_status.error_message = modality_subsampler( + subsampled_streams, metas + ) + assert shard_status.error_message is None + + shard_status.successes += 1 + status = "success" + shard_status.status_dict.increment(status) + + subsampled_streams_list = [dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values())] + if len(subsampled_streams_list) == 0: # no audio or video, just write meta + meta["status"] = status + shard_sample_writer.write( + {}, + key, + caption, + meta, + ) + continue + for subsampled_streams, meta in zip(subsampled_streams_list, metas): + meta["status"] = status + text_caption = caption + if self.config["storage"]["captions_are_subtitles"]: + text_caption = meta.get("clip_subtitles")[0]["lines"][0] + shard_sample_writer.write( + subsampled_streams, + meta["key"], + text_caption, + meta, + ) + except Exception: # pylint: disable=broad-except + shard_status.failed_to_subsample += 1 + shard_status.status_dict.increment(shard_status.error_message) + meta["status"] = "failed_to_subsample" + meta["error_message"] = shard_status.error_message + shard_sample_writer.write( + {}, + key, + caption, + meta, + ) shard_sample_writer.close() end_time = time.time() @@ -125,7 +260,7 @@ def process_shard( shard_status.count, shard_status.successes, 0, # failed to download - shard_status.failed["failed_to_subsample"], + shard_status.failed_to_subsample, 0, # bytes downloaded start_time, end_time, diff --git a/video2dataset/workers/worker.py b/video2dataset/workers/worker.py deleted file mode 100644 index 45650829..00000000 --- a/video2dataset/workers/worker.py +++ /dev/null @@ -1,197 +0,0 @@ -"""Standard worker for video2dataset.""" -from dataclasses import dataclass, field -import numpy as np -from typing import Any, List, Tuple, Optional - -from video2dataset.logger import CappedCounter -from video2dataset.subsamplers import ( - ClippingSubsampler, - CutDetectionSubsampler, - FrameSubsampler, - FFProbeSubsampler, - NoOpSubsampler, - ResolutionSubsampler, - AudioRateSubsampler, - Subsampler, -) -from video2dataset.types import EncodeFormats, Streams, Metadata - - -@dataclass -class ShardStatus: - """Shard processing status""" - - successes: int = 0 - failed: dict = field( - default_factory=lambda: { - "failed_to_download": 0, - "failed_to_subsample": 0, - } - ) - status_dict: CappedCounter = field(default_factory=CappedCounter) - error_message: Optional[str] = None - count: int = 0 - bytes_downloaded: int = 0 - - -@dataclass -class Subsamplers: - """Subsamplers used in processing""" - - ffprobe_subsampler: Optional[FFProbeSubsampler] = None - modal_subsamplers: dict = field(default_factory=dict) - cut_detection_subsampler: Optional[CutDetectionSubsampler] = None - cuts_are_clips: bool = False - broadcast_subsampler: Subsampler = field(default_factory=NoOpSubsampler) - - -def get_subsamplers( - config: dict, - input_encode_formats: EncodeFormats, - do_clipping: bool = False, -) -> Tuple[Subsamplers, EncodeFormats]: - """Initialize all subsamplers using config""" - - clipping_subsampler = ClippingSubsampler( - oom_clip_count=5, - encode_formats=input_encode_formats, - **config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], - ) - need_keyframes = clipping_subsampler.precision == "keyframe_adjusted" - - cut_detection_subsampler = None - cuts_are_clips = False - if "CutDetectionSubsampler" in config["subsampling"]: - if "args" in config["subsampling"]["CutDetectionSubsampler"]: - cut_detection_subsampler = CutDetectionSubsampler(**config["subsampling"]["CutDetectionSubsampler"]["args"]) - cuts_are_clips = config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) - - broadcast_subsampler = ( - clipping_subsampler - if (do_clipping or config["storage"]["captions_are_subtitles"] or cuts_are_clips) - else NoOpSubsampler() - ) - - ffprobe_subsampler = None - if "FFProbeSubsampler" in config["subsampling"] or need_keyframes: - ffprobe_subsampler = FFProbeSubsampler(**config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"]) - ffprobe_subsampler.extract_keyframes |= need_keyframes - - video_subsamplers: List[Any] = [] - if "ResolutionSubsampler" in config["subsampling"]: - video_subsamplers.append(ResolutionSubsampler(**config["subsampling"]["ResolutionSubsampler"]["args"])) - if "FrameSubsampler" in config["subsampling"]: - video_subsamplers.append(FrameSubsampler(**config["subsampling"]["FrameSubsampler"]["args"])) - - audio_subsamplers: List[Any] = [] - if "AudioRateSubsampler" in config["subsampling"]: - audio_subsamplers.append(AudioRateSubsampler(**config["subsampling"]["AudioRateSubsampler"]["args"])) - - modal_subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} - - # output encoding formats - output_encode_formats = input_encode_formats.copy() - if modal_subsamplers["audio"]: - assert ( - len({s.encode_format for s in modal_subsamplers["audio"]}) == 1 - ) # assert that all audio subsamplers have the same output format - output_encode_formats["audio"] = modal_subsamplers["audio"][0].encode_format - if modal_subsamplers["video"]: - assert ( - len({s.encode_format for s in modal_subsamplers["video"]}) == 1 - ) # assert that all video subsamplers have the same output format - output_encode_formats["video"] = modal_subsamplers["video"][0].encode_format - - return ( - Subsamplers( - ffprobe_subsampler=ffprobe_subsampler, - modal_subsamplers=modal_subsamplers, - cut_detection_subsampler=cut_detection_subsampler, - cuts_are_clips=cuts_are_clips, - broadcast_subsampler=broadcast_subsampler, - ), - output_encode_formats, - ) - - -def process_sample( - subsamplers: Subsamplers, - shard_status: ShardStatus, - streams: Streams, - key: str, - caption: str, - metadata: Metadata, - captions_are_subtitles: bool, - shard_sample_writer: Any, # TODO: type correctly -): - """Process a single video""" - - try: - if subsamplers.ffprobe_subsampler is not None: - streams, metadata, shard_status.error_message = subsamplers.ffprobe_subsampler(streams, metadata) - assert shard_status.error_message is None - - if captions_are_subtitles: # create clips - subtitles = metadata["yt_meta_dict"]["subtitles"] - metadata["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] - elif subsamplers.cut_detection_subsampler is not None: # apply cut detection to get clips - streams, cuts, shard_status.error_message = subsamplers.cut_detection_subsampler(streams) - assert shard_status.error_message is None - metadata["cuts"] = cuts - assert cuts is not None - if subsamplers.cuts_are_clips: - metadata["clips"] = (np.array(cuts["cuts_original_fps"]) / cuts["original_fps"]).tolist() - - # 1 video -> many videos (either clipping or noop which does identity broadcasting) - subsampled_streams, metadatas, shard_status.error_message = subsamplers.broadcast_subsampler(streams, metadata) - if shard_status.error_message is not None: - metadata["clips"] = [] - assert False - - for modality in list(subsampled_streams.keys()): - for modality_subsampler in subsamplers.modal_subsamplers[modality]: - subsampled_streams, metadatas, shard_status.error_message = modality_subsampler( - subsampled_streams, metadatas - ) - assert shard_status.error_message is None - - shard_status.successes += 1 - status = "success" - shard_status.status_dict.increment(status) - - subsampled_streams_list = [dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values())] - if len(subsampled_streams_list) == 0: # no audio or video, just write metadata - metadata["status"] = status - shard_sample_writer.write( - {}, - key, - caption, - metadata, - ) - return - for subsampled_streams, subsampled_metadata in zip(subsampled_streams_list, metadatas): - subsampled_metadata["status"] = status - text_caption = caption - if captions_are_subtitles: - clip_subtitles = subsampled_metadata.get("clip_subtitles") - first_clip_subtitles = clip_subtitles[0] if clip_subtitles else None - subtitle_lines = first_clip_subtitles["lines"] if first_clip_subtitles else None - text_caption = subtitle_lines[0] if subtitle_lines else text_caption - shard_sample_writer.write( - subsampled_streams, - subsampled_metadata["key"], - text_caption, - subsampled_metadata, - ) - except Exception as err: # pylint: disable=broad-except - print(err) - shard_status.failed["failed_to_subsample"] += 1 - shard_status.status_dict.increment(shard_status.error_message) - metadata["status"] = "failed_to_subsample" - metadata["error_message"] = shard_status.error_message - shard_sample_writer.write( - {}, - key, - caption, - metadata, - )