Skip to content

Commit

Permalink
Download worker refactor (#288)
Browse files Browse the repository at this point in the history
* ClippingSubsampler rewrite and bug fixes

* More refactoring of ClippingSubsampler, plus a fix to _get_clip_intervals

* Finished refactoring ClippingSubsampler

* Final code changes

* Added docstrings

* Passed tests and linting

* Made type annotations consistent with Python 3.8

* More annotation fixes

* The Python 3.8 annotation needs a lot of hand-holding, it seems

* Pylint has to cut it out, I swear to God

* No real change, just relauching unit tests which failed due to connection timeouts

* Linting issue

* Another linting issue

* Separated per-shard code from code that should only be executed once

* Pulled ShardStatus parameters into their own data type

* Cleaned up shard processing error handling

* Cleaned up code

* Bug fixes

* Formatting

* Fixed linting issues

* Fixing more damn linting

* Added a missing docstring

* Unified SubsetWorker and DownloadWorker code

* Bug fixes

* Linting

* Linting again

* Forgot a docstring

* Removed unnecessary manual thread handling

* Removed unused import

---------

Co-authored-by: iejMac <kilianmaciej6@gmail.com>
Co-authored-by: Romain Beaumont <romain.rom1@gmail.com>
  • Loading branch information
3 people authored Jan 27, 2024
1 parent e1b5d89 commit 83afef0
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 374 deletions.
2 changes: 2 additions & 0 deletions video2dataset/subsamplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@
from .optical_flow_subsampler import OpticalFlowSubsampler
from .whisper_subsampler import WhisperSubsampler
from .caption_subsampler import CaptionSubsampler

from .subsampler import Subsampler
4 changes: 4 additions & 0 deletions video2dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ class EncodeFormats(TypedDict, total=False):
class Streams(TypedDict, total=False):
video: List[bytes]
audio: List[bytes]


# TODO: make more structured
Metadata = dict
304 changes: 96 additions & 208 deletions video2dataset/workers/download_worker.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,15 @@
"""the downloader module handles the downloading"""

import fsspec
import math
import time
from multiprocessing.pool import ThreadPool
import pyarrow as pa
import time
import traceback

import fsspec

from multiprocessing.pool import ThreadPool
from threading import Semaphore
from typing import List, Any
import numpy as np
from typing import cast

from video2dataset.data_reader import VideoDataReader
from video2dataset.logger import CappedCounter
from video2dataset.logger import write_stats
from video2dataset.subsamplers import (
ClippingSubsampler,
CutDetectionSubsampler,
FrameSubsampler,
FFProbeSubsampler,
NoOpSubsampler,
ResolutionSubsampler,
AudioRateSubsampler,
)
from video2dataset.workers.worker import ShardStatus, Streams, get_subsamplers, process_sample


def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count):
Expand Down Expand Up @@ -52,252 +38,154 @@ def __init__(
self.save_caption = save_caption
self.output_folder = output_folder
self.column_list = column_list
self.encode_formats = encode_formats
self.input_encode_formats = encode_formats
self.config = config

self.data_reader = VideoDataReader(encode_formats, tmp_dir, config["reading"])

self.clipping_subsampler = ClippingSubsampler(
5, # oom_clip_count
self.url_indice = self.column_list.index("url")
self.caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None
self.oom_sample_per_shard = math.ceil(math.log10(self.config["storage"]["number_sample_per_shard"]))
self.subsamplers, self.output_encode_formats = get_subsamplers(
config,
encode_formats,
**self.config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"],
do_clipping=("clips" in self.column_list),
)
need_keyframes = self.clipping_subsampler.precision == "keyframe_adjusted"

self.ffprobe_subsampler = None
if "FFProbeSubsampler" in self.config["subsampling"] or need_keyframes:
self.ffprobe_subsampler = FFProbeSubsampler(
**self.config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"]
)
self.ffprobe_subsampler.extract_keyframes |= need_keyframes

self.cut_detector = None
self.cuts_are_clips = False
if "CutDetectionSubsampler" in self.config["subsampling"]:
if "args" in self.config["subsampling"]["CutDetectionSubsampler"]:
self.cut_detector = CutDetectionSubsampler(
**self.config["subsampling"]["CutDetectionSubsampler"]["args"]
)
self.cuts_are_clips = self.config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False)

self.noop_subsampler = NoOpSubsampler()

video_subsamplers: List[Any] = []
if "ResolutionSubsampler" in self.config["subsampling"]:
video_subsamplers.append(ResolutionSubsampler(**self.config["subsampling"]["ResolutionSubsampler"]["args"]))
if "FrameSubsampler" in self.config["subsampling"]:
video_subsamplers.append(FrameSubsampler(**self.config["subsampling"]["FrameSubsampler"]["args"]))

audio_subsamplers: List[Any] = []
if "AudioRateSubsampler" in self.config["subsampling"]:
audio_subsamplers.append(AudioRateSubsampler(**self.config["subsampling"]["AudioRateSubsampler"]["args"]))

self.subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers}

def __call__(
self,
row,
):
try:
self.download_shard(row)
shard_file, shard_id = row
self.process_shard(shard_file, shard_id)
return (True, row)
except Exception as err: # pylint: disable=broad-except
traceback.print_exc()
print(f"shard {row[0]} failed with error {err}")
return (False, row)

def download_shard(
def get_shard_processors(
self,
row,
shard_file: str,
shard_id: int,
):
"""Function to start an video downloading in one process"""

# shard_id, shard_file = row
shard_file, shard_id = row
start_time = time.time()
"""Get objects for loading and writing data"""

fs, shard_path = fsspec.core.url_to_fs(shard_file)
print(shard_path)
with fs.open(shard_path, "rb") as f:
df = pa.ipc.open_file(f).read_all()
schema = df.schema
schema = df.schema
schema = (
schema.append(pa.field("key", pa.string()))
.append(pa.field("status", pa.string()))
.append(pa.field("error_message", pa.string()))
)

shard_sample_writer = self.sample_writer_class(
shard_id,
self.output_folder,
self.save_caption,
self.config["storage"]["oom_shard_count"],
schema,
self.output_encode_formats,
)
pydict = df.select(self.column_list).to_pydict()
shard_to_dl = list(enumerate(zip(*(pydict[col] for col in self.column_list))))
del pydict
del df

status_dict = CappedCounter()

count = len(shard_to_dl)
successes = 0
failed = {
"failed_to_download": 0,
"failed_to_subsample": 0,
}
bytes_downloaded = 0
url_indice = self.column_list.index("url")
caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None
key_url_list = [(key, x[url_indice]) for key, x in shard_to_dl]
def rm_shard_path():
fs.rm(shard_path)

semaphore = Semaphore(self.config["distribution"]["thread_count"])
return shard_sample_writer, shard_to_dl, rm_shard_path

def data_generator():
for e in key_url_list:
semaphore.acquire() # pylint: disable=(consider-using-with)
yield e
def process_shard(
self,
shard_file: str,
shard_id: int,
):
"""Function to start an video downloading in one process"""

loader = data_generator()
start_time = time.time()
shard_sample_writer, shard_to_dl, rm_shard_path = self.get_shard_processors(shard_file, shard_id)
shard_status = ShardStatus(count=len(shard_to_dl))

# The subsamplers might change the output format, so we need to update the writer
writer_encode_formats = self.encode_formats.copy()
if self.subsamplers["audio"]:
writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_formats["audio"]
if self.subsamplers["video"]:
writer_encode_formats["video"] = self.subsamplers["video"][0].encode_formats["video"]
def data_generator():
for key_and_url in [(key, x[self.url_indice]) for key, x in shard_to_dl]:
yield key_and_url

# give schema to writer
sample_writer = self.sample_writer_class(
shard_id,
self.output_folder,
self.save_caption,
self.config["storage"]["oom_shard_count"],
schema,
writer_encode_formats,
)
oom_sample_per_shard = math.ceil(math.log10(self.config["storage"]["number_sample_per_shard"]))
data_reader_call_param_generator = data_generator()

with ThreadPool(self.config["distribution"]["thread_count"]) as thread_pool:
for key, streams, yt_meta_dict, error_message in thread_pool.imap_unordered(
for key, streams, yt_meta_dict, shard_status.error_message in thread_pool.imap_unordered(
self.data_reader, # pylint: disable=(unnecessary-lambda)
loader,
data_reader_call_param_generator,
):
try:
_, sample_data = shard_to_dl[key]
str_key = compute_key(
key, shard_id, oom_sample_per_shard, self.config["storage"]["oom_shard_count"]
key, shard_id, self.oom_sample_per_shard, self.config["storage"]["oom_shard_count"]
)
meta = {
caption = sample_data[self.caption_indice] if self.caption_indice is not None else None
metadata = {
**{self.column_list[i]: sample_data[i] for i in range(len(self.column_list))},
"key": str_key,
"status": None,
"error_message": error_message,
"error_message": shard_status.error_message,
"yt_meta_dict": yt_meta_dict,
}

if error_message is not None:
print(error_message)
if "[youtube]" in error_message: # video-specific error, remove videoID
error_message = "ERROR: [youtube]:" + error_message.split(":")[-1]
raise ValueError("failed_to_download")

for stream in streams.values():
bytes_downloaded += len(stream)
for mod in streams:
streams[mod] = [streams[mod]]

if self.ffprobe_subsampler is not None:
streams, meta, error_message = self.ffprobe_subsampler(streams, meta)
if error_message is not None:
raise ValueError("failed_to_subsample")

if self.config["storage"]["captions_are_subtitles"]: # create clips
# all langs have same start and end times
subtitles = meta["yt_meta_dict"]["subtitles"][list(meta["yt_meta_dict"]["subtitles"].keys())[0]]
meta["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles]
elif self.cut_detector is not None: # apply cut detection to get clips
streams, cuts, error_message = self.cut_detector(streams)

if error_message is not None:
raise ValueError("failed_to_subsample")

meta["cuts"] = cuts

if self.cuts_are_clips:
cuts = meta["cuts"]["cuts_original_fps"]
native_fps = meta["cuts"]["original_fps"]
meta["clips"] = (np.array(cuts) / native_fps).tolist()

# 1 video -> many videos (either clipping or noop which does identity broadcasting)
broadcast_subsampler = (
self.clipping_subsampler
if (
"clips" in self.column_list
or self.config["storage"]["captions_are_subtitles"]
or self.cuts_are_clips
)
else self.noop_subsampler
)
subsampled_streams, metas, error_message = broadcast_subsampler(streams, meta)

for modality in subsampled_streams:
for modality_subsampler in self.subsamplers[modality]:
subsampled_streams, metas, error_message = modality_subsampler(subsampled_streams, metas)

if error_message is not None:
meta["clips"] = []
raise ValueError("failed_to_subsample")

successes += 1
status = "success"
status_dict.increment(status)
subsampled_streams_list = [
dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values())
]
for subsampled_streams, meta in zip(subsampled_streams_list, metas):
meta["status"] = status

text_caption = sample_data[caption_indice] if caption_indice is not None else None
if self.config["storage"]["captions_are_subtitles"]:
text_caption = meta.get("clip_subtitles")[0]["lines"]

sample_writer.write(
subsampled_streams,
meta["key"],
text_caption,
meta,
)
except Exception as err: # pylint: disable=broad-except
status = str(err)
if status.startswith("failed_to_"):
failed[status] += 1
status_dict.increment(error_message)
meta["status"] = status
meta["error_message"] = error_message
sample_writer.write(
{},
str_key,
sample_data[caption_indice] if caption_indice is not None else None,
meta,
)
semaphore.release()
else:
traceback.print_exc()
print(f"Sample {key} failed to download: {err}")
traceback.print_exc()
print(f"Sample {key} failed to download: {err}")
return

semaphore.release()

sample_writer.close()
thread_pool.terminate()
thread_pool.join()
del thread_pool
try:
if shard_status.error_message is not None:
print(shard_status.error_message)
if "[youtube]" in shard_status.error_message: # video-specific error, remove videoID
shard_status.error_message = "ERROR: [youtube]:" + shard_status.error_message.split(":")[-1]
raise ValueError
except Exception: # pylint: disable=broad-except
shard_status.failed["failed_to_download"] += 1
shard_status.status_dict.increment(shard_status.error_message)
metadata["status"] = "failed_to_download"
metadata["error_message"] = shard_status.error_message
shard_sample_writer.write(
{},
str_key,
sample_data[self.caption_indice] if self.caption_indice is not None else None,
metadata,
)
return

for stream in streams.values():
shard_status.bytes_downloaded += len(stream)
for modality in streams:
streams[modality] = [streams[modality]]

process_sample(
subsamplers=self.subsamplers,
shard_status=shard_status,
streams=cast(Streams, streams),
key=str_key,
caption=cast(str, caption),
metadata=metadata,
captions_are_subtitles=self.config["storage"]["captions_are_subtitles"],
shard_sample_writer=shard_sample_writer,
)

shard_sample_writer.close()
rm_shard_path()
end_time = time.time()

write_stats(
self.output_folder,
shard_id,
count,
successes,
failed["failed_to_download"],
failed["failed_to_subsample"],
bytes_downloaded,
shard_status.count,
shard_status.successes,
shard_status.failed["failed_to_download"],
shard_status.failed["failed_to_subsample"],
shard_status.bytes_downloaded,
start_time,
end_time,
status_dict,
shard_status.status_dict,
self.config["storage"]["oom_shard_count"],
)
fs.rm(shard_path)
Loading

0 comments on commit 83afef0

Please sign in to comment.