diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..b38df29f --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..0b123e98 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,56 @@ +name: Continuous integration + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install + run: | + sudo apt-get update + python3 -m venv .env + source .env/bin/activate + python -m pip install -U pip + make install-dev + sudo apt-get install -y ffmpeg + - name: Lint + run: | + source .env/bin/activate + make lint + tests: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install + run: | + sudo apt-get update + python3 -m venv .env + source .env/bin/activate + make install + make install-dev + sudo apt-get install -y ffmpeg + - name: Unit tests + run: | + source .env/bin/activate + make test + diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 00000000..2973cbee --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,37 @@ +name: Release + +on: + push: + branches: + - main +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions-ecosystem/action-regex-match@v2 + id: regex-match + with: + text: ${{ github.event.head_commit.message }} + regex: '^Release ([^ ]+)' + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine + - name: Release + if: ${{ steps.regex-match.outputs.match != '' }} + uses: softprops/action-gh-release@v1 + with: + tag_name: ${{ steps.regex-match.outputs.group1 }} + - name: Build and publish + if: ${{ steps.regex-match.outputs.match != '' }} + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + python setup.py sdist bdist_wheel + twine upload dist/* diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..69c554dc --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +*.egg-info +.vscode +.env +__pycache__ +.envtest +.coverage* +.env* +wandb +*.pex +.pexing +**/dataset/* +dist/ +build/ \ No newline at end of file diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..711d32e5 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,268 @@ +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + + +[MESSAGES CONTROL] + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. See also the "--disable" option for examples. +enable=indexing-exception,old-raise-syntax + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,no-else-return,wrong-import-order,unnecessary-pass,logging-fstring-interpolation,logging-format-interpolation,C0330 + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes=SQLObject + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E0201 when accessed. Python regular +# expressions are accepted. +generated-members=REQUEST,acl_users,aq_parent + +# List of decorators that create context managers from functions, such as +# contextlib.contextmanager. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the beginning of the name of dummy variables +# (i.e. not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + + +[BASIC] + +# Regular expression which should only match correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression which should only match correct module level names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression which should only match correct function names +function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct method names +method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct instance attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct attribute names in class +# bodies +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct list comprehension / +# generator expression variable names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main) + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=120 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x) + (^\s*(import|from)\s + |\$Id:\s\/\/depot\/.+#\d+\s\$ + |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') + |^\s*\#\ LINT\.ThenChange + |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$ + |pylint + |""" + |\# + |lambda + |(https?|ftp):) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=y + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes= + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +extension-pkg-whitelist=_jsonnet + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls,class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=5 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + + + +[TOKENS] + +# Number of spaces of indent required when the last token on the preceding line +# is an open (, [, or {. +indent-after-paren=4 \ No newline at end of file diff --git a/API.md b/API.md new file mode 100644 index 00000000..928f3fa6 --- /dev/null +++ b/API.md @@ -0,0 +1,221 @@ +# API + +The module exposes a single function `video2dataset` which takes the same arguments as the command line tool: + +## Core Args + +``` +url_list: list of input urls - can be any of the supported input formats + (csv, parquet, braceexpand tar paths etc.) +output_folder: Desired location of output dataset (default = "dataset") +output_format: Format of output dataset, can be (default = "files") + - files, samples saved in subdirectory for each shard (useful for debugging) + - webdataset, samples saved in tars (useful for efficient loading) + - parquet, sampels saved in parquet (as bytes) + - tfrecord, samples saved in tfrecord (as bytes) + - dummy, does not save (useful for benchmarks) +input_format: Format of the input, can be (default = "csv") + - txt, text file with a url in each line + - csv, csv file with urls, (and captions + metadata) + - tsv, tsv - || - + - tsv.gz, - || - but compressed gzip + - json, loads urls and metadata as json + - parquet, loads urls and metadata as parquet + - webdataset, usually braceexpand format of mutliple paths to tars to re-process +encode_formats: Dict that specifies what extension each modality should use (default = "{'video': 'mp4'}") + f.e. {"video": "mp4", "audio": "m4a"} +stage: String that tells video2dataset what stage of processing is being performed. Can be (default = 'download') + WARNING: To be depracated soon (this information should be deduced based on config) + - download, when input is some tabular format and data must be downloaded first + - subset, tar files are already written and we would like to re-process (input_format == "webdataset") + - optical_flow, tar files are written and we woudl like to compute optical_flow and save to md shards +url_col: Column in input (if has columns) that contains the url (default = "url") +caption_col: Column in input (if has columns) that contains captions (to be written as txt) (default = None) +clip_col: Column in input (if has columns) that contains timeframes of clips for how to split video (default = None) +save_additional_columns: List of column names to save to json component of a sample (defualt = None) +enable_wandb: Whether or not to log info to wandb (default = False) +wandb_project: Name of wandb project to log runs to (default = "video2dataset") +incremental_mode: Decides how to handle restarting, Can be (default = "incremental") + - incremental, checks which shards are done and skips those + - overwrite, deletes and reprocesses shards as it goes +max_shard_retry: Maximum amount of attempts to retry a failed shard (default = 1) +tmp_dir: Path to temporary directory on your file system (default = "/tmp") +config: Path to your config of choice or the config itself (more info on configs in API doc) (default = "default") +``` + +## Config + +Any video2dataset configs is divided into 4 components which describe the following: +- subsampling: how the data should be transformed on its way from input to output +- reading: how the data should be read (from the internet or local FS) +- storage: how the data should be stored +- distribution: how processing should be distributed among compute resources + +Lets look at an example and go over where we can find information on how to fill out a config: + +```yaml +subsampling: + ResolutionSubsampler: + args: + video_size: 224 + resize_mode: "scale,crop,pad" + FrameSubsampler: + args: + frame_rate: 5 + CutDetectionSubsampler: + cuts_are_clips: True + args: + cut_detection_mode: "all" + framerates: null + threshold: 27 + min_scene_len: 15 + ClippingSubsampler: + args: + min_length: 4.0 + max_length: 20.0 + max_length_strategy: "all" + precision: "keyframe_adjusted" + FFProbeSubsampler: + args: + extract_keyframes: False + +reading: + yt_args: + download_size: 360 + download_audio_rate: 44100 + yt_metadata_args: null + timeout: 60 + sampler: null + +storage: + number_sample_per_shard: 1000 + oom_shard_count: 5 + captions_are_subtitles: False + +distribution: + processes_count: 16 + thread_count: 32 + subjob_size: 1000 + distributor: "multiprocessing" +``` + +### Subsampling + +This component of the config will contain a dict with subsampler names and their respective parameters. Each subsampler specification is formatted as such: + +```yaml +SubsamplerName: + ninit_arg1: "test" + ... + ninit_argn: "test" + args: + init_arg1: 1 + ... + init_argn: n +``` + +The arguments within each subsampler are split into 2 groups: +- initialization args: these are the arguments inside the "args" dict of each subsampler specification. If this is not present the subsampler will not be initialized. +- non-initialization args: the video2dataset worker will use these to perform operations outside the subsampler. They can be used to specify processes even without the presenece of the initialized subsampler. A good example is the `cuts_are_clips` parameter in the `CutDetectionSubsampler`. If we perform a stage where we detect cuts in the input videos and place those cuts in the metadata but do not clip the videos themselves we can simply perform this clipping in a subsequent stage without the need for re-computing those cuts. We would do this by *just* passing in the CutDetectionSubsampler specification without the *args* but adding `cuts_are_clips`. video2dataset would then just extract the cut information from the metadata and use that for clipping. It would look something like this: +```yaml +... + CutDetectionSubsampler: + cuts_are_clips: True + ClippingSubsampler: + args: + min_length: 4.0 + max_length: 20.0 + max_length_strategy: "all" + precision: "keyframe_adjusted" +... +``` + +To check what values can be present in a specification dict of a given subsampler check the documentation in the [docstring of the subsampler](https://github.com/iejMac/video2dataset/tree/main/video2dataset/subsamplers). + +### Reading + +The reading component of the config informs video2dataset the preferred options for reading data from the specified source. Here is a maxed out reading specification (it has all possible options specified): + +```yaml +reading: + yt_args: + download_size: 360 + yt_metadata_args: + writesubtitles: 'all' + subtitleslangs: ['en'] + writeautomaticsub: True + get_info: True + dataloader_args: + resize_size: 16 + decoder_kwargs: + n_frames: null + fps: 2 + num_threads: 8 + return_bytes: True + timeout: 60 + sampler: null +``` + +Options: + +``` +yt_args: arguments used by the YtDlpDownloader, see the docstring of that class in data_reader.py for + an explanation on what they do. +dataloader_args: arguments passed to the dataloader which will be used to load frames for stages that + need them (f.e. optical flow). Follow dataloader documentation for that +timeout: tells video2dataset the maximum time to consider downloading a video. +sampler: a class that samples shards from the input (f.e. used by slurm distributor to tell workers + which shards to work on) +``` + +### Storage + +The storage specification tells video2dataset how to store shards/samples. Here are the possible arguments: +``` +number_sample_per_shard: how many samples should be in each shard +oom_shard_count: the order of magnitude of the number of shards, used only to decide what zero padding + to use to name the shard files +captions_are_subtitles: If subtitles are present in metadata we can clip the video according + to those by setting this parameter to True (TODO: this could be a ClippingSubsapler arg along with `cuts_are_clips`) +``` + +### Distribution + +The distribution specification tells video2dataset how to parallelize processing at multiple levels. There are 4 required arguments: + +``` +processes_count: The number of processes used for downloading the dataset at the shard level. This is important to be high for performance. +thread_count: The number of threads to use for processing samples (sample level). +subjob_size: the number of shards to download in each subjob supporting it, a subjob can be a pyspark job for example +distributor: the type of distribution used, can be + - multiprocessing, uses multiprocessing pool to spawn processes + - spark, use a pyspark session to create workers on a spark cluster (see details in `examples/distributed_spark.md`) + - slurm, use slurm to distribute processing to multiple slurm workers +``` + +On top of these args some distributors like slurm need additional arguments. You can check the docstring of the distributor to see what needs to be specified and then fill in the `distributor_args` entry in the config. Here's an example (currently only slurm requires this): + +```yaml +distribution: + processes_count: 1 + thread_count: 8 + subjob_size: 1000 + distributor: "slurm" + distributor_args: + cpus_per_task: 96 + job_name: "v2ds" + partition: "g80" + n_nodes: 2 + gpus_per_node: 8 + account: "stablediffusion" + tasks_per_node: 1 +``` + +### Examples + +We provide 3 example configs which can be used by setting the config parameter of video2dataset to the string: +- `default` - This performs no subsampling and attempts to download videos at either native resolution or close to 360p for youtube videos. +- `downsample_ml` - performs resolution downsampling to 224x224 via a scale, center crop, and pad for smaller videos, downsamples FPS to 5, detects cuts, clips the videos according to those cuts and saves compression metadata from FFProbe. +- `optical_flow` - takes a created tar dataset and computes the optical flow for the videos in it + +If you want to create your own config based on these you can copy them from video2dataset/configs diff --git a/HISTORY.md b/HISTORY.md new file mode 100644 index 00000000..aedb3dc3 --- /dev/null +++ b/HISTORY.md @@ -0,0 +1,16 @@ +## 1.2.0 + +Official Release of video2dataset v1 +See blog post - https://laion.ai/blog/video2dataset/ + +## 1.1.0 + +* Implemented many subsamplers +* Tested on 3 datasets +* Good DataLoader +* Distributing with spark and multiprocessing +* Multistage setup works + +## 1.0.0 + +* it works a bit diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..bcd77a8b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Romain Beaumont + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..5dbd2d28 --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +install: ## [Local development] Upgrade pip, install requirements, install package. + python -m pip install -U pip + python -m pip install -e . + +install-dev: ## [Local development] Install test requirements + python -m pip install -r requirements-test.txt + +lint: ## [Local development] Run mypy, pylint and black + python -m mypy video2dataset + python -m pylint video2dataset + python -m black --check -l 120 . + +black: ## [Local development] Auto-format python code using black + python -m black -l 120 . + +build-pex: + python3 -m venv .pexing + . .pexing/bin/activate && python -m pip install -U pip && python -m pip install pex + . .pexing/bin/activate && python -m pex setuptools . -o video2dataset.pex -v + rm -rf .pexing + + +TEST_ARGS = tests ## set default to run all tests +## to run specific test, run `make test TEST_ARGS="tests/test_subsamplers.py::test_audio_rate_subsampler"` +test: ## [Local development] Run unit tests + python -m pytest -x -s -v $(TEST_ARGS) + +.PHONY: help + +help: # Run `make help` to get help on the make commands + @grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/README.md b/README.md new file mode 100644 index 00000000..6d893dc5 --- /dev/null +++ b/README.md @@ -0,0 +1,276 @@ +# video2dataset +[![pypi](https://img.shields.io/pypi/v/video2dataset.svg)](https://pypi.python.org/pypi/video2dataset) + +Easily create large video dataset from video urls. Can download and package 10M videos in 12h on a single 16 core machine. + +![video2dataset design overview](docs/video2dataset_overview.png) + +If you believe in making reusable tools to make data easy to use for ML and you would like to contribute, please join the [DataToML](https://discord.gg/ep8yUUtCnp) chat. + +## Install + +```bash +pip install video2dataset +``` + +Or from source via + +```bash +git clone https://github.com/iejMac/video2dataset +cd video2dataset +pip install -e . +``` + +#### Mac Installation + + +On Mac, replace 'decord' with 'eva_decord' in requirements.txt. For details, see https://github.com/dmlc/decord/issues/213. + + +## Usage + +First get some video urls and metadata (all [supported sites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md) by yt-dlp). For example lets save this small animal video dataset to a csv file called `videos.csv` +```csv +url,caption +https://www.youtube.com/watch?v=od_PmtmMDV0,Driving to the banana store +https://www.youtube.com/watch?v=8FhGOV7fs64,Polar bear eating +https://www.youtube.com/watch?v=TReCLbmhlMs,Cat scared of printer +https://www.dailymotion.com/video/x29ryo7,Cat and owl playing +``` + +Then, run the tool: + +```bash +video2dataset --url_list="videos.csv" --url_col="url" --caption_col="caption" --output_folder="dataset" +``` +If you go into the output folder you should see a nice small video dataset stored with all relevant metadata. + +## Examples + +Here are some more concrete examples of video2dataset usage. + +#### WebVid download + +The WebVid dataset is a high quality video-text of 10M stock videos. It can be easily downloaded and stored using [one video2dataset command](https://github.com/iejMac/video2dataset/blob/main/examples/download_webvid.sh), to perform the same on the train split (much larger) you just need to swap out the csv file and update the distribution params to something more beefy. Here's an [example config](https://github.com/iejMac/video2dataset/blob/main/examples/default_slurm.yaml) that adjusts the default config for slurm distribution (so we can use many nodes to download it quickly). + +#### Re-processing + +video2dataset is designed such that you can chain together runs to re-process your downloaded data since `webdataset` is a valid `input_format`. Here's an example - with the WebVid data you downloaded in the previous example you can also run [this script](https://github.com/iejMac/video2dataset/blob/main/examples/optical_flow_webvid.sh) which will compute the optical flow for each video and store it in metadata shards (shards which only have the optical flow metadata in them). You can also run [this script](https://github.com/iejMac/video2dataset/blob/main/examples/downsample_webvid.sh) which will take those videos and perform resizing, fps downsampling, cut detection, and clipping and also store that in a new dataset. We make sure that the content in shards with the same IDs is the same across re-processing runs. Furthermore, if clipping is not performed shard IDs *and* sample IDs are exactly the same since clipping is the only transformation that changes the sample IDs i.e. if sample `000` is clipped into 3 clips they will be split into `000_000`, `000_001`, `000_002`. + +#### Dataloading + +Once you download some chunk of WebVid (or any video dataset) you can load it using our dataloader like in [this example](https://github.com/iejMac/video2dataset/blob/main/examples/dataloader_example.py). Try it out. + +#### Large processing job examples + +Whenever we do a large dataset processing job we document them in [dataset examples](https://github.com/iejMac/video2dataset/tree/main/dataset_examples) as many existing datasets are unique and might require special procesing or the authors just don't specify the best ways of getting the data. Thanks to this we can all share the most efficient ways of processing large video/audio datasets! + +## Output format + +The tool will automatically download the urls and store them with the format: + +``` +output-folder + ├── 00000{.ext if output_format != files, can be tar, parquet, tfrecord, etc.} + | ├── 00000.mp4 + | ├── 00000.txt + | ├── 00000.json + | ├── 00001.mp4 + | ├── 00001.txt + | ├── 00001.json + | └── ... + | ├── 10000.mp4 + | ├── 10000.txt + | ├── 10000.json + ├── 00001.ext + | ├── 10001.mp4 + | ├── 10001.txt + | ├── 10001.json + │ ... + ... +``` + +with each number being the position in the input table or the input sample ID. The subfolders avoid having too many files in a single folder. If **captions** are provided, they will be saved as 0.txt, 1.txt, etc. (matching the ID of the sample they belong to). This can then easily be fed into machine learning training or any other use case. + +Also .json files named 0.json, 1.json,... are saved with these keys: +* url +* caption +* key of the form 000010005: the first 5 digits are the shard id, the last 4 are the index in the shard +* additionally gathered metadata (either specified from input table or collected during downloading/processing) +* status: whether the download succeeded +* error_message + +Also a .parquet file will be saved with the same name as the subfolder/tar files containing these same metadata. +It can be used to analyze the results efficiently. + +.json files will also be saved with the same name suffixed by \_stats, they contain stats collected during downloading (download time, number of success, ...) + +### Output format choice + +video2dataset support several formats. There are trade off for which to choose: +* files: this is the simplest one, videos are simply saved as files. It's good for up to 1M samples on a local file system. Beyond that performance issues appear very fast. Handling more than a million files in standard filesystem does not work well. +* webdataset: webdataset format saves samples in tar files, thanks to [webdataset](https://webdataset.github.io/webdataset/) library, this makes it possible to load the resulting dataset fast in both pytorch, tensorflow and jax. Choose this for most use cases. It works well for any filesystem +* parquet: parquet is a columnar format that allows fast filtering. It's particularly easy to read it using pyarrow and pyspark. Choose this if the rest of your data ecosystem is based on pyspark. [petastorm](https://github.com/uber/petastorm) can be used to read the data but it's not as easy to use as webdataset +* tfrecord: tfrecord is a protobuf based format. It's particularly easy to use from tensorflow and using [tf data](https://www.tensorflow.org/guide/data). Use this if you plan to use the dataset only in the tensorflow ecosystem. The tensorflow writer does not use fsspec and as a consequence supports only a limited amount of filesystem, including local, hdfs, s3 and gcs. It is also less efficient than the webdataset writer when writing to other filesystems than local, losing some 30% performance. + + +## API + +The module exposes a single function `video2dataset` which takes the same arguments as the command line tool: + +``` +url_list: list of input urls - can be any of the supported input formats + (csv, parquet, braceexpand tar paths etc.) +output_folder: Desired location of output dataset (default = "dataset") +output_format: Format of output dataset, can be (default = "files") + - files, samples saved in subdirectory for each shard (useful for debugging) + - webdataset, samples saved in tars (useful for efficient loading) + - parquet, sampels saved in parquet (as bytes) + - tfrecord, samples saved in tfrecord (as bytes) + - dummy, does not save (useful for benchmarks) +input_format: Format of the input, can be (default = "csv") + - txt, text file with a url in each line + - csv, csv file with urls, (and captions + metadata) + - tsv, tsv - || - + - tsv.gz, - || - but compressed gzip + - json, loads urls and metadata as json + - parquet, loads urls and metadata as parquet + - webdataset, usually braceexpand format of mutliple paths to tars to re-process +encode_formats: Dict that specifies what extension each modality should use (default = "{'video': 'mp4'}") + f.e. {"video": "mp4", "audio": "m4a"} +stage: String that tells video2dataset what stage of processing is being performed. Can be (default = 'download') + WARNING: To be depracated soon (this information should be deduced based on config) + - download, when input is some tabular format and data must be downloaded first + - subset, tar files are already written and we would like to re-process (input_format == "webdataset") + - optical_flow, tar files are written and we woudl like to compute optical_flow and save to md shards +url_col: Column in input (if has columns) that contains the url (default = "url") +caption_col: Column in input (if has columns) that contains captions (to be written as txt) (default = None) +clip_col: Column in input (if has columns) that contains timeframes of clips for how to split video (default = None) +save_additional_columns: List of column names to save to json component of a sample (defualt = None) +enable_wandb: Whether or not to log info to wandb (default = False) +wandb_project: Name of wandb project to log runs to (default = "video2dataset") +incremental_mode: Decides how to handle restarting, Can be (default = "incremental") + - incremental, checks which shards are done and skips those + - overwrite, deletes and reprocesses shards as it goes +max_shard_retry: Maximum amount of attempts to retry a failed shard (default = 1) +tmp_dir: Path to temporary directory on your file system (default = "/tmp") +config: Path to your config of choice or the config itself (more info on configs in API doc) (default = "default") +``` + +These arguments give coarse control over input/output "shape" of the dataset. For finer control of subsamplers, distribution, reading, and storage see the more detailed [API.md](https://github.com/iejMac/video2dataset/blob/main/API.md) doc. + +## Downloading YouTube Metadata + +If we want to download a large amount of YouTube videos with video2dataset we can specify some parameters and also extract useful metadata as well. For directions on how to do so please see this [example](https://github.com/iejMac/video2dataset/blob/main/examples/yt_metadata.md). + +## Incremental mode + +If a first download got interrupted for any reason, you can run again with --incremental "incremental" (this is the default) and using the same output folder , the same number_sample_per_shard and the same input urls, and video2dataset will complete the download. + +## File system support + +Thanks to [fsspec](https://filesystem-spec.readthedocs.io/en/latest/), video2dataset supports reading and writing files in [many file systems](https://github.com/fsspec/filesystem_spec/blob/6233f315548b512ec379323f762b70764efeb92c/fsspec/registry.py#L87). +To use it, simply use the prefix of your filesystem before the path. For example `hdfs://`, `s3://`, `http://`, or `gcs://`. +Some of these file systems require installing an additional package (for example s3fs for s3, gcsfs for gcs). +See fsspec doc for all the details. + +If you need specific configuration for your filesystem, you may handle this problem by using the [fsspec configuration system](https://filesystem-spec.readthedocs.io/en/latest/features.html#configuration) that makes it possible to create a file such as `.config/fsspec/s3.json` and have information in it such as: +```json +{ + "s3": { + "client_kwargs": { + "endpoint_url": "https://some_endpoint", + "aws_access_key_id": "your_user", + "aws_secret_access_key": "your_password" + } + } +} +``` +Which may be necessary if using s3 compatible file systems such as [minio](https://min.io/). That kind of configuration also work for all other fsspec-supported file systems. + +## Distribution modes + +video2dataset supports several distributors. +* multiprocessing which spawns a process pool and use these local processes for downloading +* pyspark which spawns workers in a spark pool to do the downloading +* slurm which starts separate worker nodes + +multiprocessing is a good option for downloading on one machine, and as such it is the default. +Pyspark lets video2dataset use many nodes, which makes it as fast as the number of machines. +It can be particularly useful if downloading datasets with more than a billion image. Here's an [example](https://github.com/iejMac/video2dataset/blob/main/examples/distributed_spark.md) +for how we used pyspark distributed mode to download 40M videos with metadata. If you have access to a slurm cluster it is more comfortable to use than pyspark but not everyone does. + +### pyspark configuration + +In order to use video2dataset with pyspark, you will need to do this: +1. `pip install pyspark` +2. set `distributor: pyspark` in your config +3. tweak the `subjob_size: 1000` option in your config. This is the number of videos to download in each subjob. Increasing it will mean a longer time of preparation to put the feather files in the temporary dir, a shorter time will mean sending less shards at a time to the pyspark job. + +By default a local spark session will be created. +You may want to create a custom spark session depending on your specific spark cluster. + +## Benchmarks + +As stated at the top of the README - video2dataset is capable of downloading 10M videos in 12h. For more details on end2end performance please see specific runs in [dataset examples](https://github.com/iejMac/video2dataset/tree/main/dataset_examples) as there are nuances to video (how long it is, where it comes from etc.). For example it takes considerably longer to pul videos from youtube than just mp4 links (compare [WebVid.md](https://github.com/iejMac/video2dataset/tree/main/dataset_examples/WebVid.md) to [VideoCC.md](https://github.com/iejMac/video2dataset/tree/main/dataset_examples/VideoCC.md)). Each example should have a "Performance" statement at the bottom which should contain info about download/processing performance (video/s, Mb/s) along with a cost estimate on popular cloud infrastructure. + +For information about video2dataset subsampler speed please check out the [benchmark suite](https://github.com/iejMac/video2dataset/tree/main/benchmark) which contains code that produces performance numbers for subsamplers, over a grid of parameters, on a given architecture. It also contains a json file with some results we produced. This can be used to estimate costs of big runs and also to optimize the subsamplers. NOTE: cost can drastically vary based on chosen subsampler configuration. + +## Integration with Weights & Biases + +If you pass the `--enable_wandb=True` parameter then performance metrics will be logged to [Weights & Biases](https://wandb.com/) + +![W&B metrics](docs/wandb_logs.png) + +In addition, most frequent errors are logged for easier debugging + +![W&B status](docs/wandb_status.png) + +Other features are available: + +* logging of environment configuration (OS, python version, CPU count, Hostname, etc) +* monitoring of hardware resources (GPU/CPU, RAM, Disk, Networking, etc) +* custom graphs and reports +* comparison of runs (convenient when optimizing parameters such as number of threads/cpus) + +When running the script for the first time, you can decide to either associate your metrics to your account or log them anonymously. + +You can also log in (or create an account) before by running `wandb login`. + +## For development + +Either locally, or in [gitpod](https://gitpod.io/#https://github.com/iejMac/video2dataset) (do `export PIP_USER=false` there) + +Setup a virtualenv: + +```bash +python3 -m venv .env +source .env/bin/activate +pip install -e . +``` + +to run tests: +```bash +pip install -r requirements-test.txt +``` +then +```bash +make lint +make test +``` + +You can use `make black` to reformat the code + +`python -m pytest -x -s -v tests -k "dummy"` to run a specific test + +## Citation +```bibtex +@misc{kilian-2023-video2dataset, + author = {Maciej Kilian, Romain Beaumont, Daniel Mendelevitch, Sumith Kulal, Andreas Blattmann}, + title = {video2dataset: Easily turn large sets of video urls to a video dataset}, + year = {2023}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/iejMac/video2dataset}} +} +``` diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 00000000..1e11dbc4 --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,56 @@ +# video2dataset Benchmark Suite +The code in here can be used to benchmark the performance of video2dataset components (and end2end) on different hardware/data conditions + +## Subsamplers + +Test out the performance of video2dataset subsamplers by configuring the ```subsamplers_config.yaml``` file with the subsamplers you want to benchmark along with the parameters (to form a parameter grid to test over). You can also choose a set of videos to test on (in case thats relevant for your use case). The script will output a JSON file with the information of the system performing the benchmark and a metrics dict for each parameter set in the input grid. + +For subsampler_config: +```yaml +video_set: + - path: "dataset/mp4/{00000..00009}.tar" +subsamplers: + - name: ResolutionSubsampler + parameters: + - video_size: [360, 60] + resize_mode: ["scale"] +``` + +The benchmark will output: +```json +{ + "system_info": { + "platform": "Linux", + "cpu_count": 96, + "cpu_info": "x86_64", + "gpu_info": "NVIDIA A100-SXM4-80GB", + "gpu_count": 8 + }, + "configs_and_metrics": [ + { + "name": "ResolutionSubsampler", + "config": { + "video_size": 360, + "resize_mode": "scale" + }, + "metrics": { + "time": 111.43327045440674, + "samples": 100, + "bytes": 236045214 + } + }, + { + "name": "ResolutionSubsampler", + "config": { + "video_size": 60, + "resize_mode": "scale" + }, + "metrics": { + "time": 29.380290031433105, + "samples": 100, + "bytes": 236045214 + } + } + ] +} +``` diff --git a/benchmark/benchmark_dataloader.py b/benchmark/benchmark_dataloader.py new file mode 100644 index 00000000..2d9fe3f2 --- /dev/null +++ b/benchmark/benchmark_dataloader.py @@ -0,0 +1,51 @@ +""" +Benchmark dataloader speed +""" +import time + +# from video2dataset.dataloader import get_bytes_dataloader +from video2dataset.dataloader import get_video_dataset + +# Benchmark videos are the WebVid validation split (5000 videos) +SHARDS = "examples/dataset/{00000..00004}.tar" + + +def benchmark_train_dl(num_frames, num_workers, bs=1, num_threads=4, resize_size=None, crop_size=None): + from argparse import Namespace + from webdataset import WebLoader + + decoder_kwargs = {"n_frames": num_frames, "fps": None, "num_threads": num_threads} + + dset = get_video_dataset( + urls=SHARDS, + batch_size=bs, + decoder_kwargs=decoder_kwargs, + resize_size=resize_size, + crop_size=crop_size, + ) + dl = WebLoader(dset, batch_size=None, num_workers=num_workers) + count = 0 + t0 = time.time() + for samp in dl: + count += 1 + tf = time.time() + return count / (tf - t0) + + +if __name__ == "__main__": + # print("Benchmarking bytes dataloader...") + # print(benchmark_bytes_dl()) + + print("# benchmarking without resizing") + for nf in [8, 16, 32]: + for nw in [6, 8, 10, 12]: + print(f"Benchmarking train dataloader with {nf} decoded frames and {nw} dl workers...") + throughput = benchmark_train_dl(num_frames=nf, num_workers=nw) + print(f"Got {throughput} samples/sec") + + print("# benchmarking with resizing and centercropping") + for nf in [8, 16, 32]: + for nw in [6, 8, 10, 12]: + print(f"Benchmarking train dataloader with {nf} decoded frames and {nw} dl workers...") + throughput = benchmark_train_dl(num_frames=nf, num_workers=nw, resize_size=256, crop_size=256) + print(f"Got {throughput} samples/sec") diff --git a/benchmark/benchmark_subsamplers.py b/benchmark/benchmark_subsamplers.py new file mode 100644 index 00000000..2f25b04b --- /dev/null +++ b/benchmark/benchmark_subsamplers.py @@ -0,0 +1,154 @@ +import os +import sys +import cv2 +import time +import tempfile +import itertools +import platform +import psutil +import json +import yaml +from omegaconf import OmegaConf +from webdataset import WebLoader + +from video2dataset import subsamplers +from video2dataset.dataloader import get_video_dataset + + +# Add this function to gather system information +def gather_system_info(): + cpu_count = os.cpu_count() + cpu_info = platform.processor() + gpu_info = [None] + + # TODO: how do you get GPU info on non linux + if platform.system() == "Linux": + try: + gpu_info = os.popen("nvidia-smi --query-gpu=gpu_name --format=csv,noheader").read().splitlines() + gpu_count = len(gpu_info) + except KeyError: + pass + + return { + "platform": platform.system(), + "cpu_count": cpu_count, + "cpu_info": cpu_info, + "gpu_info": gpu_info[0], + "gpu_count": len(gpu_info), + } + + +# Add this function to load and parse the config file +def load_config(filepath): + with open(filepath, "r") as f: + config = yaml.safe_load(f) + return OmegaConf.create(config) + + +def create_parameter_grid(params): + keys, values = zip(*params.items()) + combinations = list(itertools.product(*values)) + return [dict(zip(keys, combination)) for combination in combinations] + + +def make_fake_clips(time, n): + clip_duration = time / n + return [(i * clip_duration, (i + 1) * clip_duration) for i in range(n)] + + +def main(config_file="subsamplers_config.yaml"): + # Gather system information + system_info = gather_system_info() + + # Load config + benchmark_config = load_config(config_file) + + # Dataloader + ds = get_video_dataset( + urls=benchmark_config.video_set[0].path, + batch_size=1, + decoder_kwargs={}, # load bytes + ) + ds = WebLoader(ds, batch_size=None, num_workers=12) + + # Initialize all configs + benchmarks = {} + for subsampler_config in benchmark_config.subsamplers: + subsampler_name = subsampler_config.name + params_grid = create_parameter_grid(OmegaConf.to_container(subsampler_config.parameters, resolve=True)[0]) + benchmarks[subsampler_name] = [ + { + "subsampler": getattr(subsamplers, subsampler_name)(**cfg), + "config": cfg, + "metrics": {"time": 0.0}, + } + for cfg in params_grid + ] + + for subsampler, bm_cfgs in benchmarks.items(): + print(f"Benchmarking {subsampler}...") + size_metrics = { + "samples": 0, + "frames": 0, + "bytes": 0.0, + } + for sample in ds: + # TODO: parallelize this in a safe way i.e. each benchmarker gets certain amount of cores (no interference) + # TODO: report per-core metrics + # Update size metrics: + with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_vid: + temp_vid.write(sample["mp4"]) + vid = cv2.VideoCapture(temp_vid.name) + size_metrics["samples"] += 1 + size_metrics["bytes"] += len(sample["mp4"]) + size_metrics["frames"] += int(vid.get(cv2.CAP_PROP_FRAME_COUNT)) + seconds = int(vid.get(cv2.CAP_PROP_FRAME_COUNT)) / int(vid.get(cv2.CAP_PROP_FPS)) + + # TODO: construct streams based on subsampler (maybe we want audio or sth + streams = {"video": [sample["mp4"]]} + metadata = { + "key": sample["__key__"], + "clips": make_fake_clips(seconds, 5), # TODO: parameterize this, dense/sparse + } + + for cfg in bm_cfgs: + t0 = time.time() + strm, md, err_msg = cfg["subsampler"](streams, metadata) + tf = time.time() + + if err_msg is not None: + print(err_msg) + + cfg["metrics"]["time"] += tf - t0 + + metadata["clips"] = make_fake_clips(seconds, 5) # gets popped + + # TODO: Normalize metrics for core count + # Update and normalize cfg metrics + for cfg in bm_cfgs: + for m in size_metrics: + cfg["metrics"][m + "/s"] = size_metrics[m] / cfg["metrics"]["time"] + + # TODO: visualize in nice way - visual repr but also raw numbers in some json or something + # For now just output JSON + data = { + "system_info": system_info, + "dataset_info": size_metrics, + "subsamplers": {}, + } + + for ss, bm_cfgs in benchmarks.items(): + cfg_metrics = [{"config": bm["config"], "metrics": bm["metrics"]} for bm in bm_cfgs] + data["subsamplers"][ss] = sorted(cfg_metrics, key=lambda cfg: cfg["metrics"]["time"]) + + with open("subsampler_results.json", "w") as f: + json.dump(data, f, indent=4) + + +if __name__ == "__main__": + if len(sys.argv) == 2: + main(sys.argv[1]) + elif len(sys.argv) == 1: + main() + else: + print("Usage: python yourscript.py [config_file]") diff --git a/benchmark/download_set.sh b/benchmark/download_set.sh new file mode 100755 index 00000000..c9a34762 --- /dev/null +++ b/benchmark/download_set.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +video2dataset --url_list="video_sets/mp4.parquet" \ + --input_format="parquet" \ + --output_folder="dataset/mp4" \ + --output-format="webdataset" \ + --url_col="contentUrl" \ + --caption_col="name" \ + --enable_wandb=False \ + --video_size=360 \ + --number_sample_per_shard=10 \ + --processes_count 10 \ diff --git a/benchmark/subsampler_config.yaml b/benchmark/subsampler_config.yaml new file mode 100644 index 00000000..f85aff9f --- /dev/null +++ b/benchmark/subsampler_config.yaml @@ -0,0 +1,25 @@ +video_set: + - path: "dataset/mp4/{00000..00009}.tar" +subsamplers: + - name: ResolutionSubsampler + parameters: + - video_size: [360, 60, 1080] + resize_mode: ["scale", "scale,crop,pad"] + - name: FrameSubsampler + parameters: + - frame_rate: [1, 10] + - name: CutDetectionSubsampler + parameters: + - cut_detection_mode: ["all"] + threshold: [27, 12, 8] + min_scene_len: [2, 15] + - name: ClippingSubsampler + parameters: + - oom_clip_count: [5] + encode_formats: [{"video": "mp4"}] + precise: [True, False] + min_length: [0, 4, 10] + max_length: [20, 9999] + - name: MetadataSubsampler + parameters: + - extract_keyframes: [False, True] diff --git a/benchmark/subsampler_results.json b/benchmark/subsampler_results.json new file mode 100644 index 00000000..e2be79ec --- /dev/null +++ b/benchmark/subsampler_results.json @@ -0,0 +1,424 @@ +{ + "system_info": { + "platform": "Linux", + "cpu_count": 96, + "cpu_info": "x86_64", + "gpu_info": "NVIDIA A100-SXM4-80GB", + "gpu_count": 8 + }, + "dataset_info": { + "samples": 100, + "frames": 57106, + "bytes": 236045214.0 + }, + "subsamplers": { + "ResolutionSubsampler": [ + { + "config": { + "video_size": 360, + "resize_mode": "scale" + }, + "metrics": { + "time": 111.13579082489014, + "samples/s": 0.8998001387110646, + "frames/s": 513.8398672123406, + "bytes/s": 2123935.162992829 + } + }, + { + "config": { + "video_size": 360, + "resize_mode": "scale,crop,pad" + }, + "metrics": { + "time": 70.28809905052185, + "samples/s": 1.4227159554866002, + "frames/s": 812.4561735401779, + "bytes/s": 3358252.9217404905 + } + }, + { + "config": { + "video_size": 60, + "resize_mode": "scale" + }, + "metrics": { + "time": 22.140214204788208, + "samples/s": 4.516668134962003, + "frames/s": 2579.2885051514013, + "bytes/s": 10661378.964840868 + } + }, + { + "config": { + "video_size": 60, + "resize_mode": "scale,crop,pad" + }, + "metrics": { + "time": 18.740954160690308, + "samples/s": 5.335907614018548, + "frames/s": 3047.123402061432, + "bytes/s": 12595154.546352375 + } + }, + { + "config": { + "video_size": 1080, + "resize_mode": "scale" + }, + "metrics": { + "time": 226.4908411502838, + "samples/s": 0.4415189571998933, + "frames/s": 252.13381569857108, + "bytes/s": 1042184.3673730566 + } + }, + { + "config": { + "video_size": 1080, + "resize_mode": "scale,crop,pad" + }, + "metrics": { + "time": 214.78087329864502, + "samples/s": 0.46559080640739187, + "frames/s": 265.8802859070052, + "bytes/s": 1099004.8153486538 + } + } + ], + "FrameSubsampler": [ + { + "config": { + "frame_rate": 1 + }, + "metrics": { + "time": 30.05926489830017, + "samples/s": 3.326761327608345, + "frames/s": 1899.7803237440214, + "bytes/s": 7852660.895022359 + } + }, + { + "config": { + "frame_rate": 10 + }, + "metrics": { + "time": 47.6098906993866, + "samples/s": 2.1004038978247097, + "frames/s": 1199.4566498917786, + "bytes/s": 4957902.875484677 + } + } + ], + "CutDetectionSubsampler": [ + { + "config": { + "cut_detection_mode": "all", + "threshold": 12, + "min_scene_len": 15 + }, + "metrics": { + "time": 15.451006174087524, + "samples/s": 6.47207041879948, + "frames/s": 3695.9405333596314, + "bytes/s": 15277012.47028593 + } + }, + { + "config": { + "cut_detection_mode": "all", + "threshold": 8, + "min_scene_len": 2 + }, + "metrics": { + "time": 15.451342105865479, + "samples/s": 6.471929707778526, + "frames/s": 3695.860178924005, + "bytes/s": 15276680.328655396 + } + }, + { + "config": { + "cut_detection_mode": "all", + "threshold": 12, + "min_scene_len": 2 + }, + "metrics": { + "time": 15.480078220367432, + "samples/s": 6.459915678489796, + "frames/s": 3688.9994473583833, + "bytes/s": 15248321.787510792 + } + }, + { + "config": { + "cut_detection_mode": "all", + "threshold": 8, + "min_scene_len": 15 + }, + "metrics": { + "time": 15.529819011688232, + "samples/s": 6.439225075626242, + "frames/s": 3677.1838716871216, + "bytes/s": 15199482.609703623 + } + }, + { + "config": { + "cut_detection_mode": "all", + "threshold": 27, + "min_scene_len": 15 + }, + "metrics": { + "time": 15.553856611251831, + "samples/s": 6.429273620001029, + "frames/s": 3671.5009934377877, + "bytes/s": 15175992.674976977 + } + }, + { + "config": { + "cut_detection_mode": "all", + "threshold": 27, + "min_scene_len": 2 + }, + "metrics": { + "time": 15.701493978500366, + "samples/s": 6.368820708203137, + "frames/s": 3636.9787536264835, + "bytes/s": 15033296.46995441 + } + } + ], + "ClippingSubsampler": [ + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": true, + "min_length": 0, + "max_length": 20 + }, + "metrics": { + "time": 103.1305103302002, + "samples/s": 0.9696451581575906, + "frames/s": 553.7255640174737, + "bytes/s": 2288800.9886137233 + } + }, + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": true, + "min_length": 0, + "max_length": 9999 + }, + "metrics": { + "time": 102.62368559837341, + "samples/s": 0.9744339176372847, + "frames/s": 556.4602330059478, + "bytes/s": 2300104.6261755126 + } + }, + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": true, + "min_length": 4, + "max_length": 20 + }, + "metrics": { + "time": 57.5337016582489, + "samples/s": 1.7381117000606285, + "frames/s": 992.5660674366226, + "bytes/s": 4102729.481967149 + } + }, + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": true, + "min_length": 4, + "max_length": 9999 + }, + "metrics": { + "time": 57.83302330970764, + "samples/s": 1.7291158974082954, + "frames/s": 987.4289243739812, + "bytes/s": 4081495.3203454316 + } + }, + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": true, + "min_length": 10, + "max_length": 20 + }, + "metrics": { + "time": 9.375545024871826, + "samples/s": 10.666046585528195, + "frames/s": 6090.952563131731, + "bytes/s": 25176692.48814972 + } + }, + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": true, + "min_length": 10, + "max_length": 9999 + }, + "metrics": { + "time": 9.352877855300903, + "samples/s": 10.691896285518505, + "frames/s": 6105.714292808197, + "bytes/s": 25237709.467810202 + } + }, + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": false, + "min_length": 0, + "max_length": 20 + }, + "metrics": { + "time": 8.940687894821167, + "samples/s": 11.184821702357413, + "frames/s": 6387.204281348224, + "bytes/s": 26401236.322848 + } + }, + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": false, + "min_length": 0, + "max_length": 9999 + }, + "metrics": { + "time": 8.388755559921265, + "samples/s": 11.92071926350642, + "frames/s": 6807.445942617976, + "bytes/s": 28138287.29588295 + } + }, + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": false, + "min_length": 4, + "max_length": 20 + }, + "metrics": { + "time": 3.405555009841919, + "samples/s": 29.363789370896658, + "frames/s": 16768.485558144246, + "bytes/s": 69311819.45904227 + } + }, + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": false, + "min_length": 4, + "max_length": 9999 + }, + "metrics": { + "time": 3.384561061859131, + "samples/s": 29.545928754811783, + "frames/s": 16872.49807472282, + "bytes/s": 69741750.757583 + } + }, + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": false, + "min_length": 10, + "max_length": 20 + }, + "metrics": { + "time": 0.40822601318359375, + "samples/s": 244.96234137589474, + "frames/s": 139888.19466611845, + "bytes/s": 578221882.9201413 + } + }, + { + "config": { + "oom_clip_count": 5, + "encode_formats": { + "video": "mp4" + }, + "precise": false, + "min_length": 10, + "max_length": 9999 + }, + "metrics": { + "time": 0.4058263301849365, + "samples/s": 246.41082296072224, + "frames/s": 140715.36455995005, + "bytes/s": 581640954.3767979 + } + } + ], + "MetadataSubsampler": [ + { + "config": { + "extract_keyframes": false + }, + "metrics": { + "time": 7.498495101928711, + "samples/s": 13.336009244612121, + "frames/s": 7615.661439228198, + "bytes/s": 31479011.560504466 + } + }, + { + "config": { + "extract_keyframes": true + }, + "metrics": { + "time": 7.5574257373809814, + "samples/s": 13.232018874545355, + "frames/s": 7556.276698497871, + "bytes/s": 31233547.268940978 + } + } + ] + } +} diff --git a/benchmark/video_sets/mp4.parquet b/benchmark/video_sets/mp4.parquet new file mode 100644 index 00000000..6bcac204 Binary files /dev/null and b/benchmark/video_sets/mp4.parquet differ diff --git a/benchmark/video_sets/youtube.parquet b/benchmark/video_sets/youtube.parquet new file mode 100644 index 00000000..cc169975 Binary files /dev/null and b/benchmark/video_sets/youtube.parquet differ diff --git a/dataset_examples/ACAV.md b/dataset_examples/ACAV.md new file mode 100644 index 00000000..4ce2401e --- /dev/null +++ b/dataset_examples/ACAV.md @@ -0,0 +1,62 @@ +# [ACAV100M](https://acav100m.github.io) +ACAV100M is a dataset of 100M videos with high audio-visual correspondance. + +## Download the metdata +Go to the [downloads](https://acav100m.github.io/#downloads) section of the dataset page and download the subset you wish. This just has the youtube ID, start time, and end time. For the current version of video2dataset you need to preprocess this file a bit: +1. Create a column with the link instead of just the ID: f"youtube.com/watch?v={youtubeID}" +```python +df["video_link"] = df.apply(lambda row: f"https://youtube.com/watch?v={row['videoID']}", axis=1) +``` +2. Create a column for the timespan consistent with video2dataset expected clip format: [[start, end]] +```python +df["clips"] = df.apply(lambda row: [[row['start'], row['end']]], axis=1) +``` + +## Create the appropriate video2dataset config + +Since we only want to download 10 second clips we can allow for more samples per shard than usual. Also since this is a massive dataset we want to use a different distributor like pyspark. We also want to specify that we should download the metadata of the video and include that into the json file of each sample using the `yt_metadata_args`. + +```yaml +subsampling: {} + +reading: + yt_args: + download_size: 360 + download_audio_rate: 44100 + yt_metadata_args: + get_info: True + timeout: 60 + sampler: null + +storage: + number_sample_per_shard: 2000 + captions_are_subtitles: False + +distribution: + processes_count: 96 + thread_count: 48 + subjob_size: 10000 + distributor: "pyspark" +``` + +## Download the videos with video2dataset + +This command runs video2dataset on the input table and saves the video clips along with the metadata in the webdataset format. + +```python +video2dataset( + url_list="ACAV100M_clip_unique.parquet", + output_folder="output_folder", + output_format="webdataset", + input_format="parquet", + url_col="video_link", + caption_col="title", + clip_col="clips", + save_additional_columns=["videoID", "start", "end"], + enable_wandb=True, + config="path/to/config.yaml" +) +``` + +## Performance +The job ran on 64 nodes with 48 CPU cores each and processed the dataset in 7 hours. This wasn't the entire ACAV100M dataset but instead a 60M subset of only the unique videos in the dataset. The title and description of the videos were pre-extracted and placed in the parquet file along with the description and also the timespan of the clip was processed into the format that video2dataset expects. [Here's](https://wandb.ai/iejmac/video2dataset/runs/3kon2409?workspace=user-iejmac) a link to the wandb logs from the processing of the dataset. We were able to get 0.8 vid/s/core which comes out to approximately 100MB/s/core. diff --git a/dataset_examples/HDVILA.md b/dataset_examples/HDVILA.md new file mode 100644 index 00000000..29c38f8c --- /dev/null +++ b/dataset_examples/HDVILA.md @@ -0,0 +1,113 @@ +# [HDVILA-100M](https://github.com/microsoft/XPretrain/tree/main/hd-vila-100m) +HDVILA 100M is a dataset of 100M high-resolution videos from YouTube. + +## Download the metdata +First, run `wget -O hdvila100m.zip https://hdvila.blob.core.windows.net/dataset/hdvila100m.zip?sp=r&st=2022-06-28T03:33:11Z&se=2026-01-01T11:33:11Z&spr=https&sv=2021-06-08&sr=b&sig=VaqQkLFDqKinfkaPNs1jJ1EQIYCB%2FUPYiqFqmjWye6Y%3D` to download the HD VILA 100M metadata. Next, just run `unzip hdvilla100m.zip` in order to unzip the metadata. You should now have an `hdvila100m/` directory. + +Next, we need to do some preprocessing to get this metadata formatted into a nice parquet. The following script will take the downloaded metadata `.jsonl` files and create a parquet with all the relevant information. + +```python +import pandas as pd +import glob +import json +import os +import time +from datetime import datetime + +def time_string_to_seconds(timestamp): + hh,mm,s = timestamp.split(':') + ss,ms = s.split('.') + time = 3600*int(hh) + 60*int(mm) + int(ss) + int(ms)/1000 + return time + +def convert_clip_list(clip_list): + return [[time_string_to_seconds(x) for x in clip] for clip in clip_list] + +parquet_dir = "/path/to/my/metadata/dir/" + +data = [] +for jsonl in sorted(glob.glob(f"{parquet_dir}*.jsonl")): + path = os.path.join(parquet_dir, jsonl) + with open(path, "r") as f: + for line in f: + json_obj = json.loads(line) + clips = [ + json_obj['clip'][i]['span'] + for i in range(len(json_obj['clip'])) + ] + + out = { + 'video_id': json_obj['video_id'], + 'url': json_obj['url'], + 'clips': clips + } + data.append(out) + +df = pd.DataFrame(data) +df['clips'] = df['clips'].map(lambda x: convert_clip_list(x)) +df.to_parquet("hd_vila.parquet") +``` + +Once you run this, you should have a file `hd_vila.parquet` with all the relevant metadata. + +## Create the config for Stage 1 + +Since we download these videos in HD we want to decrease the number of samples per shard (a good heuristic is having shards be ~1Gb). We also wanted to use fewer processes because the many shards being processed at once on a single machine exceeded RAM. Additionally we wanted to use slurm to process this dataset since we set up slurm distribution and it was easier to manage than pyspark. Finally we wanted to detect cuts for each video and place them in the metadata so we need to add a subsampler for that: + +```yaml +subsampling: + CutDetectionSubsampler: + args: + cut_detection_mode: "all" + framerates: null + threshold: 11.5 + min_scene_len: 15 + +reading: + yt_args: + download_size: 360 + download_audio_rate: 44100 + yt_metadata_args: + writesubtitles: 'all' + subtitleslangs: ['en'] + writeautomaticsub: True + get_info: True + timeout: 60 + sampler: null + +storage: + number_sample_per_shard: 100 + captions_are_subtitles: False + +distribution: + processes_count: 32 + thread_count: 32 + subjob_size: 10000 + distributor: "slurm" + distributor_args: + partition: "cpu64" + n_nodes: 16 + account: "laion" + cache_path: "/fsx/home-iejmac/.slurm_cache" +``` + +## Stage 1 (Downloading + Cut Detection) + +This command runs video2dataset on the input parquet and saves the videos and audio along with the metadata in the webdataset format. + +```python +video2dataset( + url_list='hd_vila.parquet', + input_format='parquet', + output_format='webdataset', + output_folder='output_folder', + url_col="url", + enable_wandb=True, + encode_formats={"video": "mp4", "audio": "m4a"}, + config="path/to/config.yaml" +) +``` + +## Stage 1 Performance + +Over the course of downloading we ran into some bugs (which are now resolved) which required restarts, here's the WandB report for [one of the restarts](https://api.wandb.ai/links/iejmac/nn9hcaol). The downloading was performed on a cluster of 16 c6i cpu64 aws nodes (64 cpu cores) and the total downloading time can be deduced from the videos/s in the report (~28 hours). We were not just downlading the videos but also the audio and the youtube metadata and performing cut detection on the videos which is placed in the metadata of the samples. diff --git a/dataset_examples/Kinetics.md b/dataset_examples/Kinetics.md new file mode 100644 index 00000000..5eda9f02 --- /dev/null +++ b/dataset_examples/Kinetics.md @@ -0,0 +1,35 @@ +# [Kinetics](https://www.deepmind.com/open-source/kinetics) + +Kinetics is a large video classification dataset. It has multiple variants (400, 600, 700, etc.) but I will only show an example for K400. + +## Preparation + +Kinetics is a youtube dataset however many of the videos have gone private, been removed, and are no longer available. To keep the dataset intact people host it on S3 so we will make use of that to get the mp4s onto your machine first, then process them and store in webdataset format using video2dataset. For downloading we can use the [kinetics-dataset](https://github.com/cvdfoundation/kinetics-dataset) repo. Follow their [instruction](https://github.com/cvdfoundation/kinetics-dataset#kinetics-400) on how to download Kinetics400. + + +After this we need to create a table to input into video2dataset. Their annotations csv files are a good base, we just need to add the location of the mp4s. We can do that with the following code: + +```python3 +import pandas as pd +df = pd.read_csv("path/to/train.csv") +mp4_path = "/path/to/kinetics-dataset/k400/train" # this is where the mp4s are stored after extraction +df['path'] = mp4_path + "/" + df['youtube_id'] + "_" + df['time_start'].astype(str).str.zfill(6) + "_" + df['time_end'].astype(str).str.zfill(6) + ".mp4" +df.to_parquet("k400_train.parquet") +``` + +## Create the dataset using video2dataset + +This simple script can take all those mp4s and organize them into shards and preprocess them according to your specified config (I didn't apply any transformations for my usage) + +```bash +#!/bin/bash + +video2dataset --url_list="k400_train.parquet" \ + --input_format="parquet" \ + --output-format="webdataset" \ + --output_folder="dataset" \ + --url_col="path" \ + --caption_col="label" \ + --save_additional_columns='[youtube_id,time_start,time_end,split,is_cc]' \ + --config=default \ +``` diff --git a/dataset_examples/MSRVTT.md b/dataset_examples/MSRVTT.md new file mode 100644 index 00000000..72f2bf57 --- /dev/null +++ b/dataset_examples/MSRVTT.md @@ -0,0 +1,58 @@ +# [MSR-VTT](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/) + +MSR-VTT is a benchmark dataset to test video-text understanding, it has 20 captions per video and can be used to evaluate the retrieval capabilities of models. It has multiple variants (which are subsets of the original 10k videos) described [here](https://github.com/albanie/collaborative-experts/tree/master/misc/datasets/msrvtt). + +## Download the videos and metadata + +MSR-VTT is small enough (10k samples) where the videos and metadata are hosted as a .zip file [here](https://www.mediafire.com/folder/h14iarbs62e7p/shared). Download the files and unzip them + +## Preprocess into video2dataset input format + +Before you can create a video2dataset dataset you need to create the correct input format. Here's the script that does that for the test split: + +```python3 +import json +import pandas as pd + +with open('test_videodatainfo.json', 'r') as f: + data = json.load(f) + +df_videos = pd.DataFrame(data['videos']) +df_sentences = pd.DataFrame(data['sentences']) + +df = df_videos + +vid_ids = df_videos["video_id"].tolist() +caps = {} +for vid_id in vid_ids: + sents = df_sentences[df_sentences["video_id"] == vid_id]["caption"].tolist() + caps[vid_id] = sents + +def add_caption_columns(row, captions_dict): + video_id = row['video_id'] + captions = captions_dict.get(video_id, [''] * 20) + for i, caption in enumerate(captions): + row[f'caption_{i}'] = caption + return row + +df = df.apply(add_caption_columns, axis=1, args=(caps,)) + +df['location'] = "/path/to/TestVideo/" + df["video_id"].astype(str) + ".mp4" + +df.to_parquet("msr-vtt-test.parquet") +``` + +## Create the dataset + +This video2dataset takes the mp4s unzipped in TestVideo and organizes them into a dataset that can be easily read with our dataloader: + +```bash +#!/bin/bash + +video2dataset --url_list="msr-vtt-test.parquet" \ + --input_format="parquet" \ + --output-format="webdataset" \ + --output_folder="/fsx/iejmac/datasets/msr-vtt/dataset" \ + --url_col="location" \ + --save_additional_columns='[video_id,category,id,caption_0,caption_1,caption_2,caption_3,caption_4,caption_5,caption_6,caption_7,caption_8,caption_9,caption_10,caption_11,caption_12,caption_13,caption_14,caption_15,caption_16,caption_17,caption_18,caption_19]' \ +``` diff --git a/dataset_examples/VideoCC.md b/dataset_examples/VideoCC.md new file mode 100644 index 00000000..e28be3a1 --- /dev/null +++ b/dataset_examples/VideoCC.md @@ -0,0 +1,90 @@ +# [VideoCC](https://github.com/google-research-datasets/videoCC-data) + +VideoCC is a video-text dataset with ~10M samples created by starting from an image-text dataset and retrieving videos with frames similar to images in that dataset. + +## Download the metadata + +Go to [this section](https://github.com/google-research-datasets/videoCC-data#data-format-for-videocc) of their README and download the CSV file. It will need some simple processing which can be done with this code snippet: +```python3 +import pandas as pd + +df = pd.read_csv("video_cc_public.csv") +df.columns = ["video_url", "start", "end", "caption"] +df["video_url"] = df["video_url"].apply(lambda x: f"https://www.youtube.com/watch?v={x}") +df['start'] = df['start'] / 1_000_000 +df['end'] = df['end'] / 1_000_000 +df.to_parquet("video_cc.parquet") +``` + +## Create the config + +We want to perform cut detection and clipping based on those cuts so we need to specify some subsamplers. We also want to extract compression metadata so we add the FFProbeSubsampler. + +```yaml +subsampling: + CutDetectionSubsampler: + cuts_are_clips: True + args: + cut_detection_mode: "all" + framerates: null + threshold: 11.5 + min_scene_len: 15 + ClippingSubsampler: + args: + min_length: 3.0 + max_length: 20.0 + max_length_strategy: "all" + precision: "keyframe_adjusted" + FFProbeSubsampler: + args: + extract_keyframes: False + +reading: + yt_args: + download_size: 360 + download_audio_rate: 44100 + yt_metadata_args: + writesubtitles: 'all' + subtitleslangs: ['en'] + writeautomaticsub: True + get_info: True + timeout: 60 + sampler: null + +storage: + number_sample_per_shard: 1000 + oom_shard_count: 5 + captions_are_subtitles: False + +distribution: + processes_count: 48 + thread_count: 48 + subjob_size: 1000 + distributor: "slurm" + distributor_args: + partition: "cpu64" + n_nodes: 50 + account: "laion" + cache_path: "/fsx/home-iejmac/.slurm_cache" +``` + +## Download and process the videos using video2dataset: + +The dataset can be downloaded with cut detection and clipping using the following python code: + +```python3 +video2dataset( + url_list='video_cc.parquet', + input_format='parquet', + output_format='webdataset', + output_folder='dataset', + url_col="video_url", + encode_formats={"video": "mp4"}, + stage="download", + config="path/to/config.yaml" +) +``` + +## Performance + +NOTE: VideoCC is a youtube-based dataset so it will be slower to download than mp4-based ones. On a single cpu16 (16 core) EC2 instance (c6i-4xlarge) the entirety of VideoCC (~5M samples) can be downloaded in ~220h. It achieves 5 video/s (0.31 videos/s/core) or 87 Mb/s and I highly recommend paralellizing this over many nodes. This means the cost to download VideoCC comes out to ~0.68$/hr * 220h = 150$ diff --git a/dataset_examples/WebVid.md b/dataset_examples/WebVid.md new file mode 100644 index 00000000..00be69c7 --- /dev/null +++ b/dataset_examples/WebVid.md @@ -0,0 +1,50 @@ +# [WebVid](https://m-bain.github.io/webvid-dataset/) + +The entirety of WebVid can be downloaded very easily using scripts/configs provided in video2dataset/examples. You don't even need to use complex distribution strategies, multiprocessing is fine to download all 10M samples in a timely manner on a single machine. + +## Create the config: + +```yaml +subsampling: {} + +reading: + yt_args: + download_size: 360 + download_audio_rate: 44100 + yt_metadata_args: null + timeout: 60 + sampler: null + +storage: + number_sample_per_shard: 1000 + oom_shard_count: 5 + captions_are_subtitles: False + +distribution: + processes_count: 16 + thread_count: 16 + subjob_size: 1000 + distributor: "multiprocessing" +``` + +## Download WebVid: + +```bash +#!/bin/bash + +wget -nc http://www.robots.ox.ac.uk/~maxbain/webvid/results_10M_train.csv + +video2dataset --url_list="results_10M_train.csv" \ + --input_format="csv" \ + --output-format="webdataset" \ + --output_folder="dataset" \ + --url_col="contentUrl" \ + --caption_col="name" \ + --save_additional_columns='[videoid,page_idx,page_dir,duration]' \ + --enable_wandb=True \ + --config="path/to/config.yaml" \ +``` + +## Performance + +On a single cpu16 (16 core) EC2 instance (c6i-4xlarge) the entirety of WebVid (10M samples) can be downloaded in ~12h. It achieves 230 video/s (14.4 videos/s/core) or 420 Mb/s and of course can be proportionally sped up by utilizing more nodes via spark or slurm distribution to reduce the processing time even more. The cost to download WebVid is ~ 0.68$/h * 12h = 8.16$ diff --git a/docs/video2dataset_overview.png b/docs/video2dataset_overview.png new file mode 100644 index 00000000..d0e243cc Binary files /dev/null and b/docs/video2dataset_overview.png differ diff --git a/docs/wandb_logs.png b/docs/wandb_logs.png new file mode 100644 index 00000000..7a1e47ec Binary files /dev/null and b/docs/wandb_logs.png differ diff --git a/docs/wandb_status.png b/docs/wandb_status.png new file mode 100644 index 00000000..6f079b5d Binary files /dev/null and b/docs/wandb_status.png differ diff --git a/examples/caption_webvid.sh b/examples/caption_webvid.sh new file mode 100755 index 00000000..ef257f2c --- /dev/null +++ b/examples/caption_webvid.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +video2dataset --url_list="dataset/{00000..00004}.tar" \ + --input_format="webdataset" \ + --output-format="webdataset" \ + --output_folder="dataset_captions" \ + --stage "caption" \ + --encode_formats '{}' \ + --config "caption" diff --git a/examples/dataloader_example.py b/examples/dataloader_example.py new file mode 100644 index 00000000..490b3e45 --- /dev/null +++ b/examples/dataloader_example.py @@ -0,0 +1,36 @@ +from webdataset import WebLoader + +from video2dataset.dataloader import get_video_dataset + +# WebVid validation split +SHARDS = "dataset/{00000..00004}.tar" + +if __name__ == "__main__": + decoder_kwargs = { + "n_frames": 8, # get 8 frames from each video + "fps": 10, # downsample to 10 FPS + "num_threads": 12, # use 12 threads to decode the video + } + resize_size = crop_size = 256 + batch_size = 32 + + dset = get_video_dataset( + urls=SHARDS, + batch_size=batch_size, + decoder_kwargs=decoder_kwargs, + resize_size=resize_size, + crop_size=crop_size, + ) + + num_workers = 6 # 6 dataloader workers + + dl = WebLoader(dset, batch_size=None, num_workers=num_workers) + + for sample in dl: + video_batch = sample["mp4"] + print(video_batch.shape) # torch.Size([32, 8, 256, 256, 3]) + + # TODO: need to add option for text/metadata preprocessing (tokenization etc.) + text_batch = sample["txt"] + print(text_batch[0]) + metadata_batch = sample["json"] diff --git a/examples/default_slurm.yaml b/examples/default_slurm.yaml new file mode 100644 index 00000000..d8e5efc5 --- /dev/null +++ b/examples/default_slurm.yaml @@ -0,0 +1,31 @@ +subsampling: {} + +reading: + yt_args: + download_size: 360 + download_audio_rate: 44100 + yt_metadata_args: + writesubtitles: 'all' + subtitleslangs: ['en'] + writeautomaticsub: True + get_info: True + timeout: 60 + sampler: null + +storage: + number_sample_per_shard: 1000 + oom_shard_count: 5 + captions_are_subtitles: False + +distribution: + processes_count: 1 + thread_count: 8 + subjob_size: 1000 + distributor: "slurm" + distributor_args: + cpus_per_task: 16 + job_name: "v2ds" + partition: "" + n_nodes: 2 + account: "" + cache_path: "/.slurm_cache" diff --git a/examples/distributed_spark.md b/examples/distributed_spark.md new file mode 100644 index 00000000..285fb65c --- /dev/null +++ b/examples/distributed_spark.md @@ -0,0 +1,273 @@ +# Distributed video2dataset tutorial + +video2dataset can be used on a single machine to download at around 14 sample/s/core. + +However, what if you have billion of samples and you don't want to wait weeks ? + +To support that use case, video2dataset proposes to use multiple machines by setting up a pyspark cluster. +This document will help you setup such a cluster and run video2dataset on it. + +## Where to get a cluster, what machines to use? + +These providers have been tested to work well with video2dataset: +* aws c6i.4xlarge nodes ($0.68/h for 230 sample/s) + +Ubuntu 20.04 works well with video2dataset. Centos7 also works. +Other providers probably work too but haven't been tested. + +## Setting up a pyspark cluster + +### You already got a cluster + +That option is of course the best. If you have an existing on-premise cluster, or you're using a cloud cluster like amazon emr, then you're all set, go directly to the use video2dataset section. +You may want to put https://github.com/iejMac/video2dataset/releases/latest/download/video2dataset.pex in a place that is available to all your nodes. + +### You don't have a cluster, but you have access to N machines over ssh + +That's a common case, you have access to N machines, and you have a place to store the data. +This is actually fairly easy to use this to setup a pyspark cluster. Let's see how to do it. + +Tools: +* spark and pyspark +* parallel ssh +* pex + +We will be assuming ubuntu 20.04. + + +#### Setup the master node + +On the master node: + +First download spark: +```bash +wget https://archive.apache.org/dist/spark/spark-3.2.0/spark-3.2.0-bin-hadoop3.2.tgz +tar xf spark-3.2.0-bin-hadoop3.2.tgz +``` + +Then download video2dataset: +```bash +wget https://github.com/iejMac/video2dataset/releases/latest/download/video2dataset.pex -O video2dataset.pex +``` + +If the master node cannot open ports that are visible from your local machine, you can do a tunnel between your local machine and the master node to be able to see the spark ui (at http://localhost:8080) +```bash +ssh -L 8080:localhost:8080 -L 4040:localhost:4040 master_node +``` + + +#### Setup the worker nodes + +##### ssh basic setup + +Still in the master node, create a ips.txt with the ips of all the nodes + +```bash +ssh-keyscan `cat ips.txt` >> ~/.ssh/known_hosts +``` + +You may use a script like this to fill your .ssh/config file +``` +def generate(ip): + print( + f"Host {ip}\n" + f" HostName {ip}\n" + " User ubuntu\n" + " IdentityFile ~/yourkey.pem" + ) + +with open("ips.txt") as f: + lines = f.readlines() + for line in lines: + generate(line.strip()) +``` +python3 generate.py >> ~/.ssh/config + +Install pssh with `sudo apt install pssh` + +Pick the right username (MASTER_USER) for the master node, and (USER) for the worker nodes, then run this to check your parallel ssh setup: +```bash +MASTER_USER=iejMac +USER=iejMac +``` + +Optionally, if another node than the current one has access to the worker nodes, you may need to add a ssh key to all the nodes with: +``` +for IP in `cat ips.txt` +do + ssh-copy-id -i the_new_id_rsa $USER@$IP +done +``` + +Check you can connect to all the nodes with: +``` +parallel-ssh -l $USER -i -h ips.txt uname -a +``` + +##### Install some packages + +```bash +sudo apt update +sudo apt install openjdk-11-jre-headless libgl1 htop tmux bwm-ng sshfs -y +``` + +```bash +parallel-ssh -l $USER -i -h ips.txt "sudo apt update" +parallel-ssh -l $USER -i -h ips.txt "sudo apt install openjdk-11-jre-headless libgl1 htop tmux bwm-ng sshfs -y" +``` + + +#### Network setting + +on master: +```bash +sudo sh -c 'echo `hostname -I` `hostname` >> /etc/hosts' +``` + +on workers +```bash +parallel-ssh -l $USER -i -h ips.txt "sudo sh -c 'echo \`hostname -I\` \`hostname\` >> /etc/hosts'" +``` + + +### Install knot resolver + +```bash +parallel-ssh -l $USER -i -h ips.txt "sudo apt update && sudo apt install libgl1 htop tmux bwm-ng python3.8-venv awscli -y" +parallel-ssh -l $USER -i -h ips.txt "wget https://secure.nic.cz/files/knot-resolver/knot-resolver-release.deb && sudo dpkg -i knot-resolver-release.deb && sudo apt update && sudo apt install -y knot-resolver" +``` + +```bash +parallel-ssh -l $USER -i -h ips.txt "sudo systemctl stop systemd-resolved" +parallel-ssh -l $USER -i -h ips.txt "sudo systemctl start kresd@{1..4}.service" +parallel-ssh -l $USER -i -h ips.txt 'sudo sh -c "echo nameserver 127.0.0.1 > /etc/resolv.conf"' +parallel-ssh -l $USER -i -h ips.txt 'dig @localhost google.com' +``` + + +##### Download video2dataset on all nodes + +Download video2dataset on all node by retrying this N times until parallel ssh says success for all: +```bash +parallel-ssh -i -h ips.txt "wget -c https://github.com/iejMac/video2dataset/releases/latest/download/video2dataset.pex -O video2dataset_new.pex" +``` +Then: +```bash +parallel-ssh -l $USER -i -h ips.txt "mv video2dataset_new.pex video2dataset.pex" +parallel-ssh -l $USER -i -h ips.txt "chmod +x video2dataset.pex" +``` + +##### Download spark on workers + +```bash +parallel-ssh -l $USER -i -h ips.txt "wget https://archive.apache.org/dist/spark/spark-3.2.0/spark-3.2.0-bin-hadoop3.2.tgz" +parallel-ssh -l $USER -i -h ips.txt "tar xf spark-3.2.0-bin-hadoop3.2.tgz" +``` + +#### Start the master node + +When you're ready, you can start the master node with: + +```bash +./spark-3.2.0-bin-hadoop3.2/sbin/start-master.sh -h master_node -p 7077 +``` + +Replace master_node by the master node ip. + + +#### Start the worker nodes + +When you're ready, you can start the worker nodes with: + +```bash +parallel-ssh -l $USER -i -h ips.txt "./spark-3.2.0-bin-hadoop3.2/sbin/start-worker.sh -c 16 -m 16G spark://master_node:7077" +``` + +Replace master_node by the master node ip. +Replace -c 16 -m 16g but the number of cores and ram you want to use on each worker. + + +#### Stop the worker nodes + +When you're done, you can stop the worker nodes with: + +```bash +parallel-ssh -l $USER -i -h ips.txt "rm -rf ~/spark-3.2.0-bin-hadoop3.2/work/*" +pkill -f "ssh -R" +parallel-ssh -l $USER -i -h ips.txt "pkill java" +``` + + +#### Stop the master node + +When you're done, you can stop the master node with: + +```bash +pkill java +``` + + +### Running video2dataset on it + +Once your spark cluster is setup, you're ready to start video2dataset in distributed mode. +Make sure to open your spark UI, at http://master_node:8080 + +Save this script to download.py. + +Then run ./video2dataset.pex download.py + +Replace master_node by the master node ip. + +```python +from video2dataset import video2dataset +import shutil +import os +from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel + +from pyspark import SparkConf, SparkContext + +def create_spark_session(): + # this must be a path that is available on all worker nodes + pex_file = "/home/iejMac/video2dataset.pex" + + os.environ['PYSPARK_PYTHON'] = pex_file + spark = ( + SparkSession.builder + .config("spark.submit.deployMode", "client") \ + #.config("spark.files", pex_file) \ # you may choose to uncomment this option if you want spark to automatically download the pex file, but it may be slow + .config("spark.executorEnv.PEX_ROOT", "./.pex") + #.config("spark.executor.cores", "2") # this can be set to the number of cores of the machine + #.config("spark.cores.max", "200") # total number of cores to use over the whole spark cluster + .config("spark.driver.port", "5678") + .config("spark.driver.blockManager.port", "6678") + .config("spark.driver.host", "master_node") + .config("spark.driver.bindAddress", "master_node") + .config("spark.executor.memory", "16GB") # make sure to increase this if you're using more cores per executor + .config("spark.executor.memoryOverhead", "8GB") + .config("spark.task.maxFailures", "100") + .master("spark://master_node:7077") # this should point to your master node, if using the tunnelling version, keep this to localhost + .appName("spark-stats") + .getOrCreate() + ) + return spark + +output_dir = "/tmp/bench" + + +spark = create_spark_session() + +url_list = "some_file.parquet" + +video2dataset( + url_list=url_list, + output_folder=output_dir, + output_format="webdataset", + input_format="parquet", + url_col="videoLoc", + caption_col="title", + clip_col="clip", + save_additional_columns=["description", "videoID", "start", "end"], + enable_wandb=True, + config="path/to/config.yaml" +) +``` diff --git a/examples/distributed_spark_slurm.md b/examples/distributed_spark_slurm.md new file mode 100644 index 00000000..22e3a2d2 --- /dev/null +++ b/examples/distributed_spark_slurm.md @@ -0,0 +1,181 @@ +# video2dataset distributing with pyspark on a slurm cluster + +Using PySpark for distributing work in video2dataset is preferred for larger datasets. Here's how we used it on a slurm cluster to download and process 40M youtube videos. + +## Setup + +Download and extract spark via ```wget https://archive.apache.org/dist/spark/spark-3.3.1/spark-3.3.1-bin-hadoop3.tgz && tar xf spark-3.3.1-bin-hadoop3.tgz```. Make sure your pyspark version is the same. + +## Creating the spark server + +To create the spark server you need to run the following sbatch script via ```sbatch``` + +```bash +#!/bin/bash +#SBATCH --partition=gpu +#SBATCH --job-name=spark_on_slurm +#SBATCH --nodes 2 +#SBATCH --ntasks-per-node 1 +#SBATCH --cpus-per-task=48 +#SBATCH --mem=0 # 0 means use all available memory (in MB) +#SBATCH --output=%x_%j.out +#SBATCH --comment laion +#SBATCH --exclusive + +srun --comment laion bash worker_spark_on_slurm.sh +``` + +Which runs this bash script: + +```bash +#!/bin/bash +# +# get environment variables +GLOBAL_RANK=$SLURM_PROCID +CPUS=`grep -c ^processor /proc/cpuinfo` +MEM=$((`grep MemTotal /proc/meminfo | awk '{print $2}'`/1000)) # seems to be in MB +MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +LOCAL_IP=$(hostname -I | awk '{print $1}') + +# setup the master node +if [ $GLOBAL_RANK == 0 ] +then + # print out some info + echo -e "MASTER ADDR: $MASTER_ADDR\tGLOBAL RANK: $GLOBAL_RANK\tCPUS PER TASK: $CPUS\tMEM PER NODE: $MEM" + + # then start the spark master node in the background + ./spark-3.3.1-bin-hadoop3/sbin/start-master.sh -p 7077 -h $LOCAL_IP + +fi + +sleep 10 + +# then start the spark worker node in the background +MEM_IN_GB=$(($MEM / 1000)) +# concat a "G" to the end of the memory string +MEM_IN_GB="$MEM_IN_GB"G +echo "MEM IN GB: $MEM_IN_GB" + +./spark-3.3.1-bin-hadoop3/sbin/start-worker.sh -c $CPUS -m $MEM_IN_GB "spark://$MASTER_ADDR:7077" +echo "Hello from worker $GLOBAL_RANK" + +sleep 10 + +if [ $GLOBAL_RANK == 0 ] +then + # then start some script + echo "hi" +fi + +sleep 1000000 +``` + +If you want to run the spark server in the background of high GPU utilization and low CPU utilization jobs you can do so with the following python script + +```python3 +from pssh.clients import ParallelSSHClient +import subprocess +import socket +import fire +import os + +def get_ips_of_slurm_job(job_id): + c = "sinfo -N -n `squeue -j "+str(job_id)+" | tail -1 | awk '{print $8}'` | tail -n +2 | awk '{print $1}'" + hosts = subprocess.check_output(c, shell=True).decode("utf8")[:-1].split("\n") + ips = [socket.gethostbyname(host) for host in hosts] + return ips + +def run(ips_to_run, command): + print(ips_to_run) + client = ParallelSSHClient(ips_to_run, timeout=10, pool_size=len(ips_to_run)) + output = list(client.run_command(command, stop_on_errors=False)) + print([(o.client.host if o.client is not None else "", ("\n".join(o.stdout) if o.stdout else "")) for o in output]) + + +def start_spark_cluster(ips, cpus, mem_in_gb, spark_path, spark_local_dir): + master_addr = ips[0] + run([master_addr], f'{spark_path}/spark-3.3.1-bin-hadoop3/sbin/start-master.sh -p 7077 -h {master_addr}') + c = f'{spark_path}/spark-3.3.1-bin-hadoop3/sbin/start-worker.sh -c {cpus} -m {mem_in_gb}g "spark://{master_addr}:7077"' + run(ips, "bash -c 'SPARK_WORKER_DIR="+spark_local_dir+"/work SPARK_LOCAL_DIRS="+spark_local_dir+"/local "+c+"'") + + +def stop_spark_cluster(ips): + run(ips, f'bash -c "pkill java"') + + +def main(cpus=48, mem_in_gb=256,spark_path=None,job_id=None,command="start", spark_local_dir="/tmp"): + if spark_path is None: + spark_path = os.getcwd() + + ips = get_ips_of_slurm_job(job_id) + if command == "stop": + stop_spark_cluster(ips) + elif command == "start": + start_spark_cluster(ips, cpus, mem_in_gb, spark_path, spark_local_dir) + +# To start the server: +# python spark_on_ssh.py --cpus=48 --mem_in_gb=256 --job_id=SLURM_JOB_ID --spark_local_dir="/scratch/spark" --command="start" +# To stop the server: +# python spark_on_ssh.py --cpus=48 --mem_in_gb=256 --job_id=SLURM_JOB_ID --spark_local_dir="/scratch/spark" --command="stop" + +if __name__ == '__main__': + fire.Fire(main) +``` + +## Running the video2dataset job + +Once you have the slurm cluster running you can ssh into the master node and simply run the following python script adjusted for your particular use case. You might need to adjust num_cores, mem_gb, master_node IP, etc. + +```python3 +from video2dataset import video2dataset +import sys +import shutil +import os +from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel + +output_dir = os.path.abspath("bench") + +def aws_ec2_s3_spark_session(master, num_cores=128, mem_gb=256): + os.environ["PYSPARK_PYTHON"] = sys.executable + os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable + main_memory = str(int(mem_gb * 0.9)) + "g" + memory_overhead = str(mem_gb - int(mem_gb * 0.9)) + "g" + spark = ( + SparkSession.builder.config("spark.submit.deployMode", "client") + .config("spark.executor.memory", main_memory) + .config("spark.executor.cores", str(num_cores)) # this can be set to the number of cores of the machine + .config("spark.task.cpus", "1") + .config("spark.executor.memoryOverhead", memory_overhead) + .config("spark.task.maxFailures", "2") + .master(master) # this should be set to the spark master url + .appName("video2dataset") + .getOrCreate() + ) + return spark + + +if __name__ == "__main__": + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + master_node = "IP OF THE MASTER NODE" + spark = aws_ec2_s3_spark_session(f"spark://{master_node}:7077", num_cores=48, mem_gb=256) + + video2dataset( + url_list="urls", + output_folder="out", + output_format="webdataset", + input_format="parquet", + url_col="videoLoc", + caption_col="title", + clip_col="clip", + save_additional_columns=["description", "videoID", "start", "end"], + enable_wandb=True, + ) +``` + +Once you run this the video2dataset job should be distributed among all spark workers. + +## Checking on the job + +You can check the output of the workers in the spark folder you untarred earier in the work or logs directories. You can also check the spark UI by doing ```ssh -L 4040:localhost:4040 -L 8080:localhost:8080 login_node``` followed by ```ssh -L localhost:4040:master_node:4040 -L localhost:8080:master_node:8080 master_node``` and checking http://localhost:4040 and http://localhost:8080 in your browser. diff --git a/examples/download_webvid.sh b/examples/download_webvid.sh new file mode 100755 index 00000000..98be4ff4 --- /dev/null +++ b/examples/download_webvid.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +wget -nc http://www.robots.ox.ac.uk/~maxbain/webvid/results_2M_val.csv + +video2dataset --url_list="results_2M_val.csv" \ + --input_format="csv" \ + --output-format="webdataset" \ + --output_folder="dataset" \ + --url_col="contentUrl" \ + --caption_col="name" \ + --save_additional_columns='[videoid,page_idx,page_dir,duration]' \ + --enable_wandb=True \ + --config=default \ diff --git a/examples/downsample_webvid.sh b/examples/downsample_webvid.sh new file mode 100755 index 00000000..d40ee732 --- /dev/null +++ b/examples/downsample_webvid.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +video2dataset --url_list="dataset/{00000..00004}.tar" \ + --input_format="webdataset" \ + --output-format="webdataset" \ + --output_folder="dataset_downsampled" \ + --stage "subset" \ + --config "downsample_ml" \ diff --git a/examples/optical_flow_webvid.sh b/examples/optical_flow_webvid.sh new file mode 100755 index 00000000..80660309 --- /dev/null +++ b/examples/optical_flow_webvid.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +video2dataset --url_list="dataset/{00000..00004}.tar" \ + --input_format="webdataset" \ + --output-format="webdataset" \ + --output_folder="dataset_optical_flow" \ + --stage "optical_flow" \ + --encode_formats '{"optical_flow": "npy"}' \ + --config "optical_flow" \ diff --git a/examples/yt_metadata.md b/examples/yt_metadata.md new file mode 100644 index 00000000..0d06faf4 --- /dev/null +++ b/examples/yt_metadata.md @@ -0,0 +1,240 @@ +### Download YouTube metadata & subtitles: +#### Usage + +In order to extract metadata from youtube videos you must configure your config such that the the `yt_metadata_args` entry is present in the `yt_args` key of the reading specifications as such: + + +```yaml +yt_args: + download_size: 360 + download_audio_rate: 44100 + yt_metadata_args: + writesubtitles: 'all' + subtitleslangs: ['en'] + writeautomaticsub: True + get_info: True +``` + +Additionally if you specify `captions_are_subtitles` to be true in the storage parameters then each video and audio sample will be clipped according to the subtitles and divided into many unique samples. + +#### Output + +For every sample the metadata will be present in the json file as such: + +```json +{ + "description": "For the past five years, King Fish has been creating a media channel for IBM to generate leads of senior IT decision makers and retain current customers. We produce dozens of webcasts every year for numerous divisions within IBM. King Fish provides managed services, original content and audience development. \n\nKFM worked with IBM to develop video content on how SPSS Statistics can help their clients meet business goals with advanced data insight methods. The result? Much more effective than an info-graphic.", + "videoID": "QW3-5OuWn4M", + "start": 56.1025, + "end": 66.10249999999999, + "caption": "IBM SPSS", + "url": "http://youtube.com/watch?v=QW3-5OuWn4M", + "key": "000000_00001", + "status": "success", + "error_message": null, + "yt_meta_dict": { + "info": { + "id": "QW3-5OuWn4M", + "title": "IBM SPSS", + "thumbnail": "https://i.ytimg.com/vi/QW3-5OuWn4M/maxresdefault.jpg", + "description": "For the past five years, King Fish has been creating a media channel for IBM to generate leads of senior IT decision makers and retain current customers. We produce dozens of webcasts every year for numerous divisions within IBM. King Fish provides managed services, original content and audience development. \n\nKFM worked with IBM to develop video content on how SPSS Statistics can help their clients meet business goals with advanced data insight methods. The result? Much more effective than an info-graphic.", + "uploader": "King Fish Media", + "uploader_id": "KingFishMediaBoston", + "uploader_url": "http://www.youtube.com/user/KingFishMediaBoston", + "channel_id": "UCDy7Xb5vYxbmSosQmztCCcQ", + "channel_url": "https://www.youtube.com/channel/UCDy7Xb5vYxbmSosQmztCCcQ", + "duration": 122, + "view_count": 116, + "average_rating": null, + "age_limit": 0, + "webpage_url": "https://www.youtube.com/watch?v=QW3-5OuWn4M", + "categories": [ + "Science & Technology" + ], + "tags": [ + "IBM", + "technology", + "statistics", + "data", + "analysis", + "computers", + "content marketing", + "Software" + ], + "playable_in_embed": true, + "live_status": "not_live", + "release_timestamp": null, + "comment_count": null, + "chapters": null, + "like_count": 1, + "channel": "King Fish Media", + "channel_follower_count": 10, + "upload_date": "20131107", + "availability": "public", + "original_url": "http://youtube.com/watch?v=QW3-5OuWn4M", + "webpage_url_basename": "watch", + "webpage_url_domain": "youtube.com", + "extractor": "youtube", + "extractor_key": "Youtube", + "playlist": null, + "playlist_index": null, + "display_id": "QW3-5OuWn4M", + "fulltitle": "IBM SPSS", + "duration_string": "2:02", + "is_live": false, + "was_live": false, + "requested_subtitles": { + "en": { + "ext": "vtt", + "url": "https://www.youtube.com/api/timedtext?v=QW3-5OuWn4M&caps=asr&xoaf=5&hl=en&ip=0.0.0.0&ipbits=0&expire=1676200746&sparams=ip%2Cipbits%2Cexpire%2Cv%2Ccaps%2Cxoaf&signature=A43F4C223A9DBC7E3BFBC61027FC5AF70D709AB5.B386EB52DD412DEFC3E8DBBCF7F30C442473CDA4&key=yt8&kind=asr&lang=en&fmt=vtt", + "name": "English" + } + }, + "_has_drm": null, + "format": "137 - 1920x1080 (1080p)+251 - audio only (medium)", + "format_id": "137+251", + "ext": "mkv", + "protocol": "https+https", + "language": null, + "format_note": "1080p+medium", + "filesize_approx": 12831366, + "tbr": 841.009, + "width": 1920, + "height": 1080, + "resolution": "1920x1080", + "fps": 30, + "dynamic_range": "SDR", + "vcodec": "avc1.640028", + "vbr": 691.069, + "stretched_ratio": null, + "acodec": "opus", + "abr": 149.94, + "asr": 48000, + "audio_channels": 2 + } + }, + "clips": [ + [ + "00:00:07.749", + "00:00:07.759" + ] + ] +} +``` + +And since we specified that captions_are_subtitles the txt file will have the subtitle for that given clip inside of it. For this particular example it would be: "analytics to assess performance based on" + +#### Multilingual Subtitles +To control the language/s of the subtitles from your videos, you can prvoide either `'first'` or `'all'` for `writesubtitles` (any value that evalutes to True will work also work as `'all'`). + +`first`: This will extract subtitles for the first language that is in `subtitleslangs` for which there exists subtitles. \ +`all`: Attempt to extract subtitles for every language in `subtitleslangs`. + +Below are some example outputs with `subtitleslangs: ['en', 'es', 'fr']`. + +Using `writesubtitles: 'first'`: +```json +{ + "url": "https://www.youtube.com/watch?v=CvHAfXKIvgw", + ... + "yt_meta_dict": { + ... + "subtitles": { + "en": [ + { + "start": "00:00:02.100", + "end": "00:00:03.360", + "lines": [ + "Good morning Lisa" + ] + }, + ... + ] + } + }, + "clips": [ + [ + 2.1, + 3.36 + ] + ], + "clip_subtitles": [ + { + "start": "00:00:02.100", + "end": "00:00:03.360", + "lines": { + "en": [ + "Good morning Lisa" + ] + } + } + ] +} +``` + + +Using `writesubtitles: 'all'`: +```json +{ + "url": "https://www.youtube.com/watch?v=CvHAfXKIvgw", + ... + "yt_meta_dict": { + ... + "subtitles": { + "en": [ + { + "start": "00:00:02.100", + "end": "00:00:03.360", + "lines": [ + "Good morning Lisa" + ] + }, + ... + ], + "es": [ + { + "start": "00:00:02.100", + "end": "00:00:03.360", + "lines": [ + "Buenos d\u00edas Lisa" + ] + }, + ... + ], + "fr": [ + { + "start": "00:00:02.100", + "end": "00:00:03.360", + "lines": [ + "Bonjour Lisa" + ] + }, + ... + ] + } + }, + "clips": [ + [ + 2.1, + 3.36 + ] + ], + "clip_subtitles": [ + { + "start": "00:00:02.100", + "end": "00:00:03.360", + "lines": { + "en": [ + "Good morning Lisa" + ], + "es": [ + "Buenos d\u00edas Lisa" + ], + "fr": [ + "Bonjour Lisa" + ] + } + } + ] +} +``` diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..76636986 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,5 @@ +# Global options: + +[mypy] +python_version = 3.8 +ignore_missing_imports = True diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 00000000..7b6d7de9 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,15 @@ +black==23.12.1 +mypy==1.8.0 +pylint==3.0.3 +pytest-cov==4.1.0 +pytest-xdist==3.5.0 +pytest==7.4.4 +types-requests==2.31.0.20240125 +types-PyYAML==6.0.12.12 +types-requests +webvtt-py +tensorflow +tensorflow_io +torch +types-pyyaml +whisperx @ git+https://github.com/m-bain/whisperx@main diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..ae5bc99f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +numpy>=1.19.5,<2 +fire>=0.4.0,<0.6.0 +requests>=2.27.1,<3 +ffmpeg-python +yt_dlp +pyarrow +fsspec +webdataset +webvtt-py +wandb +pandas +tqdm +timeout-decorator +scenedetect[opencv]==0.6.2 +decord +#eva-decord # use eva-decord instead of decord on Mac (https://github.com/dmlc/decord/issues/213) +torch==2.1.2 +langdetect +torchdata==0.7.1 +torchaudio==2.1.2 +soundfile +omegaconf +einops +transformers==4.37.1 +Pillow +accelerate +bitsandbytes +scipy diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..1d164d2d --- /dev/null +++ b/setup.py @@ -0,0 +1,50 @@ +from setuptools import setup, find_packages +from pathlib import Path +import os + +if __name__ == "__main__": + with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file: + long_description = file.read() + + def _read_reqs(relpath): + fullpath = os.path.join(os.path.dirname(__file__), relpath) + with open(fullpath) as f: + return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))] + + REQUIREMENTS = _read_reqs("requirements.txt") + + setup( + name="video2dataset", + packages=find_packages(), + include_package_data=True, + version="1.2.0", + license="MIT", + description="Easily create large video dataset from video urls", + long_description=long_description, + long_description_content_type="text/markdown", + entry_points={"console_scripts": ["video2dataset=video2dataset.main:main"]}, + author="Maciej Kilian", + author_email="kilianmaciej6@gmail.com", + url="https://github.com/iejMac/video2dataset", + data_files=[ + ( + ".", + [ + "README.md", + "video2dataset/configs/caption.yaml", + "video2dataset/configs/default.yaml", + "video2dataset/configs/downsample_ml.yaml", + "video2dataset/configs/optical_flow.yaml", + ], + ) + ], + keywords=["machine learning"], + install_requires=REQUIREMENTS, + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.6", + ], + ) diff --git a/tests/test_data_reader.py b/tests/test_data_reader.py new file mode 100644 index 00000000..1010b473 --- /dev/null +++ b/tests/test_data_reader.py @@ -0,0 +1,37 @@ +import os + +import pandas as pd +import pytest +import tempfile +import subprocess +import ffmpeg + +from video2dataset.data_reader import VideoDataReader + + +@pytest.mark.parametrize("input_file", ["test_yt.csv"]) +def test_data_reader(input_file): + encode_formats = {"video": "mp4", "audio": "mp3"} + current_folder = os.path.dirname(__file__) + url_list = pd.read_csv(os.path.join(current_folder, f"test_files/{input_file}"))["contentUrl"] + + reading_config = { + "yt_args": { + "download_size": 360, + "download_audio_rate": 12000, + }, + "timeout": 60, + "sampler": None, + } + + with tempfile.TemporaryDirectory() as tmpdir: + video_data_reader = VideoDataReader( + encode_formats=encode_formats, + tmp_dir=tmpdir, + reading_config=reading_config, + ) + for i, url in enumerate(url_list): + key, streams, yt_meta_dict, error_message = video_data_reader((i, url)) + + assert len(streams.get("audio", [])) > 0 + assert len(streams.get("video", [])) > 0 diff --git a/tests/test_data_writers.py b/tests/test_data_writers.py new file mode 100644 index 00000000..d3c573bf --- /dev/null +++ b/tests/test_data_writers.py @@ -0,0 +1,125 @@ +from video2dataset.data_writer import ( + FilesSampleWriter, + WebDatasetSampleWriter, + ParquetSampleWriter, + DummySampleWriter, + TFRecordSampleWriter, +) + +import os +import glob +import pytest +import tarfile +import pandas as pd +import pyarrow as pa + + +@pytest.mark.parametrize("modalities", [["video", "audio"], ["video"], ["audio"]]) +@pytest.mark.parametrize("writer_type", ["files", "webdataset", "parquet", "dummy", "tfrecord"]) +def test_writer(modalities, writer_type, tmp_path): + current_folder = os.path.dirname(__file__) + test_folder = str(tmp_path) + output_folder = test_folder + "/" + "test_write" + os.mkdir(output_folder) + + schema = pa.schema( + [ + pa.field("key", pa.string()), + pa.field("caption", pa.string()), + pa.field("status", pa.string()), + pa.field("error_message", pa.string()), + pa.field("width", pa.int32()), + pa.field("height", pa.int32()), + pa.field("audio_rate", pa.int32()), + ] + ) + if writer_type == "files": + writer_class = FilesSampleWriter + elif writer_type == "webdataset": + writer_class = WebDatasetSampleWriter + elif writer_type == "parquet": + writer_class = ParquetSampleWriter + elif writer_type == "dummy": + writer_class = DummySampleWriter + elif writer_type == "tfrecord": + writer_class = TFRecordSampleWriter + + streams = {} + encode_formats = {} + for mod in modalities: + encode_formats[mod] = "mp4" if mod == "video" else "mp3" + with open(os.path.join(current_folder, f"test_files/test_{mod}.{encode_formats[mod]}"), "rb") as f: + streams[mod] = f.read() + + n_samples = 1 + + writer = writer_class(0, output_folder, True, 5, schema, encode_formats) + i = 0 # TODO: maybe add more samples + writer.write( + streams=streams, + key=str(i), + caption=str(i), + meta={ + "key": str(i), + "caption": str(i), + "status": "ok", + "error_message": "", + "width": 100, + "height": 100, + "audio_rate": 12000, + }, + ) + writer.close() + + if writer_type != "dummy": + df = pd.read_parquet(output_folder + "/00000.parquet") + + expected_columns = [ + "key", + "caption", + "status", + "error_message", + "width", + "height", + "audio_rate", + ] + + if writer_type == "parquet": + for fmt in encode_formats.values(): + expected_columns.append(fmt) + + assert df.columns.tolist() == expected_columns + + assert df["key"].iloc[0] == "0" + assert df["caption"].iloc[0] == "0" + assert df["status"].iloc[0] == "ok" + assert df["error_message"].iloc[0] == "" + assert df["width"].iloc[0] == 100 + assert df["height"].iloc[0] == 100 + assert df["audio_rate"].iloc[0] == 12000 + + n_files = (len(encode_formats) + len(["caption", "meta"])) * n_samples + + if writer_type == "files": + saved_files = list(glob.glob(output_folder + "/00000/*")) + assert len(saved_files) == n_files + elif writer_type == "webdataset": + l = glob.glob(output_folder + "/*.tar") + assert len(l) == 1 + if l[0] != output_folder + "/00000.tar": + raise Exception(l[0] + " is not 00000.tar") + assert len(tarfile.open(output_folder + "/00000.tar").getnames()) == n_files + elif writer_type == "parquet": + l = glob.glob(output_folder + "/*.parquet") + assert len(l) == 1 + if l[0] != output_folder + "/00000.parquet": + raise Exception(l[0] + " is not 00000.parquet") + assert len(df.index) == n_samples + elif writer_type == "dummy": + l = glob.glob(output_folder + "/*") + assert len(l) == 0 + elif writer_type == "tfrecord": + l = glob.glob(output_folder + "/*.tfrecord") + assert len(l) == 1 + if l[0] != output_folder + "/00000.tfrecord": + raise Exception(l[0] + " is not 00000.tfrecord") diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py new file mode 100644 index 00000000..aa2bba6e --- /dev/null +++ b/tests/test_dataloaders.py @@ -0,0 +1,104 @@ +""" +Test for dataloader with and without return all +""" +import pytest +import webdataset as wds +from video2dataset.dataloader import get_video_dataset + +# Benchmark videos are the WebVid validation split (5000 videos) +SHARDS = "tests/test_files/return_all_test.tar" + + +@pytest.mark.parametrize("batch_size", [1, 4]) +def test_return_all(batch_size): + decoder_kwargs = {"n_frames": 10, "fps": None, "num_threads": 1} + + dset = get_video_dataset( + urls=SHARDS, + batch_size=batch_size, + decoder_kwargs=decoder_kwargs, + resize_size=None, + crop_size=None, + return_always=True, + keys_to_remove=["m4a"], + repeat=1, + handler=wds.warn_and_continue, + ) + dl = wds.WebLoader(dset, batch_size=None, num_workers=0) + + expected_keys = { + "000000014": True, + "000000025": True, + "000000356": True, + "0000008_00001": False, + "0000030_00005": False, + "0000038_00003": False, + } + received_keys = [] + for samp in dl: + for i in range(len(samp["__key__"])): + key = samp["__key__"][i] + received_keys.append(key) + assert expected_keys[key] == samp["__corrupted__"][i].item() + + assert set(received_keys) == set(expected_keys) + + +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_default(batch_size): + decoder_kwargs = {"n_frames": 10, "fps": None, "num_threads": 1} + + dset = get_video_dataset( + urls=SHARDS, + batch_size=batch_size, + decoder_kwargs=decoder_kwargs, + resize_size=None, + crop_size=None, + return_always=False, + keys_to_remove=["m4a"], + repeat=1, + handler=wds.warn_and_continue, + ) + dl = wds.WebLoader(dset, batch_size=None, num_workers=0) + + expected_keys = ["0000008_00001", "0000030_00005", "0000038_00003"] + + received_keys = [] + for samp in dl: + assert "__corrupted__" not in samp + for i in range(len(samp["__key__"])): + key = samp["__key__"][i] + received_keys.append(key) + + assert set(received_keys) == set(expected_keys) + + +@pytest.mark.parametrize( + "batch_size, expected_keys", + [(1, ["0000008_00001", "0000030_00005", "0000038_00003"]), (2, ["0000008_00001", "0000030_00005"])], +) +def test_drop_last(batch_size, expected_keys): + decoder_kwargs = {"n_frames": 10, "fps": None, "num_threads": 1} + + dset = get_video_dataset( + urls=SHARDS, + batch_size=batch_size, + decoder_kwargs=decoder_kwargs, + resize_size=None, + crop_size=None, + return_always=False, + keys_to_remove=["m4a"], + drop_last=True, + repeat=1, + handler=wds.warn_and_continue, + ) + dl = wds.WebLoader(dset, batch_size=None, num_workers=0) + + received_keys = [] + for samp in dl: + assert "__corrupted__" not in samp + for i in range(len(samp["__key__"])): + key = samp["__key__"][i] + received_keys.append(key) + + assert set(received_keys) == set(expected_keys) diff --git a/tests/test_downloaders.py b/tests/test_downloaders.py new file mode 100644 index 00000000..be2ed923 --- /dev/null +++ b/tests/test_downloaders.py @@ -0,0 +1,53 @@ +"""test video2dataset downloaders""" +import os +import pytest +import ffmpeg + + +from video2dataset.data_reader import YtDlpDownloader, WebFileDownloader + + +YT_URL = "https://www.youtube.com/watch?v=jLX0D8qQUBM" +DM_URL = "https://www.dailymotion.com/video/x29ryo7" +MP4_URL = "https://ak.picdn.net/shutterstock/videos/1053841541/preview/stock-footage-travel-blogger-shoot-a-story-on-top-of-mountains-young-man-holds-camera-in-forest.mp4" + +full_encode_formats = {"video": "mp4", "audio": "m4a"} + + +@pytest.mark.parametrize("modalities", [["video", "audio"], ["video"], ["audio"]]) +@pytest.mark.parametrize("video_size", [361, 1080]) +@pytest.mark.parametrize("url", [YT_URL, DM_URL]) +def test_yt_downloader(modalities, video_size, url): + encode_formats = dict([(modality, full_encode_formats[modality]) for modality in modalities]) + yt_args = {"download_size": video_size, "yt_metadata_args": None} + + ytdlp_downloader = YtDlpDownloader(encode_formats, "/tmp", yt_args) + + modality_paths, yt_meta_dict, error_message = ytdlp_downloader(url) + assert error_message is None + + for modality, path in modality_paths.items(): + if modality == "video": + probe = ffmpeg.probe(modality_paths["video"]) + video_stream = [stream for stream in probe["streams"] if stream["codec_type"] == "video"][0] + height = int(video_stream["height"]) + + assert height == 480 + elif modality == "audio": + with open(path, "rb") as f: + audio_bytes = f.read() + assert len(audio_bytes) > 0 + else: + assert error_messsage == "Error: Untested modality" + os.remove(path) + + +def test_webfile_downloader(): + webfile_downloader = WebFileDownloader(timeout=10, tmp_dir="/tmp", encode_formats={"video": "mp4"}) + + modality_paths, error_message = webfile_downloader(MP4_URL) + + assert error_message is None + with open(modality_paths["video"], "rb") as f: + assert len(f.read()) > 0 + os.remove(modality_paths["video"]) diff --git a/tests/test_files/return_all_test.tar b/tests/test_files/return_all_test.tar new file mode 100644 index 00000000..13c4d882 Binary files /dev/null and b/tests/test_files/return_all_test.tar differ diff --git a/tests/test_files/test_audio.mp3 b/tests/test_files/test_audio.mp3 new file mode 100644 index 00000000..9d9d34d5 Binary files /dev/null and b/tests/test_files/test_audio.mp3 differ diff --git a/tests/test_files/test_video.mp4 b/tests/test_files/test_video.mp4 new file mode 100644 index 00000000..7fb9f009 Binary files /dev/null and b/tests/test_files/test_video.mp4 differ diff --git a/tests/test_files/test_webvid.csv b/tests/test_files/test_webvid.csv new file mode 100644 index 00000000..59273f51 --- /dev/null +++ b/tests/test_files/test_webvid.csv @@ -0,0 +1,21 @@ +videoid,name,page_idx,page_dir,duration,contentUrl +1053841541,Travel blogger shoot a story on top of mountains. young man holds camera in forest.,27732,027701_027750,PT00H00M17S,https://ak.picdn.net/shutterstock/videos/1053841541/preview/stock-footage-travel-blogger-shoot-a-story-on-top-of-mountains-young-man-holds-camera-in-forest.mp4 +5641526,Horse grazing - seperated on green screen,28620,054101_054150,PT00H00M24S,https://ak.picdn.net/shutterstock/videos/5641526/preview/stock-footage-horse-grazing-seperated-on-green-screen.mp4 +1019039938,City traffic lights. blurred view,22404,022401_022450,PT00H00M13S,https://ak.picdn.net/shutterstock/videos/1019039938/preview/stock-footage-city-traffic-lights-blurred-view.mp4 +6251156,Asian girl sexy santa claus isolated greenscreen green background happy in love heart,4046,004001_004050,PT00H00M16S,https://ak.picdn.net/shutterstock/videos/6251156/preview/stock-footage-asian-girl-sexy-santa-claus-isolated-greenscreen-green-background-happy-in-love-heart.mp4 +23720764,Young woman flexing muscles with barbell in gym.the coach helps her.,13617,013601_013650,PT00H00M24S,https://ak.picdn.net/shutterstock/videos/23720764/preview/stock-footage-young-woman-flexing-muscles-with-barbell-in-gym-the-coach-helps-her.mp4 +1009929137,Bearded man in casual clothes sitts on the chair and takes the smartphone from the desk and starts browsing through it. dolly shot,5032,005001_005050,PT00H00M35S,https://ak.picdn.net/shutterstock/videos/1009929137/preview/stock-footage-bearded-man-in-casual-clothes-sitts-on-the-chair-and-takes-the-smartphone-from-the-desk-and-starts.mp4 +1020335461,"Moroccan flag waving in the wind in medieval town essaouira, morocco",873,000851_000900,PT00H00M05S,https://ak.picdn.net/shutterstock/videos/1020335461/preview/stock-footage-moroccan-flag-waving-in-the-wind-in-medieval-town-essaouira-morocco.mp4 +6543449,Unexpanded white peony flower bud covered with morning dew water drops move in wind.,27762,027751_027800,PT00H00M07S,https://ak.picdn.net/shutterstock/videos/6543449/preview/stock-footage-unexpanded-white-peony-flower-bud-covered-with-morning-dew-water-drops-move-in-wind.mp4 +30287869,Time lapse video of light cirrus clouds in a sky darkened by the eclipse,32306,032301_032350,PT00H00M34S,https://ak.picdn.net/shutterstock/videos/30287869/preview/stock-footage-time-lapse-video-of-light-cirrus-clouds-in-a-sky-darkened-by-the-eclipse.mp4 +20655292,Yellow and black tropical fish dart through coral,66015,040701_040750,PT00H00M15S,https://ak.picdn.net/shutterstock/videos/20655292/preview/stock-footage-yellow-and-black-tropical-fish-dart-through-coral.mp4 +1013992232,Legs of a crowd of people walking along the path in the woods in the evening. the camera moves with people. close-up of crowd feet. many legs of children walking in the forest,3874,175401_175450,PT00H00M11S,https://ak.picdn.net/shutterstock/videos/1013992232/preview/stock-footage-legs-of-a-crowd-of-people-walking-along-the-path-in-the-woods-in-the-evening-the-camera-moves-with.mp4 +1040044910,"Laundry outline icon animation footage/video. hand drawn like symbol animated with motion graphic, can be used as loop item, has alpha channel and it's at 4k video resolution.",84013,084001_084050,PT00H00M09S,https://ak.picdn.net/shutterstock/videos/1040044910/preview/stock-footage-laundry-outline-icon-animation-footage-video-hand-drawn-like-symbol-animated-with-motion-graphic.mp4 +7326946,"Saint petersburg, russia - april 15, 2014: city view of embankment and excursion boat parked at griboyedov channel station near nevsky prospekt in sunny day",89296,011001_011050,PT00H00M24S,https://ak.picdn.net/shutterstock/videos/7326946/preview/stock-footage-saint-petersburg-russia-april-city-view-of-embankment-and-excursion-boat-parked-at.mp4 +1015128580,Strong wind blows through the dunes grass at a sunny day in summer,17822,017801_017850,PT00H00M23S,https://ak.picdn.net/shutterstock/videos/1015128580/preview/stock-footage-strong-wind-blows-through-the-dunes-grass-at-a-sunny-day-in-summer.mp4 +9429470,Cat climbing on tree,2621,002601_002650,PT00H00M15S,https://ak.picdn.net/shutterstock/videos/9429470/preview/stock-footage-cat-climbing-on-tree.mp4 +29797315,Aerial uhd 4k view. mid-air flight over fresh and clean mountain river at sunny summer morning. green trees and sun rays on horizon. direct on sun,14848,014801_014850,PT00H01M00S,https://ak.picdn.net/shutterstock/videos/29797315/preview/stock-footage-aerial-uhd-k-view-mid-air-flight-over-fresh-and-clean-mountain-river-at-sunny-summer-morning.mp4 +24806498,Tourist couple enjoying tropical water gopro selfie,23275,023251_023300,PT00H00M06S,https://ak.picdn.net/shutterstock/videos/24806498/preview/stock-footage-tourist-couple-enjoying-tropical-water-gopro-selfie.mp4 +13080881,"Unidentifiable players dribbling, walking and shooting basket balls in an urban setting",84559,084551_084600,PT00H00M10S,https://ak.picdn.net/shutterstock/videos/13080881/preview/stock-footage-unidentifiable-players-dribbling-walking-and-shooting-basket-balls-in-an-urban-setting.mp4 +8799880,"Video 1920x1080 - coconut palm tree shadow on the beautiful tropical beach in island koh kood , thailand",20051,175901_175950,PT00H00M21S,https://ak.picdn.net/shutterstock/videos/8799880/preview/stock-footage-video-x-coconut-palm-tree-shadow-on-the-beautiful-tropical-beach-in-island-koh-kood.mp4 +13292570,"Kiev, ukraine 30.07.2015. cars go on the road. road traffic. peak hour.",83121,083101_083150,PT00H00M10S,https://ak.picdn.net/shutterstock/videos/13292570/preview/stock-footage-kiev-ukraine-cars-go-on-the-road-road-traffic-peak-hour.mp4 diff --git a/tests/test_files/test_yt.csv b/tests/test_files/test_yt.csv new file mode 100644 index 00000000..8df07d0d --- /dev/null +++ b/tests/test_files/test_yt.csv @@ -0,0 +1,4 @@ +videoid,name,contentUrl +jLX0D8qQUBM,Aesthetics predictor from side-by-side comparisions,https://www.youtube.com/watch?v=jLX0D8qQUBM +9Zn4pID_BJ0,Image Deduplication Testset,https://www.youtube.com/watch?v=9Zn4pID_BJ0 +aBfjTAQ8qpM,Merge embeddings from low abstraction level to higher levels with k-means (idea),https://www.youtube.com/watch?v=aBfjTAQ8qpM diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 00000000..48f7762a --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,64 @@ +"""end2end test""" +import os + +import pandas as pd +import pytest +import tarfile +import tempfile + +from omegaconf import OmegaConf +from video2dataset.configs import CONFIGS +from video2dataset.main import video2dataset + + +@pytest.mark.parametrize("input_file", ["test_webvid.csv", "test_yt.csv"]) +def test_e2e(input_file): + current_folder = os.path.dirname(__file__) + url_list = os.path.join(current_folder, f"test_files/{input_file}") + + sample_count = len(pd.read_csv(url_list)) + config = OmegaConf.to_container(CONFIGS["default"]) + + with tempfile.TemporaryDirectory() as tmpdir: + config["storage"]["number_sample_per_shard"] = 10 if "webvid" in input_file else 3 + samples_per_shard = config["storage"]["number_sample_per_shard"] + config["distribution"]["processes_count"] = 1 + + video2dataset( + url_list, + output_folder=tmpdir, + input_format="csv", + output_format="webdataset", + url_col="contentUrl", + caption_col="name", + save_additional_columns=["videoid"], + config=config, + ) + + for shard in ["00000", "00001"] if sample_count / samples_per_shard > 1.0 else ["00000"]: + for ext in ["mp4", "json", "txt"]: + assert ( + len([x for x in tarfile.open(tmpdir + f"/{shard}.tar").getnames() if x.endswith(f".{ext}")]) + == samples_per_shard + ) + + # multistage test + shard_list = tmpdir + ("/{00000..00001}.tar" if "webvid" in input_file else "/{00000..00000}.tar") + tmpdir2 = tmpdir + "/transformed" + + video2dataset( + shard_list, + input_format="webdataset", + output_folder=tmpdir2, + output_format="webdataset", + stage="subset", + encode_formats={}, # only copy over metadata and caption + ) + + for shard in ["00000", "00001"] if sample_count / samples_per_shard > 1.0 else ["00000"]: + for ext in ["json", "txt"]: + assert ( + len([x for x in tarfile.open(tmpdir2 + f"/{shard}.tar").getnames() if x.endswith(f".{ext}")]) + == samples_per_shard + ) + assert len([x for x in tarfile.open(tmpdir2 + f"/{shard}.tar").getnames() if x.endswith(f".mp4")]) == 0 diff --git a/tests/test_subsamplers.py b/tests/test_subsamplers.py new file mode 100644 index 00000000..be6045ac --- /dev/null +++ b/tests/test_subsamplers.py @@ -0,0 +1,328 @@ +"""test video2dataset subsamplers""" +import os +import subprocess +import pytest +import ffmpeg +import tempfile +import numpy as np +import cv2 + +from video2dataset.subsamplers import ( + ClippingSubsampler, + _get_seconds, + _split_time_frame, + Streams, + FFProbeSubsampler, + ResolutionSubsampler, + FrameSubsampler, + AudioRateSubsampler, + CutDetectionSubsampler, + OpticalFlowSubsampler, + WhisperSubsampler, +) + + +SINGLE = [[50.0, 60.0]] +MULTI = [ + ["00:00:09.000", "00:00:13.500"], + ["00:00:13.600", "00:00:24.000"], + ["00:00:45.000", "00:01:01.230"], + ["00:01:01.330", "00:01:22.000"], + ["00:01:30.000", "00:02:00.330"], +] + + +@pytest.mark.parametrize("clips", [SINGLE, MULTI]) +def test_clipping_subsampler(clips): + current_folder = os.path.dirname(__file__) + # video lenght - 2:02 + video = os.path.join(current_folder, "test_files/test_video.mp4") + with open(video, "rb") as vid_f: + video_bytes = vid_f.read() + audio = os.path.join(current_folder, "test_files/test_audio.mp3") + with open(audio, "rb") as aud_f: + audio_bytes = aud_f.read() + + min_length = 5.0 if clips == MULTI else 2.0 + max_length = 999999.0 if clips == MULTI else 3.0 + subsampler = ClippingSubsampler( + oom_clip_count=3, + encode_formats={"video": "mp4", "audio": "mp3"}, + min_length=min_length, + max_length=max_length, + max_length_strategy="all", + precision="low", + ) + + metadata = { + "key": "000", + "clips": clips, + } + + streams: Streams = {"video": [video_bytes], "audio": [audio_bytes]} + stream_fragments, meta_fragments, error_message = subsampler(streams, metadata) + video_fragments = stream_fragments["video"] + audio_fragments = stream_fragments["audio"] + assert error_message is None + # first one is only 4.5s + assert len(audio_fragments) == len(video_fragments) == len(meta_fragments) + if clips == SINGLE: + assert len(video_fragments) == 3 + else: + assert len(video_fragments) == 4 + + for vid_frag, meta_frag in zip(video_fragments, meta_fragments): + with tempfile.NamedTemporaryFile() as tmp: + tmp.write(vid_frag) + key_ind = int(meta_frag["key"].split("_")[-1]) + s, e = meta_frag["clips"][0] + + if clips == MULTI: + key_ind += 1 + else: + key_ind = 0 + + s_target, e_target = clips[key_ind] + s_target, e_target = _get_seconds(s_target), _get_seconds(e_target) + expected_clips = _split_time_frame(s_target, e_target, min_length, max_length) + assert [_get_seconds(s), _get_seconds(e)] in expected_clips + assert _get_seconds(e) - _get_seconds(s) >= min_length + + s_s, e_s = _get_seconds(s), _get_seconds(e) + probe = ffmpeg.probe(tmp.name) + video_stream = [stream for stream in probe["streams"] if stream["codec_type"] == "video"][0] + frag_len = float(video_stream["duration"]) + + assert abs(frag_len - (e_s - s_s)) < 5.0 + + +@pytest.mark.parametrize("size,resize_mode", [(144, ["scale"]), (1620, ["scale", "crop", "pad"])]) +def test_resolution_subsampler_video_size(size, resize_mode): + current_folder = os.path.dirname(__file__) + # video lenght - 2:02, 1080x1920 + video = os.path.join(current_folder, "test_files/test_video.mp4") + with open(video, "rb") as vid_f: + video_bytes = vid_f.read() + + subsampler = ResolutionSubsampler(video_size=size, resize_mode=resize_mode) + + streams = {"video": [video_bytes]} + subsampled_streams, _, error_message = subsampler(streams) + assert error_message is None + subsampled_videos = subsampled_streams["video"] + + with tempfile.NamedTemporaryFile() as tmp: + tmp.write(subsampled_videos[0]) + + probe = ffmpeg.probe(tmp.name) + video_stream = [stream for stream in probe["streams"] if stream["codec_type"] == "video"][0] + h_vid, w_vid = video_stream["height"], video_stream["width"] + + assert h_vid == size + if resize_mode == ["scale"]: + assert w_vid == 256 # 1920 / (1080/144) + else: + assert w_vid == size + + +@pytest.mark.parametrize("height,width,resize_mode", [(-1, 128, ["scale"]), (1620, 1620, ["scale", "crop", "pad"])]) +def test_resolution_subsampler_height_and_width(height, width, resize_mode): + current_folder = os.path.dirname(__file__) + # video lenght - 2:02, 1080x1920 + video = os.path.join(current_folder, "test_files/test_video.mp4") + with open(video, "rb") as vid_f: + video_bytes = vid_f.read() + + subsampler = ResolutionSubsampler(height=height, width=width, resize_mode=resize_mode) + + streams = {"video": [video_bytes]} + subsampled_streams, _, error_message = subsampler(streams) + assert error_message is None + subsampled_videos = subsampled_streams["video"] + + with tempfile.NamedTemporaryFile() as tmp: + tmp.write(subsampled_videos[0]) + + probe = ffmpeg.probe(tmp.name) + video_stream = [stream for stream in probe["streams"] if stream["codec_type"] == "video"][0] + h_vid, w_vid = video_stream["height"], video_stream["width"] + + if resize_mode == ["scale"]: + assert h_vid == 72 + assert w_vid == 128 + else: + assert h_vid == height + assert w_vid == width + + +@pytest.mark.parametrize("target_frame_rate", [6, 15, 30]) +def test_frame_rate_subsampler(target_frame_rate): + current_folder = os.path.dirname(__file__) + # video length - 2:02, 1080x1920, 30 fps + video = os.path.join(current_folder, "test_files/test_video.mp4") + with open(video, "rb") as vid_f: + video_bytes = vid_f.read() + + subsampler = FrameSubsampler(target_frame_rate) + + streams = {"video": [video_bytes]} + subsampled_streams, _, error_message = subsampler(streams) + assert error_message is None + subsampled_videos = subsampled_streams["video"] + + with tempfile.NamedTemporaryFile() as tmp: + tmp.write(subsampled_videos[0]) + + probe = ffmpeg.probe(tmp.name) + video_stream = [stream for stream in probe["streams"] if stream["codec_type"] == "video"][0] + frame_rate = int(video_stream["r_frame_rate"].split("/")[0]) + + assert frame_rate == target_frame_rate + + +@pytest.mark.parametrize("sample_rate,n_audio_channels", [(44100, 1), (24000, 2)]) +def test_audio_rate_subsampler(sample_rate, n_audio_channels): + current_folder = os.path.dirname(__file__) + audio = os.path.join(current_folder, "test_files/test_audio.mp3") + with open(audio, "rb") as aud_f: + audio_bytes = aud_f.read() + + streams = {"audio": [audio_bytes]} + subsampler = AudioRateSubsampler(sample_rate, "mp3", n_audio_channels) + + subsampled_streams, _, error_message = subsampler(streams) + assert error_message is None + subsampled_audios = subsampled_streams["audio"] + + with tempfile.NamedTemporaryFile(suffix=".mp3") as tmp: + tmp.write(subsampled_audios[0]) + + out = subprocess.check_output(f"file {tmp.name}".split()).decode("utf-8") + assert "Audio file with ID3 version" in out + + result = ffmpeg.probe(tmp.name) + read_sample_rate = result["streams"][0]["sample_rate"] + read_num_channels = result["streams"][0]["channels"] + assert int(read_sample_rate) == sample_rate + assert int(read_num_channels) == n_audio_channels + + +@pytest.mark.parametrize( + "cut_detection_mode,framerates", [("longest", []), ("longest", [1]), ("all", []), ("all", [1])] +) +def test_cut_detection_subsampler(cut_detection_mode, framerates): + current_folder = os.path.dirname(__file__) + video = os.path.join(current_folder, "test_files/test_video.mp4") + with open(video, "rb") as vid_f: + video_bytes = vid_f.read() + + subsampler = CutDetectionSubsampler(cut_detection_mode, framerates, threshold=5) + + streams = {"video": [video_bytes]} + streams, cuts, err_msg = subsampler(streams) + if cut_detection_mode == "longest": + assert len(cuts["cuts_original_fps"]) == 1 + assert cuts["cuts_original_fps"][0] == [0, 2096] + + if len(framerates) > 0: + assert cuts["cuts_1"][0] == [0, 2100] + + if cut_detection_mode == "all": + assert len(cuts["cuts_original_fps"]) > 1 + assert cuts["cuts_original_fps"][-1] == [3015, 3678] + + if len(framerates) > 0: + assert cuts["cuts_1"][-1] == [3420, 3678] + + +@pytest.mark.parametrize( + "detector,fps,params", [("cv2", 1, None), ("cv2", 2, None), ("cv2", 1, (0.25, 2, 10, 3, 5, 1, 0))] +) +def test_optical_flow_subsampler(detector, fps, params): + current_folder = os.path.dirname(__file__) + video = os.path.join(current_folder, "test_files/test_video.mp4") + + cap = cv2.VideoCapture(video) + native_fps = cap.get(cv2.CAP_PROP_FPS) + frames = [] + for _ in range(100): + ret, frame = cap.read() + if not ret: + break + frames.append(frame) + cap.release() + step = int(native_fps / fps) + raw_frames = np.array(frames)[::step] + + subsampler = OpticalFlowSubsampler(detector, detector_args=params) + + optical_flow, metrics, error_message = subsampler(raw_frames) + mean_magnitude, _ = metrics + assert error_message is None + + first_frame = optical_flow[0] + assert first_frame.shape == (1080, 1920, 2) + + if fps == 1: + if params is None: + assert np.isclose(mean_magnitude, 0.00225532218919966, rtol=1e-3) # verified independently on colab + elif params is not None: + assert np.isclose( + mean_magnitude, 0.0034578320931094217, rtol=1e-3 + ) # np.isclose due to potential numerical precision issues + elif fps == 2: # fps = 2, params = None + assert np.isclose(mean_magnitude, 0.0011257734728123598, rtol=1e-3) + + +@pytest.mark.parametrize("extract_keyframes", [False, True]) +def test_ffprobe_subsampler(extract_keyframes): + current_folder = os.path.dirname(__file__) + # video length - 2:02, 1080x1920, 30 fps + video = os.path.join(current_folder, "test_files/test_video.mp4") + with open(video, "rb") as vid_f: + video_bytes = vid_f.read() + + subsampler = FFProbeSubsampler(extract_keyframes) + + streams = {"video": [video_bytes]} + metadata = {} + subsampled_streams, metadata, error_message = subsampler(streams, metadata) + assert error_message is None + assert metadata is not None + assert "video_metadata" in metadata + + video_metadata = metadata["video_metadata"] + + # check some basic metadata + assert "format" in video_metadata + assert "duration" in video_metadata["format"] + assert "streams" in video_metadata + video_stream_info = next(stream for stream in video_metadata["streams"] if stream["codec_type"] == "video") + + assert "width" in video_stream_info + assert "height" in video_stream_info + assert "r_frame_rate" in video_stream_info + + if extract_keyframes: + assert "keyframe_timestamps" in video_metadata + assert isinstance(video_metadata["keyframe_timestamps"], list) + assert len(video_metadata["keyframe_timestamps"]) > 0 + else: + assert "keyframe_timestamps" not in metadata + + +def test_whisper_subsampler(): + current_folder = os.path.dirname(__file__) + audio = os.path.join(current_folder, "test_files/test_audio.mp3") + with open(audio, "rb") as aud_f: + audio_bytes = aud_f.read() + + subsampler = WhisperSubsampler("small", 4, "float32") + streams = {"audio": [audio_bytes]} + metadata = [{"key": "000"}] + + _, metadata, error_message = subsampler(streams, metadata) + assert error_message is None + transcript = metadata[0]["whisper_transcript"] + assert transcript["segments"][0]["text"].startswith(" Bob Jones University in Greenville, South Carolina") + assert len(transcript["segments"]) == 28 diff --git a/tests/test_yt_meta.py b/tests/test_yt_meta.py new file mode 100644 index 00000000..3af1f2c8 --- /dev/null +++ b/tests/test_yt_meta.py @@ -0,0 +1,73 @@ +import os + +import pandas as pd +import pytest +from video2dataset.data_reader import get_yt_meta + + +@pytest.mark.parametrize("input_file", ["test_yt.csv"]) +def test_meta(input_file): + yt_metadata_args = { + "writesubtitles": False, + "get_info": True, + } + current_folder = os.path.dirname(__file__) + url_list = pd.read_csv(os.path.join(current_folder, f"test_files/{input_file}"))["contentUrl"] + for url in url_list: + yt_meta_dict = get_yt_meta(url, yt_metadata_args) + + assert type(yt_meta_dict) == dict + assert type(yt_meta_dict["info"]) == dict + assert "id" in yt_meta_dict["info"].keys() + assert "title" in yt_meta_dict["info"].keys() + + +@pytest.mark.parametrize("input_file", ["test_yt.csv"]) +def test_no_meta(input_file): + yt_metadata_args = { + "writesubtitles": False, + "get_info": False, + } + current_folder = os.path.dirname(__file__) + url_list = pd.read_csv(os.path.join(current_folder, f"test_files/{input_file}"))["contentUrl"] + for url in url_list: + yt_meta_dict = get_yt_meta(url, yt_metadata_args) + + assert type(yt_meta_dict) == dict + assert yt_meta_dict["info"] == None + + +@pytest.mark.parametrize("input_file", ["test_yt.csv"]) +def test_subtitles(input_file): + yt_metadata_args = { + "writesubtitles": True, + "subtitleslangs": ["en"], + "writeautomaticsub": True, + "get_info": False, + } + current_folder = os.path.dirname(__file__) + url_list = pd.read_csv(os.path.join(current_folder, f"test_files/{input_file}"))["contentUrl"] + for url in url_list: + yt_meta_dict = get_yt_meta(url, yt_metadata_args) + + assert type(yt_meta_dict) == dict + assert type(yt_meta_dict["subtitles"]) == dict + assert type(yt_meta_dict["subtitles"]["en"]) == list + assert type(yt_meta_dict["subtitles"]["en"][0]) == dict + + +@pytest.mark.parametrize("input_file", ["test_yt.csv"]) +def test_no_subtitles(input_file): + yt_metadata_args = { + "writesubtitles": False, + "subtitleslangs": ["en"], + "writeautomaticsub": True, + "get_info": False, + } + current_folder = os.path.dirname(__file__) + url_list = pd.read_csv(os.path.join(current_folder, f"test_files/{input_file}"))["contentUrl"] + for url in url_list: + yt_meta_dict = get_yt_meta(url, yt_metadata_args) + + assert type(yt_meta_dict) == dict + assert yt_meta_dict["subtitles"] == None diff --git a/video2dataset/__init__.py b/video2dataset/__init__.py new file mode 100644 index 00000000..cf7f9b9b --- /dev/null +++ b/video2dataset/__init__.py @@ -0,0 +1,3 @@ +"""video2dataset""" + +from .main import video2dataset diff --git a/video2dataset/configs/__init__.py b/video2dataset/configs/__init__.py new file mode 100644 index 00000000..194a021c --- /dev/null +++ b/video2dataset/configs/__init__.py @@ -0,0 +1,12 @@ +"""video2dataset example configs""" +import os +from omegaconf import OmegaConf + +configs_path = os.path.dirname(os.path.realpath(__file__)) + +CONFIGS = { + "default": OmegaConf.load(os.path.join(configs_path, "default.yaml")), + "downsample_ml": OmegaConf.load(os.path.join(configs_path, "downsample_ml.yaml")), + "optical_flow": OmegaConf.load(os.path.join(configs_path, "optical_flow.yaml")), + "caption": OmegaConf.load(os.path.join(configs_path, "caption.yaml")), +} diff --git a/video2dataset/configs/caption.yaml b/video2dataset/configs/caption.yaml new file mode 100644 index 00000000..2cd30597 --- /dev/null +++ b/video2dataset/configs/caption.yaml @@ -0,0 +1,31 @@ +subsampling: + CaptionSubsampler: + args: + captioner_args: + video_captioner: "kpyu/video-blip-flan-t5-xl-ego4d" + prompt: null + +reading: + dataloader_args: + batch_size: 32 + resize_size: [224, 224] + decoder_kwargs: + n_frames: 8 + uniformly_sample: True + fps: null + num_threads: 8 + return_bytes: True + timeout: 60 + sampler: null + +storage: + number_sample_per_shard: 1000 + oom_shard_count: 5 + captions_are_subtitles: False + +distribution: + processes_count: 1 + thread_count: 8 + subjob_size: 1000 + distributor: "multiprocessing" + diff --git a/video2dataset/configs/default.yaml b/video2dataset/configs/default.yaml new file mode 100644 index 00000000..cf0baf90 --- /dev/null +++ b/video2dataset/configs/default.yaml @@ -0,0 +1,24 @@ +subsampling: {} + +reading: + yt_args: + download_size: 360 + download_audio_rate: 44100 + yt_metadata_args: + writesubtitles: 'all' + subtitleslangs: ['en'] + writeautomaticsub: True + get_info: True + timeout: 60 + sampler: null + +storage: + number_sample_per_shard: 1000 + oom_shard_count: 5 + captions_are_subtitles: False + +distribution: + processes_count: 16 + thread_count: 32 + subjob_size: 1000 + distributor: "multiprocessing" diff --git a/video2dataset/configs/downsample_ml.yaml b/video2dataset/configs/downsample_ml.yaml new file mode 100644 index 00000000..83d56299 --- /dev/null +++ b/video2dataset/configs/downsample_ml.yaml @@ -0,0 +1,47 @@ +subsampling: + ResolutionSubsampler: + args: + video_size: 224 + resize_mode: "scale,crop,pad" + FrameSubsampler: + args: + frame_rate: 5 + CutDetectionSubsampler: + cuts_are_clips: True + args: + cut_detection_mode: "all" + framerates: null + threshold: 11.5 + min_scene_len: 15 + ClippingSubsampler: + args: + min_length: 4.0 + max_length: 10.0 + max_length_strategy: "all" + precision: "keyframe_adjusted" + FFProbeSubsampler: + args: + extract_keyframes: False + +reading: + yt_args: + download_size: 360 + download_audio_rate: 44100 + yt_metadata_args: + writesubtitles: 'all' + subtitleslangs: ['en'] + writeautomaticsub: True + get_info: True + timeout: 60 + sampler: null + +storage: + number_sample_per_shard: 1000 + oom_shard_count: 5 + captions_are_subtitles: False + +distribution: + processes_count: 16 + thread_count: 32 + subjob_size: 1000 + distributor: "multiprocessing" diff --git a/video2dataset/configs/optical_flow.yaml b/video2dataset/configs/optical_flow.yaml new file mode 100644 index 00000000..e8f625ee --- /dev/null +++ b/video2dataset/configs/optical_flow.yaml @@ -0,0 +1,28 @@ +subsampling: + OpticalFlowSubsampler: + args: + detector: "cv2" + detector_args: null + dtype: "fp16" + +reading: + dataloader_args: + resize_size: 16 + decoder_kwargs: + n_frames: null + fps: 2 + num_threads: 8 + return_bytes: True + timeout: 60 + sampler: null + +storage: + number_sample_per_shard: 1000 + oom_shard_count: 5 + captions_are_subtitles: False + +distribution: + processes_count: 16 + thread_count: 32 + subjob_size: 1000 + distributor: "multiprocessing" diff --git a/video2dataset/data_reader.py b/video2dataset/data_reader.py new file mode 100644 index 00000000..4dda4734 --- /dev/null +++ b/video2dataset/data_reader.py @@ -0,0 +1,276 @@ +"""classes and functions for downloading videos""" +import os +import uuid +import requests +import yt_dlp +import io +import webvtt +import ffmpeg + + +def video2audio(video, audio_format, tmp_dir): + """extract audio from video""" + path = f"{tmp_dir}/{str(uuid.uuid4())}.{audio_format}" + num_streams = len(ffmpeg.probe(video)["streams"]) + ffmpeg_args = {"f": audio_format} + + if int(num_streams) > 1: # video has audio stream + try: + video = ffmpeg.input(video) + (ffmpeg.output(video.audio, path, **ffmpeg_args).run(capture_stderr=True)) + except ffmpeg.Error as _: + path = None + else: + path = None + return path + + +def sub_to_dict(sub, dedupe=True, single=False) -> list: + """Convert WebVTT to JSON, optionally removing duplicate lines""" + + captions = webvtt.read_buffer(io.StringIO(sub)) + dicts = [{"start": c.start, "end": c.end, "lines": c.lines} for c in captions] + if dedupe: + dicts = [] + prev_line = None + for c in captions: + if any("" in l for l in c.lines): + continue + # Collect lines that are not dupes + not_dupe_lines = [] + for line in c.lines: + if not line.strip(): + continue + if line != prev_line: + not_dupe_lines.append(line) + prev_line = line + if not_dupe_lines: + dicts.append({"start": c.start, "end": c.end, "lines": not_dupe_lines}) + if single: + for d in dicts: + d["line"] = "\n".join(d.pop("lines")) + return dicts + + +def get_yt_meta(url, yt_metadata_args: dict) -> dict: + """Return yt meta dict with meta data and/or subtitles + yt_metadata_args is a dict of follwing format: + yt_metadata_args = { + 'writesubtitles': 'first', + 'subtitleslangs': ['en'], + 'writeautomaticsub': True, + 'get_info': True + } + + writesubtitles: Whether to write subtitles for each provided language or just the first present + writeautomaticsub: Write the automatically generated subtitles to a file + subtitleslangs: List of languages of the subtitles to download. + get_info: Whether to add info (title, description, tags etc) to the output. + """ + + write_subs = yt_metadata_args.get("writesubtitles", None) + + yt_metadata_args["skip_download"] = True + yt_metadata_args["ignoreerrors"] = True + yt_metadata_args["quiet"] = True + + info_dict, full_sub_dict = None, None + + with yt_dlp.YoutubeDL(yt_metadata_args) as yt: + info_dict = yt.extract_info(url, download=False) + if write_subs: + full_sub_dict = {} + for lang in yt_metadata_args["subtitleslangs"]: + if lang not in info_dict["requested_subtitles"]: + continue + sub_url = info_dict["requested_subtitles"][lang]["url"] + res = requests.get(sub_url, timeout=10) + sub = io.TextIOWrapper(io.BytesIO(res.content)).read() + full_sub_dict[lang] = sub_to_dict(sub) + + if write_subs == "first": + break + + if yt_metadata_args["get_info"]: + info_dict.pop("subtitles") + info_dict.pop("requested_formats") + info_dict.pop("formats") + info_dict.pop("thumbnails") + info_dict.pop("automatic_captions") + else: + info_dict = None + + yt_meta_dict = {"info": info_dict, "subtitles": full_sub_dict} + + return yt_meta_dict + + +def get_file_info(url): + """returns info about the url (currently extension and modality)""" + # TODO: make this nicer + video_extensions = ["mp4", "webm", "mov", "avi", "mkv"] + audio_extensions = ["mp3", "wav", "m4a"] + for ext in video_extensions: + if url.endswith(f".{ext}"): + return ext, "video" + for ext in audio_extensions: + if url.endswith(f".{ext}"): + return ext, "audio" + return None + + +class WebFileDownloader: + """Downloader class for mp4 links""" + + def __init__(self, timeout, tmp_dir, encode_formats): + self.timeout = timeout + self.tmp_dir = tmp_dir + self.encode_formats = encode_formats + + def __call__(self, url): + modality_paths = {} + + ext, modality = get_file_info(url) + if not os.path.isfile(url): + resp = requests.get(url, stream=True, timeout=self.timeout) + byts = resp.content + else: # local files (don't want to delete) + with open(url, "rb") as f: + byts = f.read() + + modality_path = f"{self.tmp_dir}/{str(uuid.uuid4())}.{ext}" + with open(modality_path, "wb") as f: + f.write(byts) + + modality_paths[modality] = modality_path + + if modality == "video" and self.encode_formats.get("audio", None): + audio_format = self.encode_formats["audio"] + audio_path = video2audio(modality_paths["video"], audio_format, self.tmp_dir) + if audio_path is not None: + modality_paths["audio"] = audio_path + + for modality, modality_path in modality_paths.items(): + if modality not in self.encode_formats: + os.remove(modality_path) + modality_path.pop(modality) + + return modality_paths, None + + +class YtDlpDownloader: + """Downloader class for yt-dlp links + + yt_args: + download_size: preferred height of video to download. Will try to download smallest video >=download_size + download_audio_rate: same as size but with audio + yt_metadata_args: see get_yt_metadata function docstring + """ + + # TODO: maybe we just include height and width in the metadata_args + def __init__(self, yt_args, tmp_dir, encode_formats): + self.metadata_args = yt_args.get("yt_metadata_args", {}) + self.video_size = yt_args.get("download_size", 360) + self.audio_rate = yt_args.get("download_audio_rate", 44100) + self.tmp_dir = tmp_dir + self.encode_formats = encode_formats + + # TODO: figure out when to do this + # was relevant with HD videos for loading with decord + self.specify_codec = False + + def __call__(self, url): + modality_paths = {} + + video_format_string = ( + f"wv*[height>={self.video_size}][ext=mp4]{'[codec=avc1]' if self.specify_codec else ''}/" + f"w[height>={self.video_size}][ext=mp4]{'[codec=avc1]' if self.specify_codec else ''}/" + f"bv/b[ext=mp4]{'[codec=avc1]' if self.specify_codec else ''}" + ) + audio_fmt_string = ( + f"wa[asr>={self.audio_rate}][ext=m4a] / ba[ext=m4a]" if self.audio_rate > 0 else "ba[ext=m4a]" + ) + + if self.encode_formats.get("audio", None): + audio_path_m4a = f"{self.tmp_dir}/{str(uuid.uuid4())}.m4a" + ydl_opts = { + "outtmpl": audio_path_m4a, + "format": audio_fmt_string, + "quiet": True, + } + + err = None + try: + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + ydl.download(url) + except Exception as e: # pylint: disable=(broad-except) + err = str(e) + os.remove(audio_path_m4a) + + if err is None: + # TODO: look into this, don't think we can just do this + # TODO: just figure out a way to download the preferred extension using yt-dlp + # audio_path = audio_path_m4a.replace(".m4a", f".{self.encode_formats['audio']}") + audio_path = audio_path_m4a + modality_paths["audio"] = audio_path + + if self.encode_formats.get("video", None): + video_path = f"{self.tmp_dir}/{str(uuid.uuid4())}.mp4" + ydl_opts = { + "outtmpl": video_path, + "format": video_format_string, + "quiet": True, + "no_warnings": True, + } + + err = None + try: + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + ydl.download(url) + except Exception as e: # pylint: disable=(broad-except) + err = str(e) + os.remove(video_path) + + if err is None: + modality_paths["video"] = video_path + + err = None + try: + if self.metadata_args: + yt_meta_dict = get_yt_meta(url, self.metadata_args) + else: + yt_meta_dict = {} + except Exception as e: # pylint: disable=(broad-except) + err = str(e) + yt_meta_dict = {} + + return modality_paths, yt_meta_dict, None + + +class VideoDataReader: + """Video data reader provide data for a video""" + + def __init__(self, encode_formats, tmp_dir, reading_config): + self.webfile_downloader = WebFileDownloader(reading_config["timeout"], tmp_dir, encode_formats) + self.yt_downloader = YtDlpDownloader(reading_config["yt_args"], tmp_dir, encode_formats) + + def __call__(self, row): + key, url = row + + meta_dict = None + try: + # TODO: make nice function to detect what type of link we're dealing with + if get_file_info(url): # web file that can be directly downloaded + modality_paths, error_message = self.webfile_downloader(url) + else: + modality_paths, meta_dict, error_message = self.yt_downloader(url) + except Exception as e: # pylint: disable=(broad-except) + modality_paths, meta_dict, error_message = {}, None, str(e) + + streams = {} + for modality, modality_path in modality_paths.items(): + with open(modality_path, "rb") as modality_file: + streams[modality] = modality_file.read() + os.remove(modality_path) + + return key, streams, meta_dict, error_message diff --git a/video2dataset/data_writer.py b/video2dataset/data_writer.py new file mode 100644 index 00000000..799a01c5 --- /dev/null +++ b/video2dataset/data_writer.py @@ -0,0 +1,341 @@ +""""writer module handle writing the videos to disk""" + +import json +import os + +import fsspec +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import webdataset as wds + + +class BufferedParquetWriter: + """Write samples to parquet files incrementally with a buffer""" + + def __init__(self, output_file, schema, buffer_size=100): + self.buffer_size = buffer_size + self.schema = schema + self._initiatlize_buffer() + fs, output_path = fsspec.core.url_to_fs(output_file) + + self.output_fd = fs.open(output_path, "wb") + self.parquet_writer = pq.ParquetWriter(self.output_fd, schema) + + def _initiatlize_buffer(self): + self.current_buffer_size = 0 + self.buffer = {k: [] for k in self.schema.names} + + def _add_sample_to_buffer(self, sample): + for k in self.schema.names: + self.buffer[k].append(sample[k]) + self.current_buffer_size += 1 + + def write(self, sample): + if self.current_buffer_size >= self.buffer_size: + self.flush() + self._add_sample_to_buffer(sample) + + def flush(self): + """Write the buffer to disk""" + if self.current_buffer_size == 0: + return + + df = pa.Table.from_pydict(self.buffer, self.schema) + self.parquet_writer.write_table(df) + self._initiatlize_buffer() + + def close(self): + self.flush() + if self.parquet_writer is not None: + self.parquet_writer.close() + self.parquet_writer = None + self.output_fd.close() + + +class ParquetSampleWriter: + """ParquetSampleWriter is a video+caption writer to parquet""" + + def __init__( + self, + shard_id, + output_folder, + save_caption, + oom_shard_count, + schema, + encode_formats, + ): + self.oom_shard_count = oom_shard_count + for fmt in encode_formats.values(): + schema = schema.append(pa.field(fmt, pa.binary())) + shard_name = ( + shard_id + if isinstance(shard_id, str) + else "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string + shard_id=shard_id, oom_shard_count=oom_shard_count + ) + ) + output_file = f"{output_folder}/{shard_name}.parquet" + self.buffered_parquet_writer = BufferedParquetWriter(output_file, schema, 100) + self.save_caption = save_caption + self.encode_formats = encode_formats + + def write(self, streams, key, caption, meta): + """Keep sample in memory then write to disk when close() is called""" + sample = {"key": key} + for modality, stream in streams.items(): + ext = self.encode_formats[modality] if modality in self.encode_formats else modality + sample[ext] = stream + + if self.save_caption: + sample["txt"] = str(caption) if caption is not None else "" + sample.update(meta) + + self.buffered_parquet_writer.write(sample) + + def close(self): + self.buffered_parquet_writer.close() + + +class WebDatasetSampleWriter: + """WebDatasetSampleWriter is a video+caption writer to webdataset""" + + def __init__( + self, + shard_id, + output_folder, + save_caption, + oom_shard_count, + schema, + encode_formats, + ): + self.oom_shard_count = oom_shard_count + shard_name = ( + shard_id + if isinstance(shard_id, str) + else "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string + shard_id=shard_id, oom_shard_count=oom_shard_count + ) + ) + self.shard_id = shard_id + fs, output_path = fsspec.core.url_to_fs(output_folder) + self.tar_fd = fs.open(f"{output_path}/{shard_name}.tar", "wb") + self.tarwriter = wds.TarWriter(self.tar_fd) + self.save_caption = save_caption + self.buffered_parquet_writer = BufferedParquetWriter(output_folder + "/" + shard_name + ".parquet", schema, 100) + self.encode_formats = encode_formats + + def write(self, streams, key, caption, meta): + """write sample to tars""" + sample = {"__key__": key} + for modality, stream in streams.items(): + ext = self.encode_formats[modality] if modality in self.encode_formats else modality + sample[ext] = stream + + if self.save_caption: + sample["txt"] = str(caption) if caption is not None else "" + # some meta data may not be JSON serializable + for k, v in meta.items(): + if isinstance(v, np.ndarray): + meta[k] = v.tolist() + sample["json"] = json.dumps(meta, indent=4) + + self.tarwriter.write(sample) + self.buffered_parquet_writer.write(meta) + + def close(self): + self.buffered_parquet_writer.close() + self.tarwriter.close() + self.tar_fd.close() + + +class TFRecordSampleWriter: + """TFRecordSampleWriter is a video+caption writer to TFRecord""" + + def __init__( + self, + shard_id, + output_folder, + save_caption, + oom_shard_count, + schema, + encode_formats, + ): + try: + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + import tensorflow_io as _ # pylint: disable=import-outside-toplevel + from tensorflow.python.lib.io.tf_record import ( # pylint: disable=import-outside-toplevel + TFRecordWriter, + ) # pylint: disable=import-outside-toplevel + from tensorflow.python.training.training import ( # pylint: disable=import-outside-toplevel + BytesList, + Example, + Feature, + Features, + FloatList, + Int64List, + ) + + self._BytesList = BytesList # pylint: disable=invalid-name + self._Int64List = Int64List # pylint: disable=invalid-name + self._FloatList = FloatList # pylint: disable=invalid-name + self._Example = Example # pylint: disable=invalid-name + self._Features = Features # pylint: disable=invalid-name + self._Feature = Feature # pylint: disable=invalid-name + except ImportError as e: + raise ModuleNotFoundError( + "tfrecords require tensorflow and tensorflow_io to be installed." + "Run `pip install tensorflow tensorflow_io`." + ) from e + + self.oom_shard_count = oom_shard_count + shard_name = ( + shard_id + if isinstance(shard_id, str) + else "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string + shard_id=shard_id, oom_shard_count=oom_shard_count + ) + ) + self.shard_id = shard_id + self.tf_writer = TFRecordWriter(f"{output_folder}/{shard_name}.tfrecord") + self.save_caption = save_caption + self.buffered_parquet_writer = BufferedParquetWriter(output_folder + "/" + shard_name + ".parquet", schema, 100) + self.encode_formats = encode_formats + + def write(self, streams, key, caption, meta): + """Write a sample using tfrecord writer""" + sample = {"key": self._bytes_feature(key.encode())} + for modality, stream in streams.items(): + ext = self.encode_formats[modality] if modality in self.encode_formats else modality + sample[ext] = self._bytes_feature(stream) + + if self.save_caption: + sample["txt"] = self._bytes_feature(str(caption) if caption is not None else "") + for k, v in meta.items(): + sample[k] = self._feature(v) + + tf_example = self._Example(features=self._Features(feature=sample)) + self.tf_writer.write(tf_example.SerializeToString()) + self.buffered_parquet_writer.write(meta) + + def close(self): + self.buffered_parquet_writer.close() + self.tf_writer.close() + + def _feature(self, value): + if isinstance(value, list): + return self._list_feature(value) + elif isinstance(value, int): + return self._int64_feature(value) + elif isinstance(value, float): + return self._float_feature(value) + else: + return self._bytes_feature(value) + + def _bytes_feature(self, value): + """Returns a bytes_list from a string / byte.""" + if value is None: + value = "" + if isinstance(value, str): + value = value.encode() + return self._Feature(bytes_list=self._BytesList(value=[value])) + + def _float_feature(self, value): + """Returns a float_list from a float / double.""" + return self._Feature(float_list=self._FloatList(value=[value])) + + def _int64_feature(self, value): + """Returns an int64_list from a bool / enum / int / uint.""" + return self._Feature(int64_list=self._Int64List(value=[value])) + + def _list_feature(self, value): + """Returns an list of int64_list, float_list, bytes_list.""" + if isinstance(value[0], int): + return self._Feature(int64_list=self._Int64List(value=value)) + elif isinstance(value[0], float): + return self._Feature(float_list=self._FloatList(value=value)) + else: + for i, bytes_feature in enumerate(value): + if bytes_feature is None: + value[i] = "" + if isinstance(bytes_feature, str): + value[i] = bytes_feature.encode() + return self._Feature(bytes_list=self._BytesList(value=value)) + + +class FilesSampleWriter: + """FilesSampleWriter is a caption+video writer to files""" + + def __init__( + self, + shard_id, + output_folder, + save_caption, + oom_shard_count, + schema, + encode_formats, + ): + self.oom_shard_count = oom_shard_count + shard_name = ( + shard_id + if isinstance(shard_id, str) + else "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string + shard_id=shard_id, oom_shard_count=oom_shard_count + ) + ) + self.shard_id = shard_id + self.fs, self.subfolder = fsspec.core.url_to_fs(f"{output_folder}/{shard_name}") + if not self.fs.exists(self.subfolder): + self.fs.mkdir(self.subfolder) + self.save_caption = save_caption + self.buffered_parquet_writer = BufferedParquetWriter(output_folder + "/" + shard_name + ".parquet", schema, 100) + self.encode_formats = encode_formats + + def write(self, streams, key, caption, meta): + """Write sample to disk""" + for modality, stream in streams.items(): + ext = self.encode_formats[modality] if modality in self.encode_formats else modality + filename = f"{self.subfolder}/{key}.{ext}" + with self.fs.open(filename, "wb") as f: + f.write(stream) + + if self.save_caption: + caption = str(caption) if caption is not None else "" + caption_filename = f"{self.subfolder}/{key}.txt" + with self.fs.open(caption_filename, "w") as f: + f.write(str(caption)) + + # some meta data may not be JSON serializable + for k, v in meta.items(): + if isinstance(v, np.ndarray): + meta[k] = v.tolist() + j = json.dumps(meta, indent=4) + meta_filename = f"{self.subfolder}/{key}.json" + with self.fs.open(meta_filename, "w") as f: + f.write(j) + + self.buffered_parquet_writer.write(meta) + + def close(self): + self.buffered_parquet_writer.close() + + +class DummySampleWriter: + """Does not write""" + + def __init__( + self, + shard_id, + output_folder, + save_caption, + oom_shard_count, + schema, + encode_formats, + ): + pass + + def write(self, streams, key, caption, meta): + pass + + def close(self): + pass diff --git a/video2dataset/dataloader/__init__.py b/video2dataset/dataloader/__init__.py new file mode 100644 index 00000000..db65173b --- /dev/null +++ b/video2dataset/dataloader/__init__.py @@ -0,0 +1,5 @@ +""" +Load the data (videos and metadata) into bytes, tensors, strings etc. +""" + +from .dataloader import get_video_dataset diff --git a/video2dataset/dataloader/audio_decode.py b/video2dataset/dataloader/audio_decode.py new file mode 100644 index 00000000..f0dcfcb0 --- /dev/null +++ b/video2dataset/dataloader/audio_decode.py @@ -0,0 +1,49 @@ +"""Audio Decoders""" +import io +import torch +import torch.nn.functional as F +import torchaudio +import torchaudio.functional as Fa + + +def set_backend(extension): + """Sets torchaudio backend for different extensions (soundfile doesn't support M4A and MP3)""" + if extension in ["wav", "flac"]: + torchaudio.set_audio_backend("soundfile") + else: + torchaudio.set_audio_backend("sox_io") + + +class AudioDecoder: + """Basic audio decoder that converts audio into torch tensors""" + + def __init__(self, sample_rate=48000, num_channels=None, extension="wav", max_length=10): + self.sample_rate = sample_rate + self.num_channels = num_channels + self.max_length = max_length + set_backend(extension) + + def __call__(self, key, data): + extension = key.split(".")[-1] + if extension not in "mp3 wav flac m4a".split(): + return None + additional_info = {} + waveform, sample_rate = torchaudio.load(io.BytesIO(data), format=extension) + + waveform = Fa.resample(waveform, sample_rate, self.sample_rate) + pad_masks = torch.zeros((self.max_length * self.sample_rate,)) + pad_start = self.max_length * self.sample_rate - waveform.shape[1] + + if pad_start < 0: + waveform = waveform[:, : self.max_length * self.sample_rate] + if pad_start > 0: + waveform = F.pad( # pylint: disable=not-callable + waveform, (0, self.max_length * self.sample_rate - waveform.shape[1]), "constant" + ) + pad_masks[:pad_start] = 1.0 + + additional_info["audio_pad_masks"] = pad_masks + + additional_info["original_sample_rate"] = sample_rate + additional_info["sample_rate"] = self.sample_rate + return (waveform, additional_info) diff --git a/video2dataset/dataloader/custom_wds.py b/video2dataset/dataloader/custom_wds.py new file mode 100644 index 00000000..4f720524 --- /dev/null +++ b/video2dataset/dataloader/custom_wds.py @@ -0,0 +1,544 @@ +"""Custom WebDataset classes""" +import os +import numpy as np +import random +import tarfile +import warnings +import copy +from io import BufferedIOBase +from typing import Callable, Iterator, Tuple, cast, Optional, IO, Union, List, Iterable, Dict + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset +from torch.utils.data.datapipes.iter import IterableWrapper +from torch.utils.data.datapipes.utils.common import StreamWrapper +from torchdata.datapipes.iter import S3FileLoader, IterDataPipe, FileOpener +from torchdata.datapipes.iter import TarArchiveLoader +from torchdata.datapipes.utils.common import validate_pathname_binary_tuple + + +import webdataset as wds +from webdataset import DataPipeline, filters, shardlists, cache, tariterators +from webdataset.compat import FluidInterface +from webdataset.autodecode import Decoder, ImageHandler + + +def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): + """Take a list of samples (as dictionary) and create a batch, preserving the keys. + If `tensors` is True, `ndarray` objects are combined into + tensor batches. + :param dict samples: list of samples + :param bool tensors: whether to turn lists of ndarrays into a single ndarray + :returns: single sample consisting of a batch + :rtype: dict + """ + + # first, get all keys, for no missing a shared field + keys = set.union(*[set(sample.keys()) for sample in samples]) + if "__corrupted__" in keys: + # set default dummy + dummy = {key: None for key in keys if not key.startswith("__")} + # dummy = {key: None if key != '__corrupted__' else True for key in keys if not key.startswith('__')} + dummy_set = False + missed_ids = set() + # search for first non-corrupted sample and take that as the dummy + for i, sample in enumerate(samples): + if not (dummy_set or sample["__corrupted__"]): + # set dummy to reasonble output, but keep sample['__corrupted__']=True + for key in sample: + # keep meta data + if not key.startswith("__"): + dummy[key] = copy.deepcopy(sample[key]) + dummy_set = True + elif not dummy_set and sample["__corrupted__"]: + # add default dummy and remember id + missed_ids.add(i) + for key in dummy: + samples[i][key] = copy.deepcopy(dummy[key]) + elif sample["__corrupted__"]: + # set corrupted sample to dummy except for metadata + for key in dummy: + samples[i][key] = copy.deepcopy(dummy[key]) + + if missed_ids and dummy_set: + # this will be only to do if at least one element in the batch is not corrupted and + # if there is at least one missed sample + for i in missed_ids: + for key in dummy: + samples[i][key] = copy.deepcopy(dummy[key]) + # samples[i] = copy.deepcopy(dummy) + + # the resorting case if the all samples in the batch are corrupted, then the batch will be + keys = set.intersection(*[set(sample.keys()) for sample in samples]) + batched = {key: [s[key] for s in samples] for key in keys} + + result = {} + for key, values in batched.items(): # Iterate over both key and values + first_value = values[0] + + if isinstance(first_value, (int, float)): + if combine_scalars: + result[key] = np.array(values) + elif isinstance(first_value, torch.Tensor): + if combine_tensors: + result[key] = torch.stack(values) + elif isinstance(first_value, np.ndarray): + if combine_tensors: + result[key] = np.array(values) + elif isinstance(first_value, tuple): # tuple of torch tensor and a dict + dict_keys = first_value[1].keys() + if combine_tensors: + result[key] = torch.stack([v[0] for v in values]) + for k in dict_keys: + result[k] = torch.stack([torch.tensor(v[1][k]) for v in values]) + else: + result[key] = values + + return result + + +class KeyPassThroughDecoder(Decoder): + """Decoder which allows you to pass through some keys""" + + def __init__(self, *args, passthrough_keys=None, **kwargs): + """ + Initialize the KeyPassThroughDecoder. + + :param *args: Positional arguments to be passed to the base Decoder class. + :param passthrough_keys: List of keys to bypass the decoding process. + :param **kwargs: Keyword arguments to be passed to the base Decoder class. + """ + super().__init__(*args, **kwargs) + self.passthrough_keys = passthrough_keys or [] # Simplified passthrough_keys initialization + + def decode(self, sample): + """ + Decode an entire sample. + + :param dict sample: The sample, a dictionary of key-value pairs. + :return: Decoded sample. + :rtype: dict + """ + result = {} + assert isinstance(sample, dict), sample + for k, v in sample.items(): # Removed unnecessary list conversion + if k[0] == "_": + if isinstance(v, bytes): + v = v.decode("utf-8") + result[k] = v + continue + if self.only is not None and k not in self.only: + result[k] = v + continue + assert v is not None + if self.partial: + if isinstance(v, bytes): + result[k] = self.decode1(k, v) + else: + result[k] = v + else: + assert ( + isinstance(v, bytes) or k in self.passthrough_keys + ), f"key: {k}; passthrough_keys: {self.passthrough_keys}" + result[k] = self.decode1(k, v) + return result + + +class FluidInterfaceWithChangedDecode(FluidInterface): + """ + FluidInterface with more Decoder args and different decode function + """ + + # pylint: disable=missing-function-docstring + def decode( + self, + *args, + pre=None, + post=None, + only=None, + partial=False, + passthrough_keys=None, + handler=wds.reraise_exception, + ): + handlers = [ImageHandler(x) if isinstance(x, str) else x for x in args] + decoder = KeyPassThroughDecoder( + handlers, + passthrough_keys=passthrough_keys, + pre=pre, + post=post, + only=only, + partial=partial, + ) + return self.map(decoder, handler=handler) + + +# TODO: pylint says this needs __getitem__ +# pylint: disable=abstract-method +class WebDatasetWithChangedDecoder(DataPipeline, FluidInterfaceWithChangedDecode): + """Small fluid-interface wrapper for DataPipeline.""" + + def __init__( + self, + urls, + handler=wds.reraise_exception, + resampled=False, + shardshuffle=None, + cache_size=0, + cache_dir=None, + detshuffle=False, + nodesplitter=shardlists.single_node_only, + verbose=False, + ): + super().__init__() + if isinstance(urls, IterableDataset): + assert not resampled + self.append(urls) + elif isinstance(urls, dict): + assert "datasets" in urls + self.append(shardlists.MultiShardSample(urls)) + elif resampled: + self.append(shardlists.ResampledShards(urls)) + else: + self.append(shardlists.SimpleShardList(urls)) + self.append(nodesplitter) + self.append(shardlists.split_by_worker) + if shardshuffle is True: + shardshuffle = 100 + if shardshuffle is not None: + if detshuffle: + self.append(filters.detshuffle(shardshuffle)) + else: + self.append(filters.shuffle(shardshuffle)) + if cache_size == 0: + self.append(tariterators.tarfile_to_samples(handler=handler)) + else: + assert cache_size == -1 or cache_size > 0 + self.append( + cache.cached_tarfile_to_samples( + handler=handler, + verbose=verbose, + cache_size=cache_size, + cache_dir=cache_dir, + ) + ) + + +# pylint: disable=missing-function-docstring +def _return_always_map(data, f, return_always=False, handler=wds.reraise_exception): + """Map samples.""" + for sample in data: + try: + result = f(sample) + except Exception as exn: # pylint: disable=broad-except + if handler(exn): + if return_always: + result = {"__corrupted__": True} + else: + continue + else: + break + if result is None: + if return_always: + result = {"__corrupted__": True} + else: + continue + if isinstance(sample, dict) and isinstance(result, dict): + result["__key__"] = sample.get("__key__") + result["__url__"] = sample.get("__url__") + yield result + + +return_always_map = wds.pipelinefilter(_return_always_map) + + +# pylint: disable=missing-function-docstring +def _s3dataset2samples(data, return_always=False, handler=wds.reraise_exception): + for sample in data: + try: + # construct webdataset-style sample + key = os.path.split(sample[0][0])[-1].split(".")[0] + url = os.path.split(sample[0][0])[0] + sample = {s[0].split(".")[-1]: s[1].read() for s in sample} + sample["__key__"] = key + sample["__url__"] = url + + if return_always: + # assume sample is heatlthy in the beginning + sample["__corrupted__"] = False + yield sample + except Exception as exn: # pylint: disable=broad-except + if handler(exn): + continue + break + + +s3dataset2samples = filters.pipelinefilter(_s3dataset2samples) + + +class SplitByWorker(IterDataPipe): + """ + distributed data across workers to mimic behavior of shard splitting in webdataset + """ + + def __init__(self, datapipe, drop_last=False): + super().__init__() + self.datapipe = datapipe + self._len = len(list(self.datapipe)) + self.drop_last = drop_last + self.worker_id = 0 + self.num_workers = 1 + self.max_it = (self._len // self.num_workers) * self.num_workers + + # # def reset(self): + def reset(self): + # this will be called whenever __iter__ is invoked again (this should be kept in mind for shuffling + worker_info = torch.utils.data.get_worker_info() + + if worker_info: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + self.max_it = (self._len // self.num_workers) * self.num_workers + + def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: + for i, data in enumerate(self.datapipe): + # avoid hanging due to uneven number of shards per worker + if self.drop_last and i >= self.max_it: + break + if i % self.num_workers == self.worker_id: + yield data + + +class PrefixResampler(IterDataPipe): + """ + Resampling of prefixes with given probabilities, this is useful when mixing different datasets + """ + + def __init__( + self, + datapipe: IterDataPipe[str], + prefixes: List[str], + ps: Optional[Iterable[float]] = None, + ): + super().__init__() + urls = list(datapipe) + self._len = len(urls) + self.prefix2urls: Dict[str, List] = {p: [] for p in set(prefixes)} + self.ps = dict(zip(prefixes, ps)) # type: ignore + if self.ps is None: + # uniformly distributed + self.ps = [1 / len(self.prefix2urls)] * len(self.prefix2urls) + + print(f"{self.__class__.__name__} got the following prefixes: {prefixes}") + for u in urls: + self.prefix2urls[ + list( + # pylint: disable=unnecessary-lambda,cell-var-from-loop + filter(lambda x: u.startswith(x), prefixes) + )[0] + ].append( + u # pylint: disable=cell-var-from-loop + ) + + for p in self.prefix2urls: + if not self.prefix2urls[p]: + print(f"removing prefix {p} from repefixes2urls since no_entries") + self.prefix2urls.pop(p) + self.ps.pop(p) + + sum_ = sum(list(self.ps.values())) + self.ps = {k: self.ps[k] / sum_ for k in self.ps} + + print(f"Got the following (prob, prefix) pairs for {len(self.ps)} prefixes {list(self.ps.items())}") + + # internal iterator for one epoch + self.it = 0 + self.url_pool: Dict[str, List] = {} + + assert len(self.ps) == len(self.prefix2urls) and np.isclose( + sum(self.ps.values()), 1.0 + ), "Probabilities must have the same length than prefix and must sum up to 1" + + def reset(self): + # this will be called whenever __iter__ is invoked again (this should be kept in mind for shuffling + print("refilling url_pool") + self.url_pool = copy.deepcopy(self.prefix2urls) + self.it = 0 + + def refill_prefix(self, prefix): + # refill the buffer + self.url_pool[prefix] = copy.deepcopy(self.prefix2urls[prefix]) + + def __iter__(self): + while self.it < self.__len__(): + # sample prefix with corresponding probs + prefix_id = np.random.choice(len(self.ps), 1, p=list(self.ps.values())).item() + prefix = list(self.ps.keys())[prefix_id] + # refill the url pool for the selected prefix if empty + if not self.url_pool[prefix]: + self.refill_prefix(prefix) + + # uniformly sample from all available urls for the selected prefix + url_id = np.random.randint(len(self.url_pool[prefix]), dtype=int) + url = self.url_pool[prefix].pop(url_id) + + yield url + self.it += 1 + + def __len__(self): + return self._len + + +class TarArchiveLoaderAndCloser(TarArchiveLoader): + """ + Loads tar archive and closes it once iterated through + """ + + def __init__(self, *args, handler: Callable = wds.reraise_exception, **kwargs): + super().__init__(*args, **kwargs) + self.handler = handler + + def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: + for data in self.datapipe: + validate_pathname_binary_tuple(data) + pathname, data_stream = data + try: + if isinstance(data_stream, StreamWrapper) and isinstance(data_stream.file_obj, tarfile.TarFile): + tar = data_stream.file_obj + else: + reading_mode = ( + self.mode + if hasattr(data_stream, "seekable") and data_stream.seekable() + else self.mode.replace(":", "|") + ) + # typing.cast is used here to silence mypy's type checker + # pylint: disable=consider-using-with + tar = tarfile.open( + fileobj=cast(Optional[IO[bytes]], data_stream), + mode=reading_mode, + ) + for tarinfo in tar: + if not tarinfo.isfile(): + continue + extracted_fobj = tar.extractfile(tarinfo) + if extracted_fobj is None: + warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}") + raise tarfile.ExtractError + inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name)) + yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc] + + # close tarfile after it's been exceeded + tar.close() + del tar + # if isinstance(data_stream, StreamWrapper): + # data_stream.autoclose() + except Exception as e: # pylint: disable=broad-except + warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!") + if self.handler(e): + if hasattr(e, "args") and len(e.args) > 0: + e.args = (e.args[0] + " @ " + str(pathname),) + e.args[1:] + else: + raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() + del data_stream + + +def grouper(x): + return x[0].split("/")[-1].split(".")[0] + + +class TorchDataWebdataset(DataPipeline, FluidInterfaceWithChangedDecode): + """ + Loads tars from s3 directly in memory which reduces failures due to failed downloads + """ + + def __init__( + self, + urls: Union[List[str], str], + repeat: Optional[int] = None, + shardshuffle: int = 10000, + sample_shuffle: int = 0, + buffer_size: Optional[int] = None, + resample_prefixes: bool = False, + prefix_probs: Optional[List[float]] = None, + drop_last: bool = False, + return_always: bool = False, + handler: Callable = wds.reraise_exception, + ): + """ + :param urls: s3 prefixes to load the shards from, can be a list of different prefoxes for dataset mixing. + With shards specified using braceexpand notation + :param repeat: number of repetitions in the training data. Default is None which means looping perpetually. + :param shardshuffle: Shuffle buffer size for shard shuffling. size 1 means no shufflin. Default is 10k. + :param sample_shuffle: Shuffle buffer for sample-level-shuffling. Default is 1 which means no shuffling + :param buffer_size: memory size allocated for loading data from s3 in every connection. + The number of connections per worker is 25. Default is None which means 128M per connection. + :param resample_prefixes: Whether to resample when different prefixes are in the entire dataset. + This can be useful in combination with prefix probs when training on merged datasets of non-equal size. + :param prefix_probs: list containing resampling probabilities for every prefix in `urls` + :param drop_last: whether to drop last samples to prevent hanging (recommended) + :param return_always: Flag indicating whether to drop a sample and continue on error (default) or + to return a dummy output with a + :param handler: handler for handling exceptions as in webdataset + """ + super().__init__() + self.return_always = return_always + if isinstance(urls, (List, list)): + pass + + elif isinstance(urls, str): + urls = [urls] + else: + raise TypeError( + "urls need to be path to a [S3 prefix,path on a mounted fs] or list of paths to more than one those" + ) + + load_from_s3 = urls[0].replace(" ", "").startswith("s3://") + + # sharding filter ensures propper splitting for distributed environment + main_datapipe = IterableWrapper(urls).shard_expand().sharding_filter() + + try: + # after this operation datapipes in the distinct processes contain different tars + global_rank = dist.get_rank() + world_size = dist.get_world_size() + main_datapipe.apply_sharding(world_size, global_rank) + # synchronize data across processes to prevent hanging if sharding is uneven (which is likely) + main_datapipe = main_datapipe.fullsync() + except RuntimeError: + print("torch distributed not used, not applying sharding in dataloader") + pass + # start shuffling accross shards for the first time to mix different datasets + # (can be the same for all workers, just as an additional shuffled initialization) + if shardshuffle > 1 and not resample_prefixes: + raw_tars = list(main_datapipe) + random.shuffle(raw_tars) + print("Loader got the following concrete shards (first 25 are shown)") + print(raw_tars[:25]) + # back to datapipes + main_datapipe = IterableWrapper(raw_tars) + elif resample_prefixes: + main_datapipe = PrefixResampler(main_datapipe, prefixes=urls, ps=prefix_probs) + main_datapipe = SplitByWorker(main_datapipe, drop_last=drop_last) + # different syntax than for webdataset + shardshuffle = max(shardshuffle, 1) + main_datapipe = main_datapipe.shuffle(buffer_size=shardshuffle).cycle(count=repeat) + + if load_from_s3: + # Load data with S3FileLoader + main_datapipe = S3FileLoader(main_datapipe, buffer_size=buffer_size) + else: + # regular fileopener + main_datapipe = FileOpener(main_datapipe, mode="b") + # adapted TarLoader which closes open tarfile handles after exceeding them + main_datapipe = TarArchiveLoaderAndCloser(datapipe=main_datapipe, handler=handler).groupby(grouper) + if sample_shuffle > 0: + main_datapipe = main_datapipe.shuffle(buffer_size=sample_shuffle) + + self.append(main_datapipe) + self.append(s3dataset2samples(return_always=self.return_always, handler=handler)) + + def map(self, f, handler=wds.reraise_exception): + return self.compose(return_always_map(f, return_always=self.return_always, handler=handler)) diff --git a/video2dataset/dataloader/dataloader.py b/video2dataset/dataloader/dataloader.py new file mode 100644 index 00000000..64d2d15c --- /dev/null +++ b/video2dataset/dataloader/dataloader.py @@ -0,0 +1,211 @@ +"""video dataset creation""" +import webdataset as wds +from functools import partial +from typing import List, Union + +from .custom_wds import ( + WebDatasetWithChangedDecoder, + dict_collation_fn, + TorchDataWebdataset, +) +from .transform import VideoResizer, CutsAdder, CustomTransforms +from .video_decode import VideoDecorder, VideoDecorderWithCutDetection +from .audio_decode import AudioDecoder +from .filters import ( + KeyFilter, + LanguageFilter, + AestheticsFilter, + UnsafeFilter, + UnusedKeyFilter, +) # pylint: disable=unused-import + + +def reassemble(x): + """ + Process a dictionary by updating its values based on certain conditions. + + :param dict x: The input dictionary to process. + :return: The processed dictionary. + :rtype: dict + """ + new_dict = {} + + for key in x: + if key not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): + continue + + # this is updating the output of video decoders + if isinstance(x[key], tuple) and len(x[key]) == 2: + new_dict.update({f"{subk}": x[key][-1][subk] for subk in x[key][-1]}) + + x[key] = x[key][0] + x.update(new_dict) + del new_dict + return x + + +def get_video_dataset( + urls: Union[str, List[str]], + batch_size, + shuffle=0, + repeat=1, + drop_last=False, + video_key="mp4", + cuts_key=None, + decoder_kwargs=None, + custom_transforms=None, + aesthetics_threshold=None, + allowed_languages=None, + p_unsafe_threshold=None, + resize_size=None, + crop_size=None, + random_crop=False, + original_height_key="original_height", + original_width_key="original_width", + keys_to_remove: Union[str, List[str], None] = None, + enforce_additional_keys=None, + return_always: bool = False, + handler=wds.reraise_exception, +): + """ + Generates a webdataset given the specified parameters. + Parameters: + urls (str, list(str)): The path to the dataset or a list of paths to the different locations of the dataset. + batch_size (int): The number of samples per batch. + shuffle (int, optional): Shuffle buffer size. Default is 0 means no shuffling. + repeat (int, optional): Whether to repeat the dataset. Default is 1. -1 means repeating infinitely + drop_last (bool, optional): Whether to drop the last incomplete batch. Default is False. + video_key (str, optional): The key for video files. Default is 'mp4'. + cuts_key (str, optional): The key for cut detection. Default is None. + decoder_kwargs (dict, optional): Keyword arguments for the video decoder. Default is an empty dictionary. + custom_transforms (dict, optional): Pairs of additional custom transforms to apply to samples. + aesthetics_threshold (float, optional): Aesthetic threshold for filtering. Default is None. + allowed_languages (list, optional): List of allowed languages. Default is None. + p_unsafe_threshold (float, optional): Probability threshold for unsafe content filtering. Default is None. + resize_size (tuple, optional): Tuple of (width, height) for resizing the video. Default is None. + crop_size (tuple, optional): Tuple of (width, height) for cropping the video. Default is None. + random_crop (bool, optional): Whether to apply random cropping. Default is False. + original_height_key (str, optional): The key for the original video height. Default is 'original_height'. + original_width_key (str, optional): The key for the original video width. Default is 'original_width'. + keys_to_remove ((list, int), optional): Keys which, for the sake of speed, will be + removed before decoding. Default is None which means nothing will be removed. + enforce_additional_keys (list, optional): Which keys must be in each sample + Returns: + WebDataset: The processed webdataset. + """ + + if decoder_kwargs is None: + decoder_kwargs = {} + if enforce_additional_keys is None: + enforce_additional_keys = ["txt"] + if keys_to_remove is None: + keys_to_remove = [] + + if isinstance(urls, str): + urls = [urls] + # only use webdataset when using pipe + use_torchdata = not urls[0].replace(" ", "").startswith("pipe:") + + if not use_torchdata: + urls = urls[0] + + additional_decoder_kwargs = {} + if cuts_key: + dataset_cls = ( + partial( + WebDatasetWithChangedDecoder, + nodesplitter=wds.split_by_node, + ) + if not use_torchdata + else partial( + TorchDataWebdataset, repeat=repeat, drop_last=drop_last, return_always=return_always, handler=handler + ) + ) + video_decoder_cls = partial(VideoDecorderWithCutDetection, cuts_key=cuts_key) + additional_decoder_kwargs = {"passthrough_keys": [video_key]} + elif video_key in ["mp3", "wav", "flac", "m4a"] and decoder_kwargs != {}: + dataset_cls = wds.WebDataset + video_decoder_cls = AudioDecoder # type: ignore + decoder_kwargs["extension"] = video_key + elif decoder_kwargs == {}: # nothing means just read the bytes + dataset_cls = ( + partial( + wds.WebDataset, + nodesplitter=wds.split_by_node, + ) + if not use_torchdata + else partial( + TorchDataWebdataset, repeat=repeat, drop_last=drop_last, return_always=return_always, handler=handler + ) + ) + video_decoder_cls = None + else: + dataset_cls = ( + partial( + wds.WebDataset, + nodesplitter=wds.split_by_node, + ) + if not use_torchdata + else partial( + TorchDataWebdataset, repeat=repeat, drop_last=drop_last, return_always=return_always, handler=handler + ) + ) + video_decoder_cls = VideoDecorder # type: ignore + + dset = dataset_cls(urls, shardshuffle=shuffle, handler=handler) + + if not use_torchdata: + dset = dset.repeat(repeat).shuffle(shuffle, initial=shuffle) + + unused_key_filter = UnusedKeyFilter(keys=keys_to_remove) + dset = dset.map(unused_key_filter, handler=handler) + + # TODO: organize this such that you don't always need video. + # should work with audio-text, just text or whatever you might want + enforce_keys = [video_key] + enforce_additional_keys + key_filter = KeyFilter(enforce_keys) + dset = dset.select(key_filter) + + if cuts_key: + cut_adder = CutsAdder(cuts_key=cuts_key, video_key=video_key) + dset = dset.map(cut_adder, handler=handler) + + aesthetics_filter = AestheticsFilter(aesthetic_thld=aesthetics_threshold) + language_filter = LanguageFilter(languages=allowed_languages) + unsafe_filter = UnsafeFilter(p_unsafe_threshold=p_unsafe_threshold) + # TODO: in the futuer only include filters we want to use based on params + filters = [aesthetics_filter, language_filter, unsafe_filter] + + # Decoding + if video_decoder_cls is not None: + dset = dset.decode( + video_decoder_cls(**decoder_kwargs), + handler=handler, + **additional_decoder_kwargs, + ).map(reassemble, handler=handler) + + # Filters + for fltr in filters: + dset = dset.select(fltr) + + # Resizing + if decoder_kwargs != {}: # bytes + dset = dset.map( + VideoResizer( + size=resize_size, + crop_size=crop_size, + random_crop=random_crop, + key=video_key, + width_key=original_width_key, + height_key=original_height_key, + ), + handler=handler, + ) + + if custom_transforms: + dset = dset.map(CustomTransforms(custom_transforms), handler=handler) + + if decoder_kwargs != {}: + dset = dset.batched(batch_size, partial=not drop_last, collation_fn=dict_collation_fn) + + return dset diff --git a/video2dataset/dataloader/filters.py b/video2dataset/dataloader/filters.py new file mode 100644 index 00000000..0fd12aad --- /dev/null +++ b/video2dataset/dataloader/filters.py @@ -0,0 +1,104 @@ +"""WebDataset filters""" +from langdetect import detect_langs, DetectorFactory # pylint: disable=unused-import +from typing import List, Union, Dict + +from webdataset.autodecode import decoders # pylint: disable=unused-import + + +class LanguageFilter: + """Filters the dataset based on the language""" + + def __init__(self, languages="en", lang_key="txt"): + self.languages = languages + if not isinstance(self.languages, list) and self.languages is not None: + self.languages = [self.languages] + + self.lang_key = lang_key + + def __call__(self, x): + valid = True + if self.languages: + try: + valid = False + for k in self.languages: + # langs = detect_langs(decoders[k](x[k])) + langs = detect_langs(x[k]) + valid |= max(langs, key=lambda x: x.prob).lang in self.languages + except Exception: # pylint: disable=broad-except + valid = False + return valid + + +class KeyFilter: + """Filters the dataset based on the key""" + + def __init__(self, enforce_keys=None): + self.enforce_keys = enforce_keys + if enforce_keys is None: + self.enforce_keys = ["mp4", "txt"] + + def __call__(self, sample): + try: + for key in self.enforce_keys: + if key not in sample: + return False + return True + except Exception as _: # pylint: disable=broad-except + return False + + +class AestheticsFilter: + """Filters the dataset based on aesthethics""" + + def __init__(self, aesthetic_thld=None, aesthetic_key="AESTHETIC_SCORE"): + self.aesthetic_thld = aesthetic_thld + self.aesthetic_key = aesthetic_key + + def __call__(self, sample): + if self.aesthetic_thld is not None: + try: + return sample["json"][self.aesthetic_key] >= self.aesthetic_thld + except Exception as e: # pylint: disable=broad-except + if self.aesthetic_key not in sample["json"]: + raise e + return True + else: + return True + + +class UnsafeFilter: + """Filters the dataset based on the probability a sample is unsafe""" + + def __init__(self, p_unsafe_threshold): + self.p_unsafe_threshold = p_unsafe_threshold + + def __call__(self, sample): + valid = True + if self.p_unsafe_threshold is not None and "json " in sample: + try: + valid = sample["json"]["punsafe"] < self.p_unsafe_threshold + except Exception: # pylint: disable=broad-except + if "punsafe" not in sample["json"]: + raise + valid = False + return valid + + +class UnusedKeyFilter: + """Removes keys specified keys which are not used during loading and by that speeds up sampling""" + + def __init__(self, keys: Union[str, List[str], None] = None) -> None: + if keys is None: + self.unused_keys = set() + elif isinstance(keys, str): + self.unused_keys = {keys} + else: + self.unused_keys = set(keys) + + def __call__(self, x: Dict) -> Dict: + if not self.unused_keys: + return x + for key in self.unused_keys.intersection(x.keys()): + del x[key] + + return x diff --git a/video2dataset/dataloader/transform.py b/video2dataset/dataloader/transform.py new file mode 100644 index 00000000..e1e49acb --- /dev/null +++ b/video2dataset/dataloader/transform.py @@ -0,0 +1,179 @@ +""" +Transformations you might want to apply to data during loading +""" +import math +import cv2 +import torch +import numpy as np + +from .video_decode import PRNGMixin + + +class VideoResizer(PRNGMixin): + """Resizes frames to specified height and width""" + + def __init__( + self, + size=None, + crop_size=None, + random_crop=False, + key="mp4", + width_key="original_width", + height_key="original_height", + ): + self.key = key + + self.height_key = height_key + self.width_key = width_key + + # resize size as [h,w] + self.resize_size = size + # crop size as [h,w] + self.crop_size = crop_size + if isinstance(self.crop_size, int): + self.crop_size = [self.crop_size] * 2 + self.random_crop = random_crop and self.crop_size is not None + + if self.crop_size or self.resize_size: + print(f"{self.__class__.__name__} is resizing video to size {self.resize_size} ...") + + if self.crop_size: + print(f'... and {"random" if self.random_crop else "center"} cropping to size {self.crop_size}.') + else: + print(f"WARNING: {self.__class__.__name__} is not resizing or croppping videos. Is this intended?") + + def _get_rand_reference(self, resize_size, h, w): + """gets random reference""" + assert resize_size is None or ( + self.crop_size[0] <= resize_size[0] and self.crop_size[1] <= resize_size[1] + ), "Resize size must be greater or equal than crop_size" + + # consistent random crop + min_x = math.ceil(self.crop_size[1] / 2) + max_x = w - min_x + if min_x == max_x: + # catch corner case + max_x = min(max_x + 1, w) + min_y = math.ceil(self.crop_size[0] / 2) + max_y = h - min_y + if min_y == max_y: + # catch corner case + max_y = min(max_y + 1, h) + + try: + x = self.prng.randint(min_x, max_x, 1).item() + y = self.prng.randint(min_y, max_y, 1).item() + except ValueError as e: + print("Video size not large enough, consider reducing size") + print(e) + raise e + + reference = [y, x] + return reference + + # def _get_resize_size(self, frame, orig_h, orig_w): + def _get_resize_size(self, frame): + """gets resize size""" + orig_h, orig_w = frame.shape[:2] + if self.resize_size is not None: + if isinstance(self.resize_size, int): + f = self.resize_size / min((orig_h, orig_w)) + resize_size = [int(round(orig_h * f)), int(round(orig_w * f))] + else: + resize_size = self.resize_size + h, w = resize_size + else: + resize_size = None + h, w = frame.shape[:2] + + return resize_size, (h, w) + + def _get_reference_frame(self, resize_size, h, w): + """gets reference frame""" + if self.random_crop: + reference = self._get_rand_reference(resize_size, h, w) + else: + reference = [s // 2 for s in [h, w]] + + return reference + + def __call__(self, data): + if self.crop_size is None and self.resize_size is None: + if isinstance(data[self.key], list): + # convert to tensor + data[self.key] = torch.from_numpy(np.stack(data[self.key])) + return data + + result = [] + + if self.key not in data: + raise KeyError(f"Specified key {self.key} not in data") + + vidkey = self.key + frames = data[vidkey] + + if isinstance(frames, int): + raise TypeError(f"Frames is int: {frames}") + + # for videos: take height and width of first frames since the same for all frames anyways, + # if resize size is integer, then this is used as the new size of the smaller size + # orig_h = data[self.height_key][0].item() + # orig_w = data[self.width_key][0].item() + + # resize_size, (h, w) = self._get_resize_size(frames[0], orig_h, orig_w) + resize_size, (h, w) = self._get_resize_size(frames[0]) + reference = self._get_reference_frame(resize_size, h, w) + + for frame in frames: + if resize_size is not None: + frame = cv2.resize( + frame, + tuple(reversed(resize_size)), + interpolation=cv2.INTER_LANCZOS4, + ) + + if self.crop_size is not None: + x_ = reference[1] - int(round(self.crop_size[1] / 2)) + y_ = reference[0] - int(round(self.crop_size[0] / 2)) + + frame = frame[ + int(y_) : int(y_) + self.crop_size[0], + int(x_) : int(x_) + self.crop_size[1], + ] + + # TODO: maybe lets add other options for normalization + # will need for VideoCLIP built on top of CLIP + + frame = torch.from_numpy(frame) + result.append(frame) + + data[vidkey] = torch.stack(result).to(torch.float16) + return data + + +class CutsAdder: + """Adds cuts to video sample""" + + def __init__(self, cuts_key, video_key="mp4"): + self.cuts_key = cuts_key + self.video_key = video_key + + def __call__(self, sample): + assert self.cuts_key in sample, f'no field with key "{self.cuts_key}" in sample, but this is required.' + assert self.video_key in sample, f'no field with key "{self.video_key}" in sample, but this is required.' + sample[self.video_key] = { + self.video_key: sample[self.video_key], + self.cuts_key: sample[self.cuts_key], + } + del sample[self.cuts_key] + return sample + + +class CustomTransforms: + def __init__(self, key_transform_dict): + self.key_transform_dict = key_transform_dict + + def __call__(self, sample): + for key, transform in self.key_transform_dict.items(): + sample[key] = transform(sample[key]) + return sample diff --git a/video2dataset/dataloader/video_decode.py b/video2dataset/dataloader/video_decode.py new file mode 100644 index 00000000..d302ce82 --- /dev/null +++ b/video2dataset/dataloader/video_decode.py @@ -0,0 +1,224 @@ +"""Video Decoders""" +import os +import re +import numpy as np +import torch +import torch.nn.functional as F +import decord +import tempfile +from typing import Iterable + +from webdataset.autodecode import decoders + + +decord.bridge.set_bridge("torch") + + +class PRNGMixin: + """ + Adds a prng property which is a numpy RandomState which gets + reinitialized whenever the pid changes to avoid synchronized sampling + behavior when used in conjunction with multiprocessing. + """ + + @property + def prng(self): + currentpid = os.getpid() + if getattr(self, "_initpid", None) != currentpid: + self._initpid = currentpid + self._prng = np.random.RandomState() + return self._prng + + +class AbstractVideoDecoder(PRNGMixin): + def get_frames(self, *args, **kwargs): + raise NotImplementedError(f"{self.__class__.__name__} is abstract") + + def __call__(self, *args, **kwargs): + raise NotImplementedError(f"{self.__class__.__name__} is abstract") + + +class VideoDecorder(AbstractVideoDecoder): + """Basic video decoder that uses decord""" + + def __init__( + self, + n_frames=None, + uniformly_sample=None, + fps=None, + num_threads=4, + tmpdir="/tmp/", + min_fps=1, + max_fps=32, + return_bytes=False, + pad_frames=False, + ): + super().__init__() + self.n_frames = n_frames + self.pad_frames = pad_frames + if fps is not None and not isinstance(fps, Iterable): + fps = [ + fps, + ] + if uniformly_sample: + assert fps is None, "fps not compatible with uniformly_sample..." + self.uniformly_sample = uniformly_sample + self.fps = fps + self.min_fps = min_fps + self.max_fps = max_fps + self.return_bytes = return_bytes + # for frame rate conditioning + if self.fps == "sample": + # this means we sample fps in every iteration + self.fs_ids = {fr: i for i, fr in enumerate(range(self.min_fps, self.max_fps + 1))} + elif isinstance(self.fps, list): + self.fs_ids = {fr: i for i, fr in enumerate(self.fps)} + else: + self.fs_ids = None + assert self.fps is None, f"invalid fps value specified: {self.fps}...." + + self.num_threads = num_threads + + infostring2 = "" + if self.fps == "sample": + infostring1 = "a randomly sampled" + infostring2 = f" in between [{self.min_fps},{self.max_fps}]" + elif isinstance(self.fps, list): + infostring1 = ",".join([str(f) for f in self.fps]) + else: + infostring1 = "native" + info = ( + f'Decoding video clips of length {self.n_frames} with "decord".' + + f" Subsampling clips to {infostring1} fps {infostring2}" + + self.pad_frames * "Padding videos that are too short" + ) + + print(info) + + self.tmpdir = tmpdir + print(f"Setting {self.tmpdir} as temporary directory for the decoder") + + def get_frames(self, reader, n_frames, stride, **kwargs): # pylint: disable=arguments-differ + if n_frames * stride > len(reader): + if not self.pad_frames: + raise ValueError("video clip not long enough for decoding") + n_frames = len(reader) + # sample frame start and choose scene + if n_frames == len(reader) or n_frames == len(reader) // stride: + frame_start = 0 + else: + frame_start = self.prng.choice(int(len(reader)) - int(n_frames * stride), 1).item() + # only decode the frames which are actually needed + frames = reader.get_batch(np.arange(frame_start, frame_start + n_frames * stride, stride).tolist()) + + # TODO: maybe its useful to inform the user which frmaes are padded + # can just output first_pad_index or a mask or something + pad_start = len(frames) + if self.pad_frames and frames.shape[0] < self.n_frames: + frames = F.pad(frames, (0, 0) * 3 + (0, self.n_frames - frames.shape[0])) # pylint: disable=not-callable + + return frames, frame_start, pad_start + + def __call__(self, key, data, scene_list=None): # pylint: disable=arguments-differ + extension = re.sub(r".*[.]", "", key) + if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): + return None + + additional_info = {} + with tempfile.TemporaryDirectory(dir=self.tmpdir) as dirname: + fname = os.path.join(dirname, f"file.{extension}") + with open(fname, "wb") as stream: + stream.write(data) + reader = decord.VideoReader(fname, num_threads=self.num_threads) + + native_fps = int(np.round(reader.get_avg_fps())) + if isinstance(self.fps, list): + fps_choices = list(filter(lambda x: x <= native_fps, self.fps)) + if not fps_choices: + return None + chosen_fps = self.prng.choice(list(fps_choices), 1).item() + + elif self.fps == "sample": + if native_fps < self.min_fps: + return None + max_fps = min(native_fps, self.max_fps) + additional_info.update({"fps_id": 0}) + chosen_fps = self.prng.choice(np.arange(self.min_fps, max_fps + 1), 1).item() + else: + chosen_fps = native_fps + + fs_id = self.fs_ids[chosen_fps] if self.fs_ids else 0 + stride = int(np.round(native_fps / chosen_fps)) + if self.n_frames is None: + n_frames = len(reader) // stride + else: + n_frames = self.n_frames + + if self.uniformly_sample: + additional_info.update({"fps_id": torch.Tensor([fs_id] * self.n_frames).long()}) + t = len(reader) + indices = np.linspace(0, t - 1, self.n_frames) + indices = np.clip(indices, 0, t - 1).astype(int) + frames, start_frame, pad_start = reader.get_batch(indices), indices[0], len(indices) + else: + additional_info.update({"fps_id": torch.Tensor([fs_id] * n_frames).long()}) + frames, start_frame, pad_start = self.get_frames(reader, n_frames, stride, scene_list=scene_list) + frames = frames.float().numpy() + + additional_info["original_height"] = torch.full((frames.shape[0],), fill_value=frames.shape[1]).long() + additional_info["original_width"] = torch.full((frames.shape[0],), fill_value=frames.shape[2]).long() + + pad_masks = torch.zeros((frames.shape[0],)) + pad_masks[:pad_start] = 1.0 + additional_info["pad_masks"] = pad_masks + + if self.n_frames is not None and frames.shape[0] < self.n_frames: + raise ValueError("Decoded video not long enough, skipping") + + # return compatible with torchvisioin API + additional_info.update({"native_fps": chosen_fps if chosen_fps is not None else native_fps}) + additional_info.update({"start_frame": start_frame}) + + if self.return_bytes: + additional_info.update({"video_bytes": data}) + out = (list(frames), additional_info) + return out + + +class VideoDecorderWithCutDetection(VideoDecorder): + """Video decoder that uses decord with cut detection""" + + def __init__(self, *args, cuts_key="npy", **kwargs): + super().__init__(*args, **kwargs) + self.cuts_key = cuts_key + if self.cuts_key not in decoders: + raise KeyError( + f"{self.__class__.__name__} received {self.cuts_key} as cuts_key," + + " but that one is no decoder known to webdataset" + ) + + def __call__(self, key, data): # pylint: disable=arguments-differ,signature-differs + extension = re.sub(r".*[.]", "", key) + if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): + return None + + cut_list = decoders[self.cuts_key](data[self.cuts_key]) + data = data[extension] + + return super().__call__(key, data, scene_list=cut_list) + + def get_frames(self, reader, n_frames, stride, scene_list): # pylint: disable=arguments-differ + min_len = n_frames * stride + # filter out subclips shorther than minimal required length + scene_list = list(filter(lambda x: x[1] - x[0] > min_len, scene_list)) + if len(scene_list) == 0: + raise ValueError("video clips not long enough for decoding") + + clip_id = self.prng.choice(len(scene_list), 1).item() + start, end = scene_list[clip_id].tolist() + # sample frame start and choose scene + frame_start = self.prng.choice(int(end - start) - int(n_frames * stride), 1).item() + start + # only decode the frames which are actually needed + frames = reader.get_batch(np.arange(frame_start, frame_start + n_frames * stride, stride).tolist()) + + return frames, frame_start diff --git a/video2dataset/distributor.py b/video2dataset/distributor.py new file mode 100644 index 00000000..bce8d3c9 --- /dev/null +++ b/video2dataset/distributor.py @@ -0,0 +1,340 @@ +"""distributor defines the distribution strategies for img2dataset""" +import os +import time +import subprocess +import yaml +from datetime import datetime +from contextlib import contextmanager +from multiprocessing import get_context +from itertools import islice, chain + +import fsspec +from tqdm import tqdm + + +def retrier(runf, failed_shards, max_shard_retry): + # retry failed shards max_shard_retry times + for i in range(max_shard_retry): + if len(failed_shards) == 0: + break + print(f"Retrying {len(failed_shards)} shards, try {i+1}") + failed_shards = runf(failed_shards) + if len(failed_shards) != 0: + print( + f"Retried {max_shard_retry} times, but {len(failed_shards)} shards " + "still failed. You may restart the same command to retry again." + ) + + +def no_distributor(process_count, worker, input_sharder, _, max_shard_retry): # pylint: disable=unused-argument + """Go through shards sequentially (useful for when things don't like multiprocessing)""" + + def run(gen): + failed_shards = [] + for shard in gen: + status, row = worker(shard) + if status is False: + failed_shards.append(row) + return failed_shards + + failed_shards = run(input_sharder) + retrier(run, failed_shards, max_shard_retry) + + +def multiprocessing_distributor(processes_count, worker, input_sharder, _, max_shard_retry): + """Distribute the work to the processes using multiprocessing""" + ctx = get_context("spawn") + with ctx.Pool(processes_count, maxtasksperchild=5) as process_pool: + + def run(gen): + failed_shards = [] + for status, row in tqdm(process_pool.imap_unordered(worker, gen)): + if status is False: + failed_shards.append(row) + return failed_shards + + failed_shards = run(input_sharder) + + retrier(run, failed_shards, max_shard_retry) + + process_pool.terminate() + process_pool.join() + del process_pool + + +def pyspark_distributor(processes_count, worker, input_sharder, subjob_size, max_shard_retry): + """Distribute the work to the processes using pyspark""" + + with _spark_session(processes_count) as spark: + + def batcher(iterable, batch_size): + iterator = iter(iterable) + for first in iterator: + yield list(chain([first], islice(iterator, batch_size - 1))) + + def run(gen): + failed_shards = [] + for batch in batcher(gen, subjob_size): + rdd = spark.sparkContext.parallelize(batch, len(batch)) + for status, row in rdd.map(worker).collect(): + if status is False: + failed_shards.append(row) + return failed_shards + + failed_shards = run(input_sharder) + + retrier(run, failed_shards, max_shard_retry) + + +@contextmanager +def _spark_session(processes_count: int): + """Create and close a spark session if none exist""" + + from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel + import pyspark # pylint: disable=import-outside-toplevel + + spark_major_version = int(pyspark.version.__version__[0]) + if spark_major_version >= 3: + spark = SparkSession.getActiveSession() + else: + spark = pyspark.sql.SparkSession._instantiatedSession # pylint: disable=protected-access + + if spark is None: + print("No pyspark session found, creating a new one!") + owned = True + spark = ( + SparkSession.builder.config("spark.driver.memory", "16G") + .master("local[" + str(processes_count) + "]") + .appName("spark-stats") + .getOrCreate() + ) + else: + owned = False + + try: + yield spark + finally: + if owned: + spark.stop() + + +class SlurmShardSampler: + """ + Should be callable to select samples based on the node_id + :param global_task_id: The global task id for the current task + :param num_tasks: The overall number of tasks + :return: + """ + + def __init__(self, global_task_id, num_tasks): + self.task_id = global_task_id + print(global_task_id) + self.num_tasks = num_tasks + + def __call__(self, shardfile_list): + shardlist = [ + (full_shard_id, shard_id) + for full_shard_id, shard_id in shardfile_list + if int(full_shard_id) % self.num_tasks == self.task_id + ] + return shardlist + + +class SlurmDistributor: + """Parallelism via slurm""" + + def __init__( + self, + worker_args, + cpus_per_task, + job_name, + partition, + n_nodes, + account, + gpus_per_node=0, + tasks_per_node=1, + nodelist=None, + constraint=None, + exclude=None, + cache_path=None, + timeout=None, + verbose_wait=False, + ): + self.cpus_per_task = cpus_per_task + self.job_name = job_name + self.partition = partition + self.n_nodes = n_nodes + self.gpus_per_node = gpus_per_node + self.account = account + self.tasks_per_node = tasks_per_node + self.nodelist = nodelist + self.constraint = constraint + self.exclude = exclude + self.cache_path = cache_path + if not cache_path: + cache_path = ".video2dataset_cache/" + self.timeout = timeout + self.verbose_wait = verbose_wait + + self.fs, self.cache_path = fsspec.core.url_to_fs(cache_path) + if not self.fs.exists(self.cache_path): + self.fs.mkdir(self.cache_path) + + self.timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + # save worker args to file (this is written by the slurm_executor) + self.worker_args_as_file = os.path.join(self.cache_path, f"{self.timestamp}_worker_args.yaml") + with self.fs.open(self.worker_args_as_file, "w", encoding="utf-8") as f: + yaml.dump(worker_args, f, default_flow_style=False) + + self.launcher_path = os.path.join(self.cache_path, self.timestamp + "_launcher.sh") + with self.fs.open(self.launcher_path, "w", encoding="utf-8") as launcher_file: + launcher_file.write(self._make_launch_cpu()) + + self.sbatch_path = os.path.join(self.cache_path, self.timestamp + "_sbatch.sh") + with self.fs.open(self.sbatch_path, "w", encoding="utf-8") as sbatch_file: + sbatch_file.write(self._make_sbatch()) + + print(f"Wrote launcher to {self.launcher_path}") + print(f"Wrote sbatch to {self.sbatch_path}") + + def _make_sbatch(self): + nodelist = ("#SBATCH --nodelist " + self.nodelist) if self.nodelist is not None else "" + exclude = ("#SBATCH --exclude " + self.exclude) if self.exclude is not None else "" + account = ("#SBATCH --account " + self.account) if self.account is not None else "" + constraint = ("#SBATCH --constraint " + self.constraint) if self.constraint is not None else "" + return f"""#!/bin/bash +#SBATCH --partition={self.partition} +#SBATCH --job-name={self.job_name} +#SBATCH --output={self.cache_path}/slurm-%x_%j.out +#SBATCH --nodes={self.n_nodes} +#SBATCH --ntasks-per-node={self.tasks_per_node} +#SBATCH --cpus-per-task={self.cpus_per_task} +#SBATCH --gpus-per-node={self.gpus_per_node} +#SBATCH --exclusive +{nodelist} +{exclude} +{account} +{constraint} +#SBATCH --open-mode append + +srun --account {self.account} bash {self.launcher_path} + +""" + + def _make_launch_cpu( + self, + ): + """Create cpu launcher""" + + venv = os.environ.get("VIRTUAL_ENV") + if venv: + venv_activate = f"source {venv}/bin/activate" + else: + conda_env = os.environ.get("CONDA_ENV") + if conda_env: + venv_activate = f"conda activate {conda_env}" + else: + raise ValueError("You need to specify either a virtual environment or a conda environment.") + + cdir = os.path.abspath(os.path.dirname(__file__)) + script = os.path.join(cdir, "slurm_executor.py") + project_root = os.path.abspath(os.path.join(cdir, "..")) + return f"""#!/bin/bash +# mpi version for node rank +H=`hostname` +THEID=`echo -e $HOSTNAMES | python3 -c "import sys;[sys.stdout.write(str(i)) for i,line in enumerate(next(sys.stdin).split(' ')) if line.strip() == '$H'.strip()]"` +export NODE_RANK=${{THEID}} +echo THEID=$THEID + +# set new location for cache dir +export XDG_CACHE_HOME="{self.cache_path}" +#in case of accessing s3 disable ssl verification (for making torchdata s3 related functionality work) +export S3_VERIFY_SSL=0 +export CALLED_FROM_SLURM=1 + +cd {project_root} +{venv_activate} + +python {script} --worker_args {self.worker_args_as_file} --node_id $SLURM_NODEID --n_nodes $SLURM_JOB_NUM_NODES --num_tasks_per_node $SLURM_NTASKS_PER_NODE --subtask_id $SLURM_LOCALID +""" + + def __call__(self, *args, **kwargs): + print(f"{self.__class__.__name__} starting the job") + + status = self._run_job() + + # interpret the results + if status == "success": + print("job succeeded") + return True + elif status == "failed": + print("job failed") + return False + else: + print("exception occurred") + return False + + def _start_job(self, sbatch_file): + """start job""" + args = ["sbatch"] + args.append(sbatch_file) + sbatch_output = subprocess.check_output(args).decode("utf8") + lines = sbatch_output.split("\n") + + lines = [line for line in lines if "Submitted" in line] + if len(lines) == 0: + raise ValueError(f"slurm sbatch failed: {sbatch_output}") + + parsed_sbatch = lines[0].split(" ") + job_id = parsed_sbatch[3].strip() + return job_id + + def _run_job( + self, + ): + """ + Run a job and wait for it to finish. + """ + try: + job_id = self._start_job(self.sbatch_path) + + print(f"waiting for job {job_id}") + + timeout = self.timeout + + if timeout is None: + print("You have not specified a timeout, defaulting to 2 weeks.") + timeout = 1.21e6 + + status = self._wait_for_job_to_finish(job_id=job_id, timeout=timeout) + + if not status: + print(f"canceling {job_id}") + subprocess.check_output(["scancel", job_id]).decode("utf8") + status = self._wait_for_job_to_finish(job_id) + print("job cancelled") + return "failed" + else: + print("job succeeded") + return "success" + except Exception as e: # pylint: disable=broad-except + print(e) + return "exception occurred" + + def _wait_for_job_to_finish(self, job_id, timeout=30): + t = time.time() + while 1: + if time.time() - t > timeout: + return False + time.sleep(1) + if self._is_job_finished(job_id): + return True + + def _is_job_finished(self, job_id): + status = subprocess.check_output(["squeue", "-j", job_id]).decode("utf8") + + if self.verbose_wait: + print(f"job status is {status}") + + return status == "slurm_load_jobs error: Invalid job id specified" or len(status.split("\n")) == 2 diff --git a/video2dataset/input_sharder.py b/video2dataset/input_sharder.py new file mode 100644 index 00000000..5f0f9dc0 --- /dev/null +++ b/video2dataset/input_sharder.py @@ -0,0 +1,191 @@ +"""Reader is module to read the url list and return shards""" + +from multiprocessing.pool import ThreadPool +import math +import fsspec +import time +import pyarrow.parquet as pq +import pyarrow.csv as csv_pq +import pyarrow as pa +import pandas as pd + + +class InputSharder: + """ + The reader class reads an url list and returns shards + It provides an iter method + It provides attributes: + - column_list: the list of columns to read + - input_format: the format of the input file + - url_col: the column name of the url + - caption_col: the column name of the caption + - clip_col: the column name of the list of time ranges for video clips + - save_additional_columns: the list of additional columns to save + - number_sample_per_shard: the number of samples per shard + - done_shards: a set of already done shards + """ + + def __init__( + self, + url_list, + input_format, + url_col, + caption_col, + clip_col, + save_additional_columns, + number_sample_per_shard, + done_shards, + tmp_path, + sampler=lambda x: x, + ) -> None: + self.input_format = input_format + self.url_col = url_col + self.caption_col = caption_col + self.clip_col = clip_col + self.save_additional_columns = save_additional_columns + self.number_sample_per_shard = number_sample_per_shard + self.done_shards = done_shards + self.shard_sampler = sampler + + fs, url_path = fsspec.core.url_to_fs(url_list) + self.fs = fs + self.tmp_path = tmp_path + + if fs.isdir(url_path): + self.input_files = sorted(fs.glob(url_path + "/*." + input_format)) + if len(self.input_files) == 0: + raise ValueError(f"No file found at path {url_path} with extension {input_format}") + else: + self.input_files = [url_path] + + if self.input_format == "txt": + self.column_list = ["url"] + elif self.input_format in ["json", "csv", "tsv", "tsv.gz", "parquet"]: + self.column_list = self.save_additional_columns if self.save_additional_columns is not None else [] + self.column_list = ( + self.column_list + ["clips"] * bool(self.clip_col) + ["caption"] * bool(self.caption_col) + ["url"] + ) + else: + raise ValueError(f"Invalid input format {self.input_format}") + + def _save_to_arrow(self, input_file, start_shard_id): + """Read the input file and save to arrow files in a temporary directory""" + if self.input_format in ["txt", "json", "csv", "tsv"]: + with self.fs.open(input_file, mode="rb") as file: + if self.input_format == "txt": + df = csv_pq.read_csv(file, read_options=csv_pq.ReadOptions(column_names=["url"])) + elif self.input_format == "json": + df = pa.Table.from_pandas(pd.read_json(file)) + elif self.input_format == "csv": + df = pa.Table.from_pandas(pd.read_csv(file)) + elif self.input_format == "tsv": + df = csv_pq.read_csv(file, parse_options=csv_pq.ParseOptions(delimiter="\t")) + else: + raise ValueError(f"Unknown input format {self.input_format}") + elif self.input_format == "tsv.gz": + with self.fs.open(input_file, encoding="utf-8", mode="rb", compression="gzip") as file: + df = csv_pq.read_csv(file, parse_options=csv_pq.ParseOptions(delimiter="\t")) + elif self.input_format == "parquet": + with self.fs.open(input_file, mode="rb") as file: + columns_to_read = [self.url_col] + if self.caption_col is not None: + columns_to_read += [self.caption_col] + if self.clip_col is not None: + columns_to_read += [self.clip_col] + if self.save_additional_columns is not None: + columns_to_read += self.save_additional_columns + df = pq.read_table(file, columns=columns_to_read) + else: + raise ValueError(f"Unknown input format {self.input_format}") + + column_names = df.column_names + if self.caption_col is not None: + column_names = [c if c != self.caption_col else "caption" for c in column_names] + if self.clip_col is not None: + column_names = [c if c != self.clip_col else "clips" for c in column_names] + column_names = [c if c != self.url_col else "url" for c in column_names] + + df = df.rename_columns(column_names) + + number_samples = df.num_rows + + number_shards = math.ceil(df.num_rows / self.number_sample_per_shard) + shards_to_write = [ + (start_shard_id + shard_id, shard_id) + for shard_id in range(number_shards) + if start_shard_id + shard_id not in self.done_shards + ] + + shards_to_write = self.shard_sampler(shards_to_write) + number_shards = len(shards_to_write) + + if len(shards_to_write) == 0: + return [], number_shards + + def write_shard(t): + full_shard_id, shard_id = t + begin_shard = shard_id * self.number_sample_per_shard + end_shard = min(number_samples, (1 + shard_id) * self.number_sample_per_shard) + df_shard = df.slice(begin_shard, end_shard - begin_shard).select(self.column_list) + tmp_file = self.tmp_path + f"/{full_shard_id}.feather" + for i in range(10): + try: + fs, tmp_path = fsspec.core.url_to_fs(tmp_file) + with fs.open(tmp_path, "wb") as file: + with pa.ipc.new_file(file, df_shard.schema) as writer: + writer.write_table(df_shard) + return (full_shard_id, tmp_file) + except Exception as e: # pylint: disable=broad-except + if i != 9: + print("retrying to write to file due to error:", e) + time.sleep(1) + else: + raise e + # can't reach here + raise ValueError("Failed to write to file.") + + for i in range(10): + shards = [] + # thread pool to make it faster to write files to low latency file systems (ie s3, hdfs) + try: + with ThreadPool(32) as thread_pool: + for shard in thread_pool.imap_unordered(write_shard, shards_to_write): + shards.append(shard) + break + except Exception as e: # pylint: disable=broad-except + if i != 9: + print("retrying whole sharding to write to files due to error:", e) + time.sleep(2 * i) + else: + raise e + + shards.sort(key=lambda k: k[0]) + + del df + + return shards, number_shards + + def __iter__(self): + """ + Iterate over shards, yield shards of size number_sample_per_shard or less for the last one + Each shard is a tuple (shard_id, shard) + shard is a tuple (sample id, sample) + sample is a tuple of the columns + """ + start_shard_id = 0 + for i, input_file in enumerate(self.input_files): + print("Sharding file number " + str(i + 1) + " of " + str(len(self.input_files)) + " called " + input_file) + + shards, number_shards = self._save_to_arrow(input_file, start_shard_id) + print("File sharded in " + str(len(shards)) + " shards") + print( + "Downloading starting now, check your bandwidth speed (with bwm-ng)" + "your cpu (with htop), and your disk usage (with iotop)!" + ) + + for shard_id, arrow_file in shards: + yield ( + arrow_file, + shard_id, + ) + start_shard_id += number_shards diff --git a/video2dataset/logger.py b/video2dataset/logger.py new file mode 100644 index 00000000..f0da0054 --- /dev/null +++ b/video2dataset/logger.py @@ -0,0 +1,342 @@ +"""logging utils for the downloader""" + +import wandb +import time +from collections import Counter +import fsspec +import json +import multiprocessing +import queue +import traceback + + +class CappedCounter: + """Maintain a counter with a capping to avoid memory issues""" + + def __init__(self, max_size=10**5): + self.max_size = max_size + self.counter = Counter() + + def increment(self, key): + if len(self.counter) >= self.max_size: + self._keep_most_frequent() + self.counter[key] += 1 + + def _keep_most_frequent(self): + self.counter = Counter(dict(self.counter.most_common(int(self.max_size / 2)))) + + def most_common(self, k): + return self.counter.most_common(k) + + def update(self, counter): + self.counter.update(counter.counter) + if len(self.counter) >= self.max_size: + self._keep_most_frequent() + + def dump(self): + return self.counter + + @classmethod + def load(cls, d, max_size=10**5): + c = CappedCounter(max_size) + c.counter = Counter(d) + return c + + +class Logger: + """logger which logs when number of calls reaches a value or a time interval has passed""" + + def __init__(self, min_interval=0): + """Log only every if min_interval (seconds) have elapsed since last log""" + # wait for all processes to return + self.processes_returned = 0 + # min time (in seconds) before logging a new table (avoids too many logs) + self.min_interval = min_interval + self.last = time.perf_counter() + # keep track of whether we logged the last call + self.last_call_logged = False + self.last_args = None + self.last_kwargs = None + + def __call__(self, *args, **kwargs): + self.processes_returned += 1 + if time.perf_counter() - self.last > self.min_interval: + self.do_log(*args, **kwargs) + self.last = time.perf_counter() + self.last_call_logged = True + else: + self.last_call_logged = False + self.last_args = args + self.last_kwargs = kwargs + + def do_log(self, *args, **kwargs): + raise NotImplementedError() + + def sync(self): + """Ensure last call is logged""" + if not self.last_call_logged and self.last_args is not None: + self.do_log(*self.last_args, **self.last_kwargs) + # reset for next file + self.processes_returned = 0 + + +class SpeedLogger(Logger): + """Log performance metrics""" + + def __init__(self, prefix, enable_wandb, **logger_args): + super().__init__(**logger_args) + self.prefix = prefix + self.start_time = float("+inf") + self.end_time = float("-inf") + self.count = 0 + self.success = 0 + self.failed_to_download = 0 + self.failed_to_subsample = 0 + self.bytes_downloaded = 0 + self.enable_wandb = enable_wandb + + def __call__( + self, + count, + success, + failed_to_download, + failed_to_subsample, + bytes_downloaded, + start_time, + end_time, + ): # pylint: disable=arguments-differ + self.count += count + self.success += success + self.failed_to_download += failed_to_download + self.failed_to_subsample += failed_to_subsample + self.bytes_downloaded += bytes_downloaded + self.start_time = min(start_time, self.start_time) + self.end_time = max(end_time, self.end_time) + super().__call__( + self.count, + self.success, + self.failed_to_download, + self.failed_to_subsample, + self.bytes_downloaded, + self.start_time, + self.end_time, + ) + + def do_log( + self, + count, + success, + failed_to_download, + failed_to_subsample, + bytes_downloaded, + start_time, + end_time, + ): # pylint: disable=arguments-differ + duration = end_time - start_time + vid_per_sec = count / duration + bytes_per_sec = 1.0 * bytes_downloaded / duration + success_ratio = 1.0 * success / count + failed_to_download_ratio = 1.0 * failed_to_download / count + failed_to_subsample_ratio = 1.0 * failed_to_subsample / count + + print( + " - ".join( + [ + f"{self.prefix:<7}", + f"success: {success_ratio:.3f}", + f"failed to download: {failed_to_download_ratio:.3f}", + f"failed to resize: {failed_to_subsample_ratio:.3f}", + f"videos per sec: {vid_per_sec:.0f}", + f"bytes per sec: {bytes_per_sec:.0f}", + f"count: {count}", + ] + ) + ) + + if self.enable_wandb: + wandb.log( + { + f"{self.prefix}/vid_per_sec": vid_per_sec, + f"{self.prefix}/bytes_per_sec": bytes_per_sec, + f"{self.prefix}/success": success_ratio, + f"{self.prefix}/failed_to_download": failed_to_download_ratio, + f"{self.prefix}/failed_to_subsample": failed_to_subsample_ratio, + f"{self.prefix}/count": count, + } + ) + + +class StatusTableLogger(Logger): + """Log status table to W&B, up to `max_status` most frequent items""" + + def __init__(self, max_status=100, min_interval=60, enable_wandb=False, **logger_args): + super().__init__(min_interval=min_interval, **logger_args) + # avoids too many errors unique to a specific website (SSL certificates, etc) + self.max_status = max_status + self.enable_wandb = enable_wandb + + def do_log(self, status_dict, count): # pylint: disable=arguments-differ + if self.enable_wandb: + status_table = wandb.Table( + columns=["status", "frequency", "count"], + data=[[k, 1.0 * v / count, v] for k, v in status_dict.most_common(self.max_status)], + ) + wandb.run.log({"status": status_table}) + + +def write_stats( + output_folder, + shard_id, + count, + successes, + failed_to_download, + failed_to_subsample, + bytes_downloaded, + start_time, + end_time, + status_dict, + oom_shard_count, +): + """Write stats to disk""" + stats = { + "count": count, + "successes": successes, + "failed_to_download": failed_to_download, + "failed_to_subsample": failed_to_subsample, + "duration": end_time - start_time, + "bytes_downloaded": bytes_downloaded, + "start_time": start_time, + "end_time": end_time, + "status_dict": status_dict.dump(), + } + fs, output_path = fsspec.core.url_to_fs(output_folder) + shard_name = ( + shard_id + if isinstance(shard_id, str) + else "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string + shard_id=shard_id, oom_shard_count=oom_shard_count + ) + ) + json_file = f"{output_path}/{shard_name}_stats.json" + with fs.open(json_file, "w") as f: + json.dump(stats, f, indent=4) + + +# https://docs.python.org/3/library/multiprocessing.html +# logger process that reads stats files regularly, aggregates and send to wandb / print to terminal +class LoggerProcess(multiprocessing.context.SpawnProcess): + """Logger process that reads stats files regularly, aggregates and send to wandb / print to terminal""" + + def __init__( + self, + output_folder, + enable_wandb, + wandb_project, + config_parameters, + log_interval=5, + ): + super().__init__() + self.log_interval = log_interval + self.enable_wandb = enable_wandb + self.output_folder = output_folder + self.stats_files = set() + self.wandb_project = wandb_project + self.done_shards = set() + self.config_parameters = config_parameters + ctx = multiprocessing.get_context("spawn") + self.q = ctx.Queue() + + def run(self): + """Run logger process""" + + fs, output_path = fsspec.core.url_to_fs(self.output_folder, use_listings_cache=False) + + if self.enable_wandb: + self.current_run = wandb.init( + project=self.wandb_project, + config=self.config_parameters, + anonymous="allow", + ) + else: + self.current_run = None + self.total_speed_logger = SpeedLogger("total", enable_wandb=self.enable_wandb) + self.status_table_logger = StatusTableLogger(enable_wandb=self.enable_wandb) + last_check = 0 + total_status_dict = CappedCounter() + while True: + time.sleep(0.1) + try: + self.q.get(False) + last_one = True + except queue.Empty as _: + last_one = False + if not last_one and time.perf_counter() - last_check < self.log_interval: + continue + + try: + # read stats files + stats_files = fs.glob(output_path + "/*.json") + + # filter out files that have an id smaller that are already done + stats_files = [f for f in stats_files if int(f.split("/")[-1].split("_")[0]) not in self.done_shards] + + # get new stats files + new_stats_files = set(stats_files) - self.stats_files + if len(new_stats_files) == 0: + if last_one: + self.finish() + return + + # read new stats files + for stats_file in new_stats_files: + with fs.open(stats_file, "r") as f: + try: + stats = json.load(f) + SpeedLogger("worker", enable_wandb=self.enable_wandb)( + count=stats["count"], + success=stats["successes"], + failed_to_download=stats["failed_to_download"], + failed_to_subsample=stats["failed_to_subsample"], + bytes_downloaded=stats["bytes_downloaded"], + start_time=stats["start_time"], + end_time=stats["end_time"], + ) + self.total_speed_logger( + count=stats["count"], + success=stats["successes"], + failed_to_download=stats["failed_to_download"], + failed_to_subsample=stats["failed_to_subsample"], + bytes_downloaded=stats["bytes_downloaded"], + start_time=stats["start_time"], + end_time=stats["end_time"], + ) + status_dict = CappedCounter.load(stats["status_dict"]) + total_status_dict.update(status_dict) + self.status_table_logger(total_status_dict, self.total_speed_logger.count) + except Exception as err: # pylint: disable=broad-except + print(f"failed to parse stats file {stats_file}", err) + + self.stats_files.add(stats_file) + last_check = time.perf_counter() + + if last_one: + self.finish() + return + except Exception as e: # pylint: disable=broad-except + traceback.print_exc() + print("logger error", e) + self.finish() + return + + def finish(self): + """Finish logger process""" + self.total_speed_logger.sync() + self.status_table_logger.sync() + if self.current_run is not None: + self.current_run.finish() + + def join(self, timeout=None): + """Stop logger process""" + self.q.put("stop") + super().join() + self.q.close() diff --git a/video2dataset/main.py b/video2dataset/main.py new file mode 100644 index 00000000..aa1d998b --- /dev/null +++ b/video2dataset/main.py @@ -0,0 +1,292 @@ +"""Create dataset from video links and metadata.""" +import os +import sys +import signal +import fire +import fsspec + +from omegaconf import OmegaConf +from typing import List, Optional, Any +import numpy as np # pylint: disable=unused-import + +from video2dataset.logger import LoggerProcess +from video2dataset.data_writer import ( + WebDatasetSampleWriter, + FilesSampleWriter, + ParquetSampleWriter, + TFRecordSampleWriter, + DummySampleWriter, +) +from video2dataset.input_sharder import InputSharder +from video2dataset.output_sharder import OutputSharder +from video2dataset.distributor import ( + no_distributor, + multiprocessing_distributor, + pyspark_distributor, + SlurmDistributor, + SlurmShardSampler, +) +from video2dataset.workers import DownloadWorker, SubsetWorker, OpticalFlowWorker, CaptionWorker, WhisperWorker +from video2dataset.configs import CONFIGS +from video2dataset.types import EncodeFormats + + +def identity(x): + return x + + +# pylint: disable=unused-argument +# pylint: disable=eval-used +# pylint: disable=broad-except +def video2dataset( + url_list: str, + output_folder: str = "dataset", + output_format: str = "files", + input_format: str = "csv", + encode_formats: Optional[EncodeFormats] = None, + stage: str = "download", + url_col: str = "url", + caption_col: Optional[str] = None, + clip_col: Optional[str] = None, + save_additional_columns: Optional[List[str]] = None, + enable_wandb: bool = False, + wandb_project: str = "video2dataset", + incremental_mode: str = "incremental", + max_shard_retry: int = 1, + tmp_dir: str = "/tmp", + config: Any = "default", +): + """ + Create datasets from video/audio links + + Args: + url_list: list of input urls - can be any of the supported input formats + (csv, parquet, braceexpand tar paths etc.) + output_folder: Desired location of output dataset + output_format: Format of output dataset, can be + - files, samples saved in subdirectory for each shard (useful for debugging) + - webdataset, samples saved in tars (useful for efficient loading) + - parquet, sampels saved in parquet (as bytes) + - tfrecord, samples saved in tfrecord (as bytes) + - dummy, does not save (useful for benchmarks) + input_format: Format of the input, can be + - txt, text file with a url in each line + - csv, csv file with urls, (and captions + metadata) + - tsv, tsv - || - + - tsv.gz, - || - but compressed gzip + - json, loads urls and metadata as json + - parquet, loads urls and metadata as parquet + - webdataset, usually braceexpand format of mutliple paths to tars to re-process + encode_formats: Dict that specifies what extension each modality should use + f.e. {"video": "mp4", "audio": "m4a"} + stage: String that tells video2dataset what stage of processing is being performed. Can be + WARNING: To be depracated soon (this information should be deduced based on config) + - download, when input is some tabular format and data must be downloaded first + - subset, tar files are already written and we would like to re-process (input_format == "webdataset") + - optical_flow, tar files are written and we would like to compute optical_flow and save to md shards + - caption, tar files are written and we would like to generate caption and save to md shards + url_col: Column in input (if has columns) that contains the url + caption_col: Column in input (if has columns) that contains captions (to be written as txt) + clip_col: Column in input (if has columns) that contains timeframes of clips for how to split video + save_additional_columns: List of column names to save to json component of a sample + enable_wandb: Whether or not to log info to wandb + wandb_project: Name of wandb project to log runs to + incremental_mode: Decides how to handle restarting, Can be + - incremental, checks which shards are done and skips those + - overwrite, deletes and reprocesses shards as it goes + max_shard_retry: Maximum amount of attempts to retry a failed shard + tmp_dir: Path to temporary directory on your file system + config: Path to your config of choice or the config itself (more info on configs in API doc) + """ + local_args = dict(locals()) + if isinstance(config, str): + config = CONFIGS[config] if config in CONFIGS else OmegaConf.load(config) + config = OmegaConf.to_container(config) + for arg_type in ["subsampling", "reading", "storage", "distribution"]: + assert arg_type in config + + if config["reading"]["sampler"] is None: + config["reading"]["sampler"] = identity + + called_from_slurm = "CALLED_FROM_SLURM" in os.environ + if called_from_slurm: + global_task_id = int(os.environ["GLOBAL_RANK"]) + num_tasks = ( + config["distribution"]["distributor_args"]["n_nodes"] + * config["distribution"]["distributor_args"]["tasks_per_node"] + ) + config["reading"]["sampler"] = SlurmShardSampler(global_task_id=global_task_id, num_tasks=num_tasks) + config["distribution"]["distributor"] = "multiprocessing" + + # Only log from master + enable_wandb = enable_wandb and (global_task_id == 0) + + # TODO: find better location for this code + # TODO: figure out minimum yt_meta_args for subtitles to be added to metadata + if config["storage"]["captions_are_subtitles"]: + assert clip_col is None # no weird double-clipping + if config["reading"]["yt_args"]["yt_metadata_args"] is None: + config["reading"]["yt_args"]["yt_metadata_args"] = {} + if not config["reading"]["yt_args"]["yt_metadata_args"].get("writesubtitles", None): # type: ignore + config["reading"]["yt_args"]["yt_metadata_args"]["writesubtitles"] = "all" # type: ignore + + if encode_formats is None: + encode_formats = {"video": "mp4"} + + def make_path_absolute(path): + fs, p = fsspec.core.url_to_fs(path) + if fs.protocol == "file": + return os.path.abspath(p) + return path + + output_folder = make_path_absolute(output_folder) + url_list = make_path_absolute(url_list) + + logger_process = LoggerProcess(output_folder, enable_wandb, wandb_project, local_args) + tmp_path = output_folder + "/_tmp" + fs, run_tmp_dir = fsspec.core.url_to_fs(tmp_path) + if not fs.exists(run_tmp_dir): + fs.mkdir(run_tmp_dir) + + def signal_handler(signal_arg, frame): # pylint: disable=unused-argument + try: + fs.rm(run_tmp_dir, recursive=True) + except Exception as _: # pylint: disable=broad-except + pass + logger_process.terminate() + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + + save_caption = caption_col is not None or config["storage"]["captions_are_subtitles"] + + fs, output_path = fsspec.core.url_to_fs(output_folder) + + if not fs.exists(output_path): + fs.mkdir(output_path) + done_shards = set() + else: + if incremental_mode == "incremental": + done_shards = set(int(x.split("/")[-1].split("_")[0]) for x in fs.glob(output_path + "/*.json")) + elif incremental_mode == "overwrite": + fs.rm(output_path, recursive=True) + fs.mkdir(output_path) + done_shards = set() + else: + raise ValueError(f"Unknown incremental mode {incremental_mode}") + + logger_process.done_shards = done_shards + logger_process.start() + + if output_format == "webdataset": + sample_writer_class = WebDatasetSampleWriter + elif output_format == "parquet": + sample_writer_class = ParquetSampleWriter # type: ignore + elif output_format == "files": + sample_writer_class = FilesSampleWriter # type: ignore + elif output_format == "tfrecord": + sample_writer_class = TFRecordSampleWriter # type: ignore + elif output_format == "dummy": + sample_writer_class = DummySampleWriter # type: ignore + else: + raise ValueError(f"Invalid output format {output_format}") + + if input_format == "webdataset": + shard_iterator = OutputSharder( # type: ignore + url_list, input_format, done_shards, sampler=config["reading"]["sampler"] + ) + else: + shard_iterator = InputSharder( # type: ignore + url_list, + input_format, + url_col, + caption_col, + clip_col, + save_additional_columns, + config["storage"]["number_sample_per_shard"], + done_shards, + tmp_path, + config["reading"]["sampler"], + ) + + if stage == "download": + worker = DownloadWorker( + sample_writer_class=sample_writer_class, + save_caption=save_caption, + output_folder=output_folder, + column_list=shard_iterator.column_list, + tmp_dir=tmp_dir, + encode_formats=encode_formats, + config=config, + ) + elif stage == "subset": + worker = SubsetWorker( # type: ignore + sample_writer_class=sample_writer_class, + output_folder=output_folder, + encode_formats=encode_formats, + config=config, + ) + elif stage == "optical_flow": + is_slurm_task = "GLOBAL_RANK" in os.environ and config["distribution"]["distributor"] == "multiprocessing" + worker = OpticalFlowWorker( # type: ignore + sample_writer_class=sample_writer_class, + output_folder=output_folder, + encode_formats=encode_formats, + is_slurm_task=is_slurm_task, + config=config, + ) + elif stage == "caption": + is_slurm_task = "GLOBAL_RANK" in os.environ and config["distribution"]["distributor"] == "multiprocessing" + worker = CaptionWorker( # type: ignore + sample_writer_class=sample_writer_class, + output_folder=output_folder, + encode_formats=encode_formats, + is_slurm_task=is_slurm_task, + config=config, + ) + elif stage == "whisper": + is_slurm_task = "GLOBAL_RANK" in os.environ and config["distribution"]["distributor"] == "multiprocessing" + worker = WhisperWorker( # type: ignore + sample_writer_class=sample_writer_class, + output_folder=output_folder, + column_list=shard_iterator.column_list, + tmp_dir=tmp_dir, + encode_formats=encode_formats, + is_slurm_task=is_slurm_task, + config=config, + ) + else: + raise ValueError(f"Invalid stage: {stage}") + + print("Starting the downloading of this file") + if config["distribution"]["distributor"] == "multiprocessing" or called_from_slurm: + distributor_fn = multiprocessing_distributor if stage not in ["whisper", "caption"] else no_distributor + called_from_slurm = "GLOBAL_RANK" in os.environ + elif config["distribution"]["distributor"] == "pyspark": + distributor_fn = pyspark_distributor + elif config["distribution"]["distributor"] == "slurm": + worker_args = {key: local_args[key] for key in local_args if not key.startswith("slurm")} + slurm_args = config["distribution"]["distributor_args"] + + distributor_fn = SlurmDistributor(worker_args=worker_args, **slurm_args) + else: + raise ValueError(f"Distributor {config['distribution']['distributor']} not supported") + + distributor_fn( + config["distribution"]["processes_count"], + worker, + shard_iterator, + config["distribution"]["subjob_size"], + max_shard_retry, + ) + logger_process.join() + if not called_from_slurm: + fs.rm(run_tmp_dir, recursive=True) + + +def main(): + fire.Fire(video2dataset) + + +if __name__ == "__main__": + main() diff --git a/video2dataset/output_sharder.py b/video2dataset/output_sharder.py new file mode 100644 index 00000000..b929d0d8 --- /dev/null +++ b/video2dataset/output_sharder.py @@ -0,0 +1,54 @@ +"""Reader is module to read the url list and return shards""" +import braceexpand +import fsspec + + +class OutputSharder: + """ + The reader class reads a shard list and returns shards + + It provides an iter method + It provides attributes: + - shard_list: a list of shards to read + - input_format: the format of the input dataset + - done_shards: a set of already done shards + - group_shards: the number of shards to group together + """ + + def __init__(self, shard_list, input_format, done_shards, sampler=lambda x: x) -> None: + self.input_format = input_format + self.done_shards = done_shards + self.column_list = None + fs, url_path = fsspec.core.url_to_fs(shard_list) + + if fs.isdir(url_path): + self.shard_list = sorted(fs.glob(url_path + "/*.tar")) + if "s3://" in shard_list: + self.shard_list = ["s3://" + s for s in self.shard_list] + if len(self.shard_list) == 0: + raise ValueError(f"No file found at path {url_path} with extension {input_format}") + else: + self.shard_list = list(braceexpand.braceexpand(shard_list)) + + num_shards = len(self.shard_list) + print(f"Found a total of {num_shards} shards!") + + if self.input_format == "webdataset": + self.shard_ids = [s.split("/")[-1][: -len(".tar")] for s in self.shard_list] + elif self.input_format == "files": + self.shard_ids = [s.split("/")[-1] for s in self.shard_list] + + self.shards = sampler( + [(s_id, s) for s_id, s in zip(self.shard_ids, self.shard_list) if int(s_id) not in self.done_shards] + ) + + num_shards = len(self.shards) + print(f"Processing a total of {num_shards} shards!") + + def __iter__(self): + """ + Iterate over shards, yield shards of size group_shards size + Each shard is a tuple (shard_id, shard) + """ + for s_id, s in self.shards: + yield (s, s_id) diff --git a/video2dataset/slurm_executor.py b/video2dataset/slurm_executor.py new file mode 100644 index 00000000..cd1cd452 --- /dev/null +++ b/video2dataset/slurm_executor.py @@ -0,0 +1,51 @@ +"""Executor for distribution via slurm""" + +import os +import yaml +import fire + +from video2dataset import video2dataset + + +def executor(worker_args, node_id, n_nodes, num_tasks_per_node, subtask_id): + """ + Spins up the individual workers on the nodes + :param worker_args: parameters to video2dataset on the respective node + :param node_id: node id + :param n_nodes: overall number of node + :param num_tasks_per_node: number of distinct calls to video2dataset from the current node + :param subtask_id: local id of current call to video2dataset + :return: + """ + + num_tasks = n_nodes * num_tasks_per_node + print("#" * 100) + print("args:") + print(worker_args) + print("node id", node_id) + print("n_nodes", n_nodes) + print("num_tasks_per_node", num_tasks_per_node) + print("subtask_id", subtask_id) + print("num_tasks", num_tasks) + print("#" * 100) + + global_task_id = node_id * num_tasks_per_node + subtask_id + assert global_task_id < num_tasks, ( + f"global_task_id is {global_task_id} but must be less than " + f"num_nodes*num_tasks_per_node={n_nodes * num_tasks_per_node}" + ) + + print(f"Starting task with id {global_task_id}") + os.environ["GLOBAL_RANK"] = str(global_task_id) + os.environ["LOCAL_RANK"] = str(subtask_id) + + # Read the worker args from the file + with open(worker_args, "r", encoding="utf-8") as worker_args_file: + worker_args = yaml.load(worker_args_file, Loader=yaml.SafeLoader) + + # call main script from every subprocess + video2dataset(**worker_args) + + +if __name__ == "__main__": + fire.Fire(executor) diff --git a/video2dataset/subsamplers/__init__.py b/video2dataset/subsamplers/__init__.py new file mode 100644 index 00000000..53b4141f --- /dev/null +++ b/video2dataset/subsamplers/__init__.py @@ -0,0 +1,16 @@ +""" +Transform video and audio by reducing fps, extracting videos, changing resolution, reducing bitrate, etc. +""" + +from .audio_rate_subsampler import AudioRateSubsampler +from .clipping_subsampler import ClippingSubsampler, _get_seconds, _split_time_frame, Streams +from .frame_subsampler import FrameSubsampler +from .ffprobe_subsampler import FFProbeSubsampler +from .noop_subsampler import NoOpSubsampler +from .resolution_subsampler import ResolutionSubsampler +from .cut_detection_subsampler import CutDetectionSubsampler +from .optical_flow_subsampler import OpticalFlowSubsampler +from .whisper_subsampler import WhisperSubsampler +from .caption_subsampler import CaptionSubsampler + +from .subsampler import Subsampler diff --git a/video2dataset/subsamplers/audio_rate_subsampler.py b/video2dataset/subsamplers/audio_rate_subsampler.py new file mode 100644 index 00000000..ad8b12e1 --- /dev/null +++ b/video2dataset/subsamplers/audio_rate_subsampler.py @@ -0,0 +1,45 @@ +""" +frame subsampler adjusts the fps of the videos to some constant value +""" + + +import tempfile +import os +import ffmpeg + + +class AudioRateSubsampler: + """ + Adjusts the frame rate of the videos to the specified frame rate. + Args: + sample_rate (int): Target sample rate of the audio. + encode_format (str): Format to encode in (i.e. m4a) + """ + + def __init__(self, sample_rate, encode_format, n_audio_channels=None): + self.sample_rate = sample_rate + self.encode_format = encode_format + self.n_audio_channels = n_audio_channels + + def __call__(self, streams, metadata=None): + audio_bytes = streams.pop("audio") + subsampled_bytes = [] + for aud_bytes in audio_bytes: + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "input.m4a"), "wb") as f: + f.write(aud_bytes) + ext = self.encode_format + try: + # TODO: for now assuming m4a, change this + ffmpeg_args = {"ar": str(self.sample_rate), "f": ext} + if self.n_audio_channels is not None: + ffmpeg_args["ac"] = str(self.n_audio_channels) + _ = ffmpeg.input(f"{tmpdir}/input.m4a") + _ = _.output(f"{tmpdir}/output.{ext}", **ffmpeg_args).run(capture_stdout=True, quiet=True) + except Exception as err: # pylint: disable=broad-except + return [], None, str(err) + + with open(f"{tmpdir}/output.{ext}", "rb") as f: + subsampled_bytes.append(f.read()) + streams["audio"] = subsampled_bytes + return streams, metadata, None diff --git a/video2dataset/subsamplers/caption_subsampler.py b/video2dataset/subsamplers/caption_subsampler.py new file mode 100644 index 00000000..c5e08e80 --- /dev/null +++ b/video2dataset/subsamplers/caption_subsampler.py @@ -0,0 +1,178 @@ +""" +video captioner +""" +import os + +import torch +from torch import nn +from einops import rearrange +from typing import Optional + +from transformers import ( + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + Blip2Config, + Blip2ForConditionalGeneration, + Blip2Processor, + Blip2QFormerModel, + Blip2VisionModel, +) + +from transformers.modeling_outputs import BaseModelOutputWithPooling + +from .subsampler import Subsampler + + +class AttrDict(dict): + """ + Lets us access dict keys with .key + """ + + # pylint: disable=super-with-arguments + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +class VideoBlipVisionModel(Blip2VisionModel): + """ + A simple, augmented version of Blip2VisionModel to handle videos. + Source: https://github.com/yukw777/VideoBLIP + """ + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + """Forward method for vision blip model""" + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + batch, _, time, _, _ = pixel_values.size() + flat_pixel_values = pixel_values.permute(0, 2, 1, 3, 4).flatten(end_dim=1) + vision_outputs: BaseModelOutputWithPooling = super().forward( + pixel_values=flat_pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + seq_len = vision_outputs.last_hidden_state.size(1) + last_hidden_state = vision_outputs.last_hidden_state.view(batch, time * seq_len, -1) + pooler_output = vision_outputs.pooler_output.view(batch, time, -1) + hidden_states = ( + tuple(hidden.view(batch, time * seq_len, -1) for hidden in vision_outputs.hidden_states) + if vision_outputs.hidden_states is not None + else None + ) + attentions = ( + tuple(hidden.view(batch, time, -1, seq_len, seq_len) for hidden in vision_outputs.attentions) + if vision_outputs.attentions is not None + else None + ) + if return_dict: + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=hidden_states, + attentions=attentions, + ) + return (last_hidden_state, pooler_output, hidden_states, attentions) + + +class VideoBlipForConditionalGeneration(Blip2ForConditionalGeneration): + """ + A simple, augmented version of Blip2ForConditionalGeneration to handle videos. + Source: https://github.com/yukw777/VideoBLIP + """ + + def __init__(self, config: Blip2Config) -> None: + super(Blip2ForConditionalGeneration, self).__init__(config) # pylint: disable=E1003 + self.vision_model = VideoBlipVisionModel(config.vision_config) + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = Blip2QFormerModel(config.qformer_config) + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + if config.use_decoder_only_language_model: + language_model = AutoModelForCausalLM.from_config(config.text_config) + else: + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) + self.language_model = language_model + self.post_init() + + +class VideoCaptioner: + """Synthetic video captioner using the VideoBLIP model.""" + + def __init__(self, args): + assert args.video_captioner in ["kpyu/video-blip-flan-t5-xl-ego4d", "kpyu/video-blip-opt-2.7b-ego4d"] + model = args.video_captioner + if args.prompt is None: + self.prompt = "Question: Can you give a detailed description of what is happening in this video? Answer:" + else: + self.prompt = args.prompt + + self.device = args.get("device", "cuda") + self.processor = Blip2Processor.from_pretrained(model) + + if ":" in self.device: + device_map = {"": int(self.device.split(":")[-1])} + else: + device_map = {"": 0} + self.model = VideoBlipForConditionalGeneration.from_pretrained(model, load_in_8bit=True, device_map=device_map) + + def __call__(self, video): + video = video * 0.00392156862745098 + video = (video - torch.Tensor([0.48145466, 0.4578275, 0.40821073])) / torch.Tensor( + [0.26862954, 0.26130258, 0.27577711] + ) + video = rearrange(video, "b t h w c -> b c t h w").to(torch.float16) + + inputs = self.processor(images=None, text=[self.prompt] * video.shape[0], return_tensors="pt") + inputs["pixel_values"] = video + inputs = inputs.to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate( + **inputs, + max_new_tokens=100, + num_beams=10, + early_stopping=True, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + ) + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True) + generated_text = [x.strip() for x in generated_text] + return generated_text + + +class CaptionSubsampler(Subsampler): + """A class to generate synthetic text caption from video frames.""" + + def __init__( + self, + captioner_args=None, + is_slurm_task=False, + ): + if is_slurm_task: + local_rank = os.environ["LOCAL_RANK"] + device = f"cuda:{local_rank}" + captioner_args["device"] = device + else: + captioner_args["device"] = "cuda" if torch.cuda.is_available() else "cpu" + + if not isinstance(captioner_args, AttrDict): + captioner_args = AttrDict(captioner_args) + + self.captioner = VideoCaptioner(captioner_args) + + def __call__(self, frames): + try: + combined_caption = self.captioner(frames) + return [ + combined_caption, + ], None + except Exception as err: # pylint: disable=broad-except + return [], str(err) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py new file mode 100644 index 00000000..2e7c7d96 --- /dev/null +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -0,0 +1,321 @@ +""" +clipping subsampler turns full videos into clips of videos according to clip_col +""" +from collections.abc import Iterable +import copy +import datetime +import ffmpeg +import glob +import os +import tempfile +from typing import Any, Union, List, Tuple, Dict, Literal, cast + +from video2dataset.subsamplers.subsampler import Subsampler +from video2dataset.types import EncodeFormats, Streams + + +ClipSpan = List[float] # [start, end] + + +def _get_seconds(t: Union[str, float]) -> float: + """Converts time to seconds""" + if not isinstance(t, str): + return float(t) # already seconds + time_format = "%H:%M:%S.%f" # TODO: maybe parameterize this? + t_obj = datetime.datetime.strptime(t, time_format).time() + return t_obj.second + t_obj.microsecond / 1e6 + t_obj.minute * 60 + t_obj.hour * 3600 + + +def _get_strtime(t_sec: float) -> str: + """Converts time to string""" + hour = int(t_sec // 3600) + minute = int((t_sec // 60) % 60) + second = int(t_sec % 60) + # Use round to solve machine error problem (e.g. t_sec=13.6) + microsecond = round((t_sec - int(t_sec)) * 1000) + return f"{hour:02d}:{minute:02d}:{second:02d}.{microsecond:03d}" + + +def _split_time_frame(s: float, e: float, min_length: float, max_length: float) -> List[ClipSpan]: + """Filters out cuts by min and max length""" + time_d = e - s + n_full_clips = int(time_d // max_length) + clip_spans = [[s + i * max_length, s + (i + 1) * max_length] for i in range(n_full_clips)] + ( + [[s + (n_full_clips) * max_length, e]] if time_d % max_length > min_length else [] + ) + return clip_spans + + +def _adjust_clip_spans_to_keyframes(clip_spans: List[ClipSpan], keyframes: List[float]) -> List[ClipSpan]: + """Translates clip_spans into keyframe vocab""" + adjusted_clip_spans = [] + for start, end in clip_spans: + keyframes_in_range = [k for k in keyframes if start <= k <= end] + if keyframes_in_range: + adjusted_start = min(keyframes_in_range) + adjusted_end = max(keyframes_in_range) + if adjusted_start != adjusted_end: + adjusted_clip_spans.append([adjusted_start, adjusted_end]) + return adjusted_clip_spans + + +def _adjust_clip_spans( + clip_spans: List[ClipSpan], + keyframe_timestamps: Union[List[float], None], + min_length: float, + max_length: float, + max_length_strategy: str, +) -> List[ClipSpan]: + """Adjusts cut times around keyframes, filtering by min and max length""" + if not isinstance(clip_spans[0], Iterable): # make sure clip_spans looks like [[start, end]] and not [start, end] + clip_spans = cast(List[ClipSpan], [clip_spans]) + clip_spans = [[_get_seconds(s), _get_seconds(e)] for [s, e] in clip_spans] + + if keyframe_timestamps: + clip_spans = _adjust_clip_spans_to_keyframes(clip_spans, keyframe_timestamps) + + filtered_clip_spans = [] + for s, e in clip_spans: + max_len_clip_spans = _split_time_frame(s, e, min_length, max_length) + if max_length_strategy == "first": + max_len_clip_spans = max_len_clip_spans[:1] + filtered_clip_spans += max_len_clip_spans + return filtered_clip_spans + + +def _collate_clip_spans(clip_spans: List[ClipSpan]) -> Tuple[str, List[int]]: + """Collates clip spans into a single string for ffmpeg and a list of clip idxs""" + clip_times = [] + clip_idxs = [] + e_prev = 0.0 + clip_idx = 0 + + for s, e in clip_spans: + if s == e_prev: # clip starts where last one left off + clip_times += [e] + clip_idxs.append(clip_idx) + clip_idx += 1 + else: # next clip skips over some time + clip_times += [s, e] + clip_idxs.append(clip_idx + 1) + clip_idx += 2 + e_prev = e + + clip_times_str = ",".join([str(time) for time in clip_times]) + return clip_times_str, clip_idxs + + +def _process_stream( + tmpdir: Any, # BytesPath + stream_bytes: bytes, + encode_format: str, + ffmpeg_kwargs: dict, +) -> List[str]: + """Processes a stream into clips using ffmpeg""" + # TODO: we need to put the extension into the metadata + # TODO: This can be done better using pipes I just don't feel like sinking too much time into this rn + with open(os.path.join(tmpdir, f"input.{encode_format}"), "wb") as f: + f.write(stream_bytes) + try: + ( + ffmpeg.input(f"{tmpdir}/input.{encode_format}") + .output(f"{tmpdir}/clip_%d.{encode_format}", **ffmpeg_kwargs) + .run(capture_stdout=True, quiet=True) + ) + except Exception as err: # pylint: disable=broad-except + raise err + stream_clips = glob.glob(f"{tmpdir}/clip*.{encode_format}") + stream_clips.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) + return stream_clips + + +def _extract_subtitles(clip_span: ClipSpan, meta_clip: dict) -> List[dict]: + """Extracts subtitles and groups them by language""" + clip_subtitles: List[dict] = [] + s_c, e_c = _get_seconds(clip_span[0]), _get_seconds(clip_span[1]) + for lang_id, (lang, subtitles) in enumerate(meta_clip["yt_meta_dict"]["subtitles"].items()): + idx = 0 + for line in subtitles: + line_dict = {lang: line["lines"]} + s, e = _get_seconds(line["start"]), _get_seconds(line["end"]) + if max(s_c, s) < min(e_c, e): + if lang_id != 0: + clip_subtitles[idx]["lines"].update(line_dict) + idx += 1 + else: + temp_line = copy.deepcopy(line) + temp_line["lines"] = line_dict + clip_subtitles.append(temp_line) + elif s > e_c: + break + return clip_subtitles + + +def _get_clip_metadata( + clip_spans: List[ClipSpan], + clip_idxs: List[int], + metadata: dict, + oom_clip_count: int, + strtime_formatting: bool, +) -> List[dict]: + """Gets metadata for each clip""" + metadata_clips = [] + for clip_id, (clip_span, _) in enumerate(zip(clip_spans, clip_idxs)): + clip_key = "{clip_id:0{oom_clip_count}d}".format( # pylint: disable=consider-using-f-string + clip_id=clip_id, oom_clip_count=oom_clip_count + ) + meta_clip = copy.deepcopy(metadata) + # set the timeframe of this clip + if strtime_formatting: + # Keep clip_spans in the original format to be compatible with the data schema. + meta_clip["clips"] = [(_get_strtime(clip_span[0]), _get_strtime(clip_span[1]))] + else: + meta_clip["clips"] = [clip_span] + meta_clip["key"] = f"{meta_clip['key']}_{clip_key}" + + yt_md_dict = meta_clip.get("yt_meta_dict", {}) + if (yt_md_dict is not None) and (yt_md_dict.get("subtitles", None) is not None): + meta_clip["clip_subtitles"] = _extract_subtitles(clip_span, meta_clip) + metadata_clips.append(meta_clip) + + # remove redundant metadata from clips after the first + for m_clips in metadata_clips[1:]: + m_clips["yt_meta_dict"] = {} + + return metadata_clips + + +def _get_clips( + streams: Streams, + encode_formats: EncodeFormats, + precision: str, + clip_spans: List[ClipSpan], + metadata: dict, + oom_clip_count: int, + strtime_formatting: bool, +) -> Tuple[Dict[str, List[bytes]], List[dict]]: + """Gets clips from streams""" + clip_times, clip_idxs = _collate_clip_spans(clip_spans) + + ffmpeg_kwargs = { + "map": 0, + "f": "segment", + "segment_times": clip_times, + "reset_timestamps": 1, + } + if precision == "exact": + ffmpeg_kwargs["force_key_frames"] = clip_times + else: + ffmpeg_kwargs["c"] = "copy" + + clips: Dict[str, List[bytes]] = {} + for k in streams.keys(): + k = cast(Literal["audio", "video"], k) + with tempfile.TemporaryDirectory() as tmpdir: + stream_bytes = streams[k][0] # pre-broadcast so only one + if stream_bytes is None: + continue + try: + stream_clips = _process_stream( + tmpdir=tmpdir, + stream_bytes=stream_bytes, + encode_format=encode_formats[k], + ffmpeg_kwargs=ffmpeg_kwargs, + ) + except Exception as err: # pylint: disable=broad-except + raise err + + clips[k] = [] + for clip_idx in clip_idxs: + with open(stream_clips[clip_idx], "rb") as vid_f: + clip_bytes = vid_f.read() + clips[k].append(clip_bytes) + + clip_metadata = _get_clip_metadata( + clip_spans=clip_spans, + clip_idxs=clip_idxs, + metadata=metadata, + oom_clip_count=oom_clip_count, + strtime_formatting=strtime_formatting, + ) + + return clips, clip_metadata + + +class ClippingSubsampler(Subsampler): + """ + Cuts videos up into segments according to the 'clips' metadata + + Parameters: + oom_clip_count: int + The number of orders of magnitude for clip count, used for formatting clip keys. + encode_formats: dict + A dictionary mapping stream keys to their corresponding file extensions, e.g., {"video": "mp4", "audio": "mp3"}. + min_length: float optional (default=0.0) + Minimum length in seconds of a clip. Below this the subsampler will reject the clips + max_length: float optional (default=999999.0) + Maximum clip length, if exceeded resolve according to max_length_strategy + max_length_strategy: str optional (defaul="all") + "all" - cut up long clip into as many clips of max_length as possible + "first" - take the first max_length clip from the long clip + precision: str, optional (default="low") + "low" - splits can be imprecise in any direction + "keyframe_adjusted" - translates cuts into the vocab of existing keyframes (a good middlepoint) + useful if you need to do fast clipping but information can't cross cut boundries + "exact" - keyframes are inserted to get exact splits (warning, slow) + + expects: + - clips to be sorted in increasing order and non-overlapping + - time to be in the format "%H:%M:%S.%f", or a number representing the second of the timestamp + """ + + def __init__( + self, + oom_clip_count: int, + encode_formats: EncodeFormats, + min_length: float = 0.0, + max_length: float = 999999.0, + max_length_strategy: Literal["all", "first"] = "all", + precision: Literal["low", "keyframe_adjusted", "exact"] = "low", + ): + assert max_length_strategy in ["all", "first"] + assert precision in ["exact", "low", "keyframe_adjusted"] + self.oom_clip_count = oom_clip_count + self.encode_formats = encode_formats + self.min_length = min_length + self.max_length = max_length + self.max_length_strategy = max_length_strategy + self.precision = precision + + def __call__(self, streams: Streams, metadata: dict): + strtime_formatting = isinstance(metadata["clips"][0][0], str) + + clip_spans = _adjust_clip_spans( + clip_spans=metadata.pop("clips"), + keyframe_timestamps=( + # TODO: make it so if keyframe timestamps not present, get it yourself + metadata["video_metadata"].pop("keyframe_timestamps") + if self.precision == "keyframe_adjusted" + else None + ), + min_length=self.min_length, + max_length=self.max_length, + max_length_strategy=self.max_length_strategy, + ) + if len(clip_spans) == 0: + return {}, [], f"Video had no clip_spans longer than {self.min_length}" + + try: + clips, clip_metadata = _get_clips( + streams=streams, + encode_formats=self.encode_formats, + precision=self.precision, + clip_spans=clip_spans, + metadata=metadata, + oom_clip_count=self.oom_clip_count, + strtime_formatting=strtime_formatting, + ) + except Exception as err: # pylint: disable=broad-except + return {}, [], str(err) + + return clips, clip_metadata, None diff --git a/video2dataset/subsamplers/cut_detection_subsampler.py b/video2dataset/subsamplers/cut_detection_subsampler.py new file mode 100644 index 00000000..a5349f7d --- /dev/null +++ b/video2dataset/subsamplers/cut_detection_subsampler.py @@ -0,0 +1,97 @@ +""" +cut detection subsampler detects cuts in a video +""" +import numpy as np +from scenedetect import ContentDetector, SceneManager, open_video +import os +import tempfile + +from .subsampler import Subsampler + +# TODO: this can be done more elegantly: +# from scenedetect import scene_manager and set that in correct namespace +# best solution is just figure out best value for them and submit PR +DEFAULT_MIN_WIDTH = 64 + + +def get_scenes_from_scene_manager(scene_manager, cut_detection_mode): + """ + Returns a list of cuts from a scene manager given a cut detection mode + """ + scene_list = scene_manager.get_scene_list(start_in_scene=True) + scene = [] + + for clip in scene_list: + scene.append([clip[0].get_frames(), clip[1].get_frames()]) + + if cut_detection_mode == "longest": # we have multiple cuts, pick the longest + longest_clip = np.argmax([clip[1] - clip[0] for clip in scene]) + scene = [scene[longest_clip]] + + return scene + + +class CutDetectionSubsampler(Subsampler): + """ + Detects cuts in input videos and returns contiguous segments in a video as metadata. + + non-initialization args: + - cuts_are_clips: whether to create video clips from the source video based on cuts + + initialization args: + - cut_detection_mode to be either "longest" to pick the longest cut or "all" to pick all cuts + - framerates to be None (for original fps only) or a list of target framerates to detect cuts in + - threshold - determines roughly how much motion is required for a "cut" (tunable parameter) + - min_scene_len - minimum scene length to not drop a scene (see pyscenedeteect docs for more explanation) + """ + + def __init__(self, cut_detection_mode="all", framerates=None, threshold=27, min_scene_len=15): + self.framerates = framerates + self.cut_detection_mode = cut_detection_mode + self.threshold = threshold + self.min_scene_len = min_scene_len + + def __call__(self, streams, metadata=None): + video_bytes = streams["video"][0] + + try: + with tempfile.TemporaryDirectory() as tmpdir: + video_path = os.path.join(tmpdir, "input.mp4") + with open(video_path, "wb") as f: + f.write(video_bytes) + + video = open_video(video_path) + + detector = ContentDetector(threshold=self.threshold, min_scene_len=self.min_scene_len) + scene_manager = SceneManager() + scene_manager.add_detector(detector) + scene_manager.auto_downscale = False + scene_manager.downscale = video.frame_size[0] // DEFAULT_MIN_WIDTH + + cuts = {} + original_fps = video.frame_rate + cuts["original_fps"] = original_fps + + scene_manager.detect_scenes(video=video) + cuts["cuts_original_fps"] = get_scenes_from_scene_manager(scene_manager, self.cut_detection_mode) + if self.framerates is not None: + for target_fps in self.framerates: + video.reset() + + detector = ContentDetector(threshold=self.threshold, min_scene_len=self.min_scene_len) + scene_manager = SceneManager() + scene_manager.add_detector(detector) + frame_skip = max( + int(original_fps // target_fps) - 1, 0 + ) # if we take 1 frame and skip N frames we're sampling 1/N+1 % of the video + # so if we desire to sample 1/N of the video, we need to subtract one when doing frame skipping + + scene_manager.detect_scenes(video=video, frame_skip=frame_skip) + cuts[f"cuts_{target_fps}"] = get_scenes_from_scene_manager( + scene_manager, self.cut_detection_mode + ) + scene_manager.clear() + except Exception as err: # pylint: disable=broad-except + return {}, None, str(err) + + return streams, cuts, None diff --git a/video2dataset/subsamplers/ffprobe_subsampler.py b/video2dataset/subsamplers/ffprobe_subsampler.py new file mode 100644 index 00000000..52c5b61e --- /dev/null +++ b/video2dataset/subsamplers/ffprobe_subsampler.py @@ -0,0 +1,58 @@ +"""extracts basic video compression metadata.""" +import os +import json +import subprocess +import tempfile + +from .subsampler import Subsampler + + +# TODO: figuer out why this is so slow (12 samples/s) +class FFProbeSubsampler(Subsampler): + """ + Extracts metadata from bytes. + Args: + extract_keyframes (bool): Whether to extract keyframe timestamps. + """ + + def __init__(self, extract_keyframes=False): + self.extract_keyframes = extract_keyframes + + def __call__(self, streams, metadata): + # TODO: this should also work for audio (maybe others) + video_bytes = streams["video"][0] + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "input.mp4"), "wb") as f: + f.write(video_bytes) + try: + command = [ + "ffprobe", + "-v", + "quiet", + "-print_format", + "json", + "-show_format", + "-show_streams", + f"{tmpdir}/input.mp4", + ] + + if self.extract_keyframes: + command.extend(["-select_streams", "v:0", "-show_entries", "packet=pts_time,flags"]) + + process = subprocess.run(command, capture_output=True, text=True, check=True) + video_metadata = json.loads(process.stdout) + + if self.extract_keyframes: + keyframe_info = [entry for entry in video_metadata["packets"] if "K" in entry.get("flags", "")] + keyframe_timestamps = [float(entry["pts_time"]) for entry in keyframe_info] + if "duration" in video_metadata["format"]: + duration = float(video_metadata["format"]["duration"]) + keyframe_timestamps.append(duration) + video_metadata["keyframe_timestamps"] = keyframe_timestamps + video_metadata.pop("packets") # Don't need it anymore + metadata["video_metadata"] = video_metadata + + except Exception as err: # pylint: disable=broad-except + return streams, metadata, str(err) + + return streams, metadata, None diff --git a/video2dataset/subsamplers/frame_subsampler.py b/video2dataset/subsamplers/frame_subsampler.py new file mode 100644 index 00000000..e3c1b572 --- /dev/null +++ b/video2dataset/subsamplers/frame_subsampler.py @@ -0,0 +1,90 @@ +""" +frame subsampler adjusts the fps of the videos to some constant value +""" +import tempfile +import os +import copy +import ffmpeg + +from .subsampler import Subsampler +from .clipping_subsampler import _get_seconds + + +class FrameSubsampler(Subsampler): + """ + Adjusts the frame rate of the videos to the specified frame rate. + Subsamples the frames of the video in terms of spacing and quantity (frame_rate, which ones etc.) + Args: + frame_rate (int): Target frame rate of the videos. + downsample_method (str): determiens how to downsample the video frames: + fps: decreases the framerate but sample remains a valid video + first_frame: only use the first frame of a video of a video and output as image + yt_subtitle: temporary special case where you want a frame at the beginning of each yt_subtitle + we will want to turn this into something like frame_timestamps and introduce + this as a fusing option with clipping_subsampler + encode_format (str): Format to encode in (i.e. mp4) + + TODO: n_frame + TODO: generalize interface, should be like (frame_rate, n_frames, sampler, output_format) + # frame_rate - spacing + # n_frames - quantity + # sampler - from start, end, center out + # output_format - save as video, or images + """ + + def __init__(self, frame_rate, downsample_method="fps", encode_format="mp4"): + self.frame_rate = frame_rate + self.downsample_method = downsample_method + self.output_modality = "video" if downsample_method == "fps" else "jpg" + self.encode_format = encode_format + + def __call__(self, streams, metadata=None): + # TODO: you might not want to pop it (f.e. in case of other subsamplers) + video_bytes = streams.pop("video") + subsampled_bytes, subsampled_metas = [], [] + for i, vid_bytes in enumerate(video_bytes): + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "input.mp4"), "wb") as f: + f.write(vid_bytes) + try: + ext = "mp4" + if self.downsample_method == "fps": + _ = ffmpeg.input(f"{tmpdir}/input.mp4") + _ = _.filter("fps", fps=self.frame_rate) + _ = _.output(f"{tmpdir}/output.mp4", reset_timestamps=1).run(capture_stdout=True, quiet=True) + elif "frame" in self.downsample_method: + _ = ffmpeg.input(f"{tmpdir}/input.mp4") + _ = _.filter("select", "eq(n,0)") + _ = _.output(f"{tmpdir}/output.jpg").run(capture_stdout=True, quiet=True) + ext = "jpg" + elif self.downsample_method == "yt_subtitle": + subtitles = metadata[i]["yt_meta_dict"]["subtitles"] + starts = [_get_seconds(s["start"]) for s in subtitles] + + for frame_id, start_t in enumerate(starts): + frame_key = f"{frame_id:04d}" + meta_frame = copy.deepcopy(metadata[i]) + + meta_frame["frame_time"] = subtitles[frame_id]["start"] + meta_frame["frame_subtitle"] = subtitles[frame_id]["lines"] + meta_frame["key"] = f"{meta_frame['key']}_{frame_key}" + + _ = ffmpeg.input(f"{tmpdir}/input.mp4", ss=start_t) + _ = _.output(f"{tmpdir}/frame_{frame_id}.jpg", vframes=1, **{"q:v": 2}).run( + capture_stdout=True, quiet=True + ) + with open(f"{tmpdir}/frame_{frame_id}.jpg", "rb") as f: + subsampled_bytes.append(f.read()) + subsampled_metas.append(meta_frame) + + except Exception as err: # pylint: disable=broad-except + return [], None, str(err) + + if self.downsample_method != "yt_subtitle": + with open(f"{tmpdir}/output.{ext}", "rb") as f: + subsampled_bytes.append(f.read()) + else: + metadata = subsampled_metas + + streams[self.output_modality] = subsampled_bytes + return streams, metadata, None diff --git a/video2dataset/subsamplers/noop_subsampler.py b/video2dataset/subsamplers/noop_subsampler.py new file mode 100644 index 00000000..b12975a8 --- /dev/null +++ b/video2dataset/subsamplers/noop_subsampler.py @@ -0,0 +1,11 @@ +"""No operation subsampler""" + +from .subsampler import Subsampler + + +class NoOpSubsampler(Subsampler): + def __init__(self): + pass + + def __call__(self, streams, metadata): + return streams, [metadata], None diff --git a/video2dataset/subsamplers/optical_flow_subsampler.py b/video2dataset/subsamplers/optical_flow_subsampler.py new file mode 100644 index 00000000..837f3fb9 --- /dev/null +++ b/video2dataset/subsamplers/optical_flow_subsampler.py @@ -0,0 +1,265 @@ +""" +optical flow detection +""" +import cv2 +import numpy as np +import os + +try: + from raft import RAFT + from raft.utils.utils import InputPadder + import torch +except: # pylint: disable=broad-except,bare-except + pass + +from .subsampler import Subsampler + + +class AttrDict(dict): + """ + Lets us access dict keys with .key + """ + + # pylint: disable=super-with-arguments + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def resize_image_with_aspect_ratio(image, target_shortest_side=16): + """ + Resize an input image while maintaining its aspect ratio. + + This function takes an image and resizes it so that its shortest side + matches the specified target length. The other side is scaled accordingly + to maintain the original aspect ratio of the image. The function returns + the resized image and the scaling factor used for resizing. + + Parameters + ---------- + image : numpy.ndarray + The input image represented as a NumPy array with shape (height, width, channels). + target_shortest_side : int, optional + The desired length for the shortest side of the resized image (default is 16). + + Returns + ------- + resized_image : numpy.ndarray + The resized image with the same number of channels as the input image. + scaling_factor : float + The scaling factor used to resize the image. + """ + # Get the original dimensions of the image + height, width = image.shape[:2] + + # Calculate the new dimensions while maintaining the aspect ratio + if height < width: + new_height = target_shortest_side + new_width = int(width * (target_shortest_side / height)) + scaling_factor = height / new_height + else: + new_width = target_shortest_side + new_height = int(height * (target_shortest_side / width)) + scaling_factor = width / new_width + # Resize the image using the calculated dimensions + resized_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) + + return resized_image, scaling_factor + + +class RAFTDetector: + """ + Optical flow detection using RAFT. Steps to setup RAFT: + + 1. clone this repo fork: https://github.com/danielmend/RAFT + 2. in the root directory of the repo, run `pip install -e .` to install RAFT locally + 3. download optical flow models using the `download.sh` script in the repo, i.e. run `bash download.sh` + 4. when passing in args to the video2dataset call, be sure to include optical flow params as follows: + args = { + 'model': '/path/to/optical/flow.pth', + 'path': None, + 'small': False, + 'mixed_precision': True, + 'alternate_corr': False, + } + """ + + def __init__(self, args, downsample_size=None): + self.device = args.get("device", "cuda") + self.downsample_size = downsample_size + + model = RAFT(args) + + state_dict = torch.load(args.model) + real_state_dict = {k.split("module.")[-1]: v for k, v in state_dict.items()} + + model.load_state_dict(real_state_dict) + + model.to(self.device) + model.eval() + self.model = model + + def preprocess(self, frame1, frame2): + """ + Preprocesses a pair of input frames for use with the RAFT optical flow detector. + This function takes in two input frames, resizes them while maintaining the aspect + ratio (if downsample_size is set), converts the frames to uint8 data type, + creates torch tensors from the frames, permutes the axes, and pads the frames + to the same size using the InputPadder. + + Args: + frame1 (np.ndarray): The first input frame as a NumPy array with shape (height, width, 3). + frame2 (np.ndarray): The second input frame as a NumPy array with shape (height, width, 3). + + Returns: + tuple: A tuple containing three elements: + - frame1 (torch.Tensor): The preprocessed first input frame as + a torch tensor with shape (1, 3, height, width). + + - frame2 (torch.Tensor): The preprocessed second input frame as a + torch tensor with shape (1, 3, height, width). + + - scaling_factor (float): The scaling factor used to resize the input + frames while maintaining the aspect ratio. If no resizing is performed, + the scaling factor is 1. + """ + frame1 = frame1.astype(np.uint8) + frame1 = torch.from_numpy(frame1).permute(2, 0, 1).float() + frame1 = frame1[None].to(self.device) + + frame2 = frame2.astype(np.uint8) + frame2 = torch.from_numpy(frame2).permute(2, 0, 1).float() + frame2 = frame2[None].to(self.device) + + padder = InputPadder(frame1.shape) + + frame1, frame2 = padder.pad(frame1, frame2) + return frame1, frame2 + + def __call__(self, frame1, frame2): + frame1, frame2 = self.preprocess(frame1, frame2) + with torch.no_grad(): + _, flow_up = self.model(frame1, frame2, iters=20, test_mode=True) + return flow_up[0].permute(1, 2, 0).cpu().numpy() + + +class Cv2Detector: + """ + A class to perform optical flow detection using OpenCV's Farneback method. + + Attributes: + pyr_scale (float): The pyramid scale. Defaults to 0.5. + levels (int): The number of pyramid layers. Defaults to 3. + winsize (int): The window size. Defaults to 15. + iterations (int): The number of iterations. Defaults to 3. + poly_n (int): The size of the pixel neighborhood. Defaults to 5. + poly_sigma (float): The standard deviation of the Gaussian. Defaults to 1.2. + flags (int): Additional flags for the cv2.calcOpticalFlowFarneback function. Defaults to 0. + """ + + def __init__( + self, + pyr_scale=0.5, + levels=3, + winsize=15, + iterations=3, + poly_n=5, + poly_sigma=1.2, + flags=0, + downsample_size=None, + ): + self.pyr_scale = pyr_scale + self.levels = levels + self.winsize = winsize + self.iterations = iterations + self.poly_n = poly_n + self.poly_sigma = poly_sigma + self.flags = flags + self.downsample_size = downsample_size + + def preprocess(self, frame): + out = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + return out + + def __call__(self, frame1, frame2): + """ + Calculate optical flow between two frames using Farneback method. + + Args: + frame1 (numpy.ndarray): The first frame (grayscale). + frame2 (numpy.ndarray): The second frame (grayscale). + + Returns: + numpy.ndarray: The computed optical flow. + """ + frame1 = self.preprocess(frame1) + frame2 = self.preprocess(frame2) + + return cv2.calcOpticalFlowFarneback( + frame1, + frame2, + None, + self.pyr_scale, + self.levels, + self.winsize, + self.iterations, + self.poly_n, + self.poly_sigma, + self.flags, + ) + + +class OpticalFlowSubsampler(Subsampler): + """ + A class to detect optical flow in video frames. + + Attributes: + detector (Cv2Detector or RAFTDetector): The optical flow detector. + fps (int): The target frames per second. Defaults to -1 (original FPS). + """ + + def __init__( + self, + detector="cv2", + detector_args=None, + dtype="fp16", + is_slurm_task=False, + ): + if detector == "cv2": + detector_args = () if detector_args is None else detector_args + self.detector = Cv2Detector(*detector_args) + elif detector == "raft": + assert detector_args is not None + if is_slurm_task: + local_rank = os.environ["LOCAL_RANK"] + device = f"cuda:{local_rank}" + detector_args["device"] = device + if not isinstance(detector_args, AttrDict): + detector_args = AttrDict(args) + self.detector = RAFTDetector(detector_args) + else: + raise NotImplementedError() + + dtypes = {"fp16": np.float16, "fp32": np.float32} + self.dtype = dtypes[dtype] + + def __call__(self, frames): + optical_flow = [] + + try: + frame1 = frames[0] + prvs = frame1 + + for frame2 in frames[1:]: + next_frame = frame2 + flow = self.detector(prvs, next_frame) + optical_flow.append(flow) + prvs = next_frame + + opt_flow = np.array(optical_flow) + mean_magnitude_per_frame = np.linalg.norm(opt_flow, axis=-1).mean(axis=(1, 2)) + mean_magnitude = float(mean_magnitude_per_frame.mean()) + metrics = [mean_magnitude, mean_magnitude_per_frame.tolist()] + return opt_flow.astype(self.dtype), metrics, None + except Exception as err: # pylint: disable=broad-except + return [], None, str(err) diff --git a/video2dataset/subsamplers/resolution_subsampler.py b/video2dataset/subsamplers/resolution_subsampler.py new file mode 100644 index 00000000..9f95969b --- /dev/null +++ b/video2dataset/subsamplers/resolution_subsampler.py @@ -0,0 +1,70 @@ +""" +resolution subsampler adjusts the resolution of the videos to some constant value +""" +import os +import ffmpeg +import tempfile +from typing import Literal + +from .subsampler import Subsampler + + +class ResolutionSubsampler(Subsampler): + """ + Adjusts the resolution of the videos to the specified height and width. + Please do not set both video_size and height/width. This will result in an error. + If both height and width are set, scale mode output will have the specified height (ignoring width). + + Args: + resize_mode (list[str]): List of resize modes to apply. Possible options are: + scale: scale video keeping aspect ratios + crop: center crop to video_size x video_size + pad: center pad to video_size x video_size + height (int): Height of video. + width (int): Width of video. + video_size (int): Both height and width. + encode_format (str): Format to encode in (i.e. mp4) + """ + + def __init__( + self, + resize_mode: Literal["scale", "crop", "pad"], + height: int = -1, + width: int = -1, + video_size: int = -1, + encode_format: str = "mp4", + ): + if video_size > 0 and (height > 0 or width > 0): + raise ValueError("Either set video_size, or set height and/or width") + self.resize_mode = resize_mode + self.height = height if video_size < 0 else video_size + self.width = width if video_size < 0 else video_size + self.video_size = video_size + self.encode_format = encode_format + + def __call__(self, streams, metadata=None): + video_bytes = streams["video"] + subsampled_bytes = [] + for vid_bytes in video_bytes: + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "input.mp4"), "wb") as f: + f.write(vid_bytes) + try: + _ = ffmpeg.input(f"{tmpdir}/input.mp4") + if "scale" in self.resize_mode: + if self.height > 0: + _ = _.filter("scale", -2, self.height) + else: + _ = _.filter("scale", self.width, -2) + if "crop" in self.resize_mode: + _ = _.filter("crop", w=self.width, h=self.height) + if "pad" in self.resize_mode: + _ = _.filter("pad", w=self.width, h=self.height) + _ = _.output(f"{tmpdir}/output.mp4", reset_timestamps=1).run(capture_stdout=True, quiet=True) + except Exception as err: # pylint: disable=broad-except + return [], None, str(err) + + with open(f"{tmpdir}/output.mp4", "rb") as f: + subsampled_bytes.append(f.read()) + streams["video"] = subsampled_bytes + return streams, metadata, None diff --git a/video2dataset/subsamplers/subsampler.py b/video2dataset/subsamplers/subsampler.py new file mode 100644 index 00000000..0d5461ce --- /dev/null +++ b/video2dataset/subsamplers/subsampler.py @@ -0,0 +1,10 @@ +"""Base subsampler class""" +from abc import abstractmethod + + +class Subsampler: + """Subsamples input and returns in same format (stream dict + metadata)""" + + @abstractmethod + def __call__(self, streams, metadata): + raise NotImplementedError("Subsampler should not be called") diff --git a/video2dataset/subsamplers/whisper_subsampler.py b/video2dataset/subsamplers/whisper_subsampler.py new file mode 100644 index 00000000..842a8ce5 --- /dev/null +++ b/video2dataset/subsamplers/whisper_subsampler.py @@ -0,0 +1,88 @@ +""" +Whisper subsampler - transcribes audio using the Whisper model from OAI using WhisperX API + +code: https://github.com/m-bain/whisperX +""" +import os +import time +import tempfile + +try: + import whisperx + import torch +except: # pylint: disable=broad-except,bare-except + pass + +from .subsampler import Subsampler + + +class WhisperSubsampler(Subsampler): + """ + Transcribes audio samples using the OAI Whisper Model via WhisperX API + + Params: + model_name: https://github.com/guillaumekln/faster-whisper/blob/20d4e9418b5efb69ec5aa4819a39e3fb0e772a2a/faster_whisper/transcribe.py#LL90C1-L90C1 + batch_size: batch size used during inference (try to maximize this for perf) + compute_type: accuracy/mem tradeoff (float16, float32, int8) + """ + + def __init__( + self, + model_name="large-v2", + batch_size=16, + compute_type="float16", + download_root=None, + is_slurm_task=False, + ): + if is_slurm_task: + global_rank = os.environ["GLOBAL_RANK"] + if global_rank != 0: + time.sleep(20) # let master worker download model + + device, device_index = "cuda", int(os.environ["LOCAL_RANK"]) + while True: + try: + self.model = whisperx.load_model( + model_name, + device=device, + device_index=device_index, + compute_type=compute_type, + download_root=download_root, + ) + print("model_loaded", os.environ["GLOBAL_RANK"], flush=True) + break + except Exception as e: # pylint: disable=(broad-except) + print(str(e), flush=True) + print( + "loading failed, retrying...", + os.environ["GLOBAL_RANK"], + flush=True, + ) + continue + else: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = whisperx.load_model( + model_name, + device=device, + compute_type=compute_type, + download_root=download_root, + ) + + self.batch_size = batch_size + + def __call__(self, streams, metadata=None): + audio_bytes = streams.get("audio") + + for i, aud_bytes in enumerate(audio_bytes): + # TODO: .m4a not always + try: + with tempfile.NamedTemporaryFile(suffix=".mp3") as tmpfile: + tmpfile.write(aud_bytes) + tmpfile.flush() # ensure all data is written + audio = whisperx.load_audio(tmpfile.name) + result = self.model.transcribe(audio, batch_size=self.batch_size) + metadata[i]["whisper_transcript"] = result + except Exception as err: # pylint: disable=broad-except + return [], metadata, str(err) + + return streams, metadata, None diff --git a/video2dataset/types.py b/video2dataset/types.py new file mode 100644 index 00000000..03ae7c1e --- /dev/null +++ b/video2dataset/types.py @@ -0,0 +1,16 @@ +"""Type definitions for video2dataset.""" +from typing import List, TypedDict + + +class EncodeFormats(TypedDict, total=False): + video: str + audio: str + + +class Streams(TypedDict, total=False): + video: List[bytes] + audio: List[bytes] + + +# TODO: make more structured +Metadata = dict diff --git a/video2dataset/workers/__init__.py b/video2dataset/workers/__init__.py new file mode 100644 index 00000000..ec591e1b --- /dev/null +++ b/video2dataset/workers/__init__.py @@ -0,0 +1,7 @@ +"""video2dataset stage workers""" + +from .download_worker import DownloadWorker +from .subset_worker import SubsetWorker +from .optical_flow_worker import OpticalFlowWorker +from .whisper_worker import WhisperWorker +from .caption_worker import CaptionWorker diff --git a/video2dataset/workers/caption_worker.py b/video2dataset/workers/caption_worker.py new file mode 100644 index 00000000..ad47120d --- /dev/null +++ b/video2dataset/workers/caption_worker.py @@ -0,0 +1,226 @@ +""" +Worker for caption stage +""" +import time +import pyarrow as pa +import traceback +import numpy as np +import torch +import fsspec +import webdataset as wds + +from video2dataset.logger import CappedCounter, write_stats +from video2dataset.subsamplers import CaptionSubsampler +from video2dataset.dataloader import get_video_dataset + + +class CaptionWorker: + """ + A class to read shards, process them using CaptionSubsampler, and write the output. + + Attributes: + sample_writer_class (type): The class used to write samples. + output_folder (str): The folder to write the output. + thread_count (int): The number of threads. + number_sample_per_shard (int): The number of samples per shard. + oom_shard_count (int): The number of out-of-memory shards. + encode_formats (dict): The encoding formats. + fps (int): The target frames per second. + caption_subsampler (CaptionSubsampler): The CaptionSubsampler instance. + """ + + def __init__( + self, + sample_writer_class, + output_folder, + encode_formats, + is_slurm_task, + config, + ) -> None: + self.sample_writer_class = sample_writer_class + self.output_folder = output_folder + self.encode_formats = encode_formats + self.save_caption = True + self.config = config + + self.caption_subsampler = CaptionSubsampler( + **self.config["subsampling"]["CaptionSubsampler"]["args"], + is_slurm_task=is_slurm_task, + ) + + def __call__( + self, + row, + ): + try: + self.process_shard(row) + return (True, row) + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"shard {row[0]} failed with error {err}") + return (False, row) + + def process_shard( + self, + row, + ): + """ + Process a video shard using the CaptionSubsampler. + + Args: + row (tuple): A tuple containing the shard and shard_id. + + Raises: + Except + """ + shard, shard_id = row + start_time = time.time() + + try: + fs, shard_path = fsspec.core.url_to_fs(shard[: -len(".tar")] + ".parquet") + + with fs.open(shard_path, "rb") as f: + df = pa.parquet.read_table(f) + schema = df.schema + except Exception as e: # pylint: disable=broad-except,unused-variable + fields = [ + pa.field("key", pa.string()), + pa.field("status", pa.string()), + pa.field("error_message", pa.string()), + ] + schema = pa.schema(fields) + + status_dict = CappedCounter() + + # give schema to writer + sample_writer = self.sample_writer_class( + shard_id, + self.output_folder, + self.save_caption, + self.config["storage"]["oom_shard_count"], + schema, + self.encode_formats, + ) + + successes = 0 + failed_to_subsample = 0 + + dset = get_video_dataset( + urls=shard, + batch_size=self.config["reading"]["dataloader_args"]["batch_size"], + decoder_kwargs=self.config["reading"]["dataloader_args"]["decoder_kwargs"], + resize_size=self.config["reading"]["dataloader_args"]["resize_size"], + crop_size=None, + enforce_additional_keys=[], + return_always=True, + handler=wds.warn_and_continue, + ) + count = 0 + for sample in dset: + batch_size = len(sample["__key__"]) + count += batch_size + + bad_batch_idx = [] + for batch_idx in range(batch_size): + corrupted = sample["__corrupted__"][batch_idx] + key = sample["__key__"][batch_idx] + if corrupted: + bad_batch_idx.append(batch_idx) + url = sample["__url__"][batch_idx] + meta = {} + meta["url"] = url + meta["key"] = key + error_message = "corrupted sample" + failed_to_subsample += 1 + status = "failed_to_subsample" + status_dict.increment(error_message) + meta["status"] = status + meta["error_message"] = error_message + meta["__corrupted__"] = True + sample_writer.write( + {}, + key, + None, + meta, + ) + continue + + if len(bad_batch_idx) != 0: + bad_batch_idx.sort(reverse=True) + for key in sample: + if isinstance(sample[key], list): + for i in bad_batch_idx: + del sample[key][i] + elif isinstance(sample[key], np.ndarray): + mask = np.ones(len(sample[key]), dtype=bool) + mask[bad_batch_idx] = False + sample[key] = sample[key][mask] + elif isinstance(sample[key], torch.Tensor): + mask = torch.ones(len(sample[key]), dtype=bool) + mask[bad_batch_idx] = False + sample[key] = sample[key][mask] + else: + raise TypeError(f"Unsupported data type: {type(sample[key])}") + + batch_size = len(sample["__key__"]) + # incase all elems are corrupted + if batch_size == 0: + continue + + meta = sample["json"] + streams = {} + caption, error_message = self.caption_subsampler(sample.get("mp4")) + + if error_message is not None: + for batch_idx in range(batch_size): + meta = sample["json"][batch_idx] + key = sample["__key__"][batch_idx] + failed_to_subsample += 1 + status = "failed_to_subsample" + status_dict.increment(error_message) + meta["status"] = status + meta["error_message"] = error_message + meta["__corrupted__"] = True + sample_writer.write( + streams, + key, + None, + meta, + ) + # continue + continue + + for batch_idx in range(batch_size): + sample["json"][batch_idx]["vblip"] = caption[0][batch_idx] + + meta = sample["json"][batch_idx] + key = sample["__key__"][batch_idx] + successes += 1 + status = "success" + status_dict.increment(status) + meta["status"] = status + meta["__corrupted__"] = False + + sample_writer.write( + streams, + key, + caption[0][batch_idx], + meta, + ) + + sample_writer.close() + end_time = time.time() + + write_stats( + self.output_folder, + shard_id, + count, # count + successes, + 0, # failed to download + failed_to_subsample, + 0, # bytes downloaded + start_time, + end_time, + status_dict, + self.config["storage"]["oom_shard_count"], + ) diff --git a/video2dataset/workers/download_worker.py b/video2dataset/workers/download_worker.py new file mode 100644 index 00000000..3c1bdac4 --- /dev/null +++ b/video2dataset/workers/download_worker.py @@ -0,0 +1,191 @@ +"""the downloader module handles the downloading""" +import fsspec +import math +from multiprocessing.pool import ThreadPool +import pyarrow as pa +import time +import traceback +from typing import cast + +from video2dataset.data_reader import VideoDataReader +from video2dataset.logger import write_stats +from video2dataset.workers.worker import ShardStatus, Streams, get_subsamplers, process_sample + + +def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count): + true_key = (10**oom_sample_per_shard) * shard_id + key + key_format = oom_sample_per_shard + oom_shard_count + str_key = "{true_key:0{key_format}d}".format( # pylint: disable=consider-using-f-string + key_format=key_format, true_key=true_key + ) + return str_key + + +class DownloadWorker: + """The downloader class gets calls with shards, download them then call the writer to write them down""" + + def __init__( + self, + sample_writer_class, + save_caption, + output_folder, + column_list, + tmp_dir, + encode_formats, + config, + ) -> None: + self.sample_writer_class = sample_writer_class + self.save_caption = save_caption + self.output_folder = output_folder + self.column_list = column_list + self.input_encode_formats = encode_formats + self.config = config + self.data_reader = VideoDataReader(encode_formats, tmp_dir, config["reading"]) + self.url_indice = self.column_list.index("url") + self.caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None + self.oom_sample_per_shard = math.ceil(math.log10(self.config["storage"]["number_sample_per_shard"])) + self.subsamplers, self.output_encode_formats = get_subsamplers( + config, + encode_formats, + do_clipping=("clips" in self.column_list), + ) + + def __call__( + self, + row, + ): + try: + shard_file, shard_id = row + self.process_shard(shard_file, shard_id) + return (True, row) + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"shard {row[0]} failed with error {err}") + return (False, row) + + def get_shard_processors( + self, + shard_file: str, + shard_id: int, + ): + """Get objects for loading and writing data""" + + fs, shard_path = fsspec.core.url_to_fs(shard_file) + print(shard_path) + with fs.open(shard_path, "rb") as f: + df = pa.ipc.open_file(f).read_all() + schema = df.schema + schema = df.schema + schema = ( + schema.append(pa.field("key", pa.string())) + .append(pa.field("status", pa.string())) + .append(pa.field("error_message", pa.string())) + ) + shard_sample_writer = self.sample_writer_class( + shard_id, + self.output_folder, + self.save_caption, + self.config["storage"]["oom_shard_count"], + schema, + self.output_encode_formats, + ) + pydict = df.select(self.column_list).to_pydict() + shard_to_dl = list(enumerate(zip(*(pydict[col] for col in self.column_list)))) + + def rm_shard_path(): + fs.rm(shard_path) + + return shard_sample_writer, shard_to_dl, rm_shard_path + + def process_shard( + self, + shard_file: str, + shard_id: int, + ): + """Function to start an video downloading in one process""" + + start_time = time.time() + shard_sample_writer, shard_to_dl, rm_shard_path = self.get_shard_processors(shard_file, shard_id) + shard_status = ShardStatus(count=len(shard_to_dl)) + + def data_generator(): + for key_and_url in [(key, x[self.url_indice]) for key, x in shard_to_dl]: + yield key_and_url + + data_reader_call_param_generator = data_generator() + + with ThreadPool(self.config["distribution"]["thread_count"]) as thread_pool: + for key, streams, yt_meta_dict, shard_status.error_message in thread_pool.imap_unordered( + self.data_reader, # pylint: disable=(unnecessary-lambda) + data_reader_call_param_generator, + ): + try: + _, sample_data = shard_to_dl[key] + str_key = compute_key( + key, shard_id, self.oom_sample_per_shard, self.config["storage"]["oom_shard_count"] + ) + caption = sample_data[self.caption_indice] if self.caption_indice is not None else None + metadata = { + **{self.column_list[i]: sample_data[i] for i in range(len(self.column_list))}, + "key": str_key, + "status": None, + "error_message": shard_status.error_message, + "yt_meta_dict": yt_meta_dict, + } + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"Sample {key} failed to download: {err}") + return + + try: + if shard_status.error_message is not None: + print(shard_status.error_message) + if "[youtube]" in shard_status.error_message: # video-specific error, remove videoID + shard_status.error_message = "ERROR: [youtube]:" + shard_status.error_message.split(":")[-1] + raise ValueError + except Exception: # pylint: disable=broad-except + shard_status.failed["failed_to_download"] += 1 + shard_status.status_dict.increment(shard_status.error_message) + metadata["status"] = "failed_to_download" + metadata["error_message"] = shard_status.error_message + shard_sample_writer.write( + {}, + str_key, + sample_data[self.caption_indice] if self.caption_indice is not None else None, + metadata, + ) + return + + for stream in streams.values(): + shard_status.bytes_downloaded += len(stream) + for modality in streams: + streams[modality] = [streams[modality]] + + process_sample( + subsamplers=self.subsamplers, + shard_status=shard_status, + streams=cast(Streams, streams), + key=str_key, + caption=cast(str, caption), + metadata=metadata, + captions_are_subtitles=self.config["storage"]["captions_are_subtitles"], + shard_sample_writer=shard_sample_writer, + ) + + shard_sample_writer.close() + rm_shard_path() + end_time = time.time() + + write_stats( + self.output_folder, + shard_id, + shard_status.count, + shard_status.successes, + shard_status.failed["failed_to_download"], + shard_status.failed["failed_to_subsample"], + shard_status.bytes_downloaded, + start_time, + end_time, + shard_status.status_dict, + self.config["storage"]["oom_shard_count"], + ) diff --git a/video2dataset/workers/optical_flow_worker.py b/video2dataset/workers/optical_flow_worker.py new file mode 100644 index 00000000..b88833d7 --- /dev/null +++ b/video2dataset/workers/optical_flow_worker.py @@ -0,0 +1,216 @@ +""" +Worker for optical flow stage +""" +import time +import pyarrow as pa +import traceback +import io +import numpy as np +import fsspec +import webdataset as wds + +from video2dataset.logger import CappedCounter, write_stats +from video2dataset.subsamplers import OpticalFlowSubsampler +from video2dataset.dataloader import get_video_dataset + + +def numpy_npy_dumps(numpy_array): + """ + Dump a numpy array into a bytestring using numpy npy format. + + Args: + numpy_array (numpy.ndarray): A numpy array. + + Returns: + bytes: A bytestring representing the numpy array. + """ + + stream = io.BytesIO() + np.save(stream, numpy_array) + return stream.getvalue() + + +class OpticalFlowWorker: + """ + A class to read shards, process them using OpticalFlowSubsampler, and write the output. + + Attributes: + sample_writer_class (type): The class used to write samples. + output_folder (str): The folder to write the output. + thread_count (int): The number of threads. + number_sample_per_shard (int): The number of samples per shard. + oom_shard_count (int): The number of out-of-memory shards. + encode_formats (dict): The encoding formats. + detector (str): The optical flow detector type. + fps (int): The target frames per second. + optical_flow_subsampler (OpticalFlowSubsampler): The OpticalFlowSubsampler instance. + """ + + def __init__( + self, + sample_writer_class, + output_folder, + encode_formats, + is_slurm_task, + config, + ) -> None: + self.sample_writer_class = sample_writer_class + self.output_folder = output_folder + self.encode_formats = encode_formats + self.save_caption = False + self.config = config + + self.optical_flow_subsampler = OpticalFlowSubsampler( + **self.config["subsampling"]["OpticalFlowSubsampler"]["args"], + is_slurm_task=is_slurm_task, + ) + + def __call__( + self, + row, + ): + try: + self.process_shard(row) + return (True, row) + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"shard {row[0]} failed with error {err}") + return (False, row) + + def process_shard( + self, + row, + ): + """ + Process a video shard using the OpticalFlowSubsampler. + + Args: + row (tuple): A tuple containing the shard and shard_id. + + Raises: + Except + """ + shard, shard_id = row + start_time = time.time() + + try: + fs, shard_path = fsspec.core.url_to_fs(shard[: -len(".tar")] + ".parquet") + + with fs.open(shard_path, "rb") as f: + df = pa.parquet.read_table(f) + schema = df.schema + except Exception as e: # pylint: disable=broad-except,unused-variable + fields = [ + pa.field("key", pa.string()), + pa.field("status", pa.string()), + pa.field("error_message", pa.string()), + ] + schema = pa.schema(fields) + + status_dict = CappedCounter() + + # give schema to writer + sample_writer = self.sample_writer_class( + shard_id, + self.output_folder, + self.save_caption, + self.config["storage"]["oom_shard_count"], + schema, + self.encode_formats, + ) + + successes = 0 + failed_to_subsample = 0 + + dset = get_video_dataset( + urls=shard, + batch_size=1, + decoder_kwargs=self.config["reading"]["dataloader_args"]["decoder_kwargs"], + resize_size=self.config["reading"]["dataloader_args"]["resize_size"], + crop_size=None, + enforce_additional_keys=[], + return_always=True, + handler=wds.warn_and_continue, + ) + count = 0 + for sample in dset: + count += 1 + corrupted = sample["__corrupted__"][0] + key = sample["__key__"][0] + dummy_npy = numpy_npy_dumps(np.array([])) + if corrupted: + url = sample["__url__"][0] + meta = {} + meta["url"] = url + meta["key"] = key + error_message = "corrupted sample" + failed_to_subsample += 1 + status = "failed_to_subsample" + status_dict.increment(error_message) + meta["status"] = status + meta["error_message"] = error_message + meta["__corrupted__"] = True + sample_writer.write( + {"optical_flow": dummy_npy}, + key, + None, + meta, + ) + continue + meta = sample["json"][0] + streams = {} + frames = np.array(sample.get("mp4")[0]).astype(np.float32) + optical_flow, metrics, error_message = self.optical_flow_subsampler(frames) + + if error_message is not None: + failed_to_subsample += 1 + status = "failed_to_subsample" + status_dict.increment(error_message) + meta["status"] = status + meta["error_message"] = error_message + meta["__corrupted__"] = True + sample_writer.write( + {"optical_flow": dummy_npy}, + key, + None, + meta, + ) + continue + + successes += 1 + status = "success" + status_dict.increment(status) + meta["status"] = status + meta["__corrupted__"] = False + + mean_magnitude, mean_magnitude_per_frame = metrics + meta["mean_optical_flow_magnitude"] = mean_magnitude + meta["mean_optical_flow_magnitude_per_frame"] = mean_magnitude_per_frame + meta["optical_flow_fps"] = self.config["reading"]["dataloader_args"]["decoder_kwargs"]["fps"] + meta["optical_flow_downsample_size"] = self.config["reading"]["dataloader_args"]["resize_size"] + meta["optical_flow_dtype"] = str(self.optical_flow_subsampler.dtype) + + streams["optical_flow"] = numpy_npy_dumps(optical_flow) + sample_writer.write( + streams, + key, + None, + meta, + ) + + sample_writer.close() + end_time = time.time() + + write_stats( + self.output_folder, + shard_id, + count, # count + successes, + 0, # failed to download + failed_to_subsample, + 0, # bytes downloaded + start_time, + end_time, + status_dict, + self.config["storage"]["oom_shard_count"], + ) diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py new file mode 100644 index 00000000..519aad3e --- /dev/null +++ b/video2dataset/workers/subset_worker.py @@ -0,0 +1,134 @@ +"""creates a subset of an existing dataset inside the sample dimension""" +import fsspec +import json +import pyarrow as pa +import time +import traceback +from typing import Literal, cast +import webdataset as wds + +from video2dataset.dataloader import get_video_dataset +from video2dataset.logger import write_stats +from video2dataset.types import EncodeFormats, Streams +from video2dataset.workers.worker import ShardStatus, get_subsamplers, process_sample + + +class SubsetWorker: + """The loader class reads the shards, then the selected data is chosen and writen by the writer""" + + def __init__( + self, + sample_writer_class, + output_folder, + encode_formats: EncodeFormats, + config, + ) -> None: + self.sample_writer_class = sample_writer_class + self.output_folder = output_folder + self.config = config + self.input_encode_formats = encode_formats + self.subsamplers, self.output_encode_formats = get_subsamplers(config, self.input_encode_formats) + + def __call__( + self, + row, + ): + try: + shard_file, shard_id = row + self.process_shard(shard_file, shard_id) + return (True, row) + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"shard_file {row[0]} failed with error {err}") + return (False, row) + + def get_shard_processors( + self, + shard_file: str, + shard_id: int, + ): + """Get objects for loading and writing data""" + + try: + fs, shard_path = fsspec.core.url_to_fs(shard_file[: -len(".tar")] + ".parquet") + with fs.open(shard_path, "rb") as f: + df = pa.parquet.read_table(f) + schema = df.schema + except Exception: # pylint: disable=broad-except + fields = [ + pa.field("key", pa.string()), + pa.field("status", pa.string()), + pa.field("error_message", pa.string()), + ] + schema = pa.schema(fields) + shard_sample_writer = self.sample_writer_class( + shard_id, + self.output_folder, + True, # save_caption + self.config["storage"]["oom_shard_count"], + schema, + self.output_encode_formats, + ) + shard_dataloader = get_video_dataset( + urls=shard_file, + batch_size=1, + decoder_kwargs={}, + enforce_additional_keys=[], + handler=wds.warn_and_continue, + ) + return shard_sample_writer, shard_dataloader + + def process_shard( + self, + shard_file: str, + shard_id: int, + ): + """Function to start an video processing in one process""" + + start_time = time.time() + shard_sample_writer, shard_dataloader = self.get_shard_processors(shard_file, shard_id) + shard_status = ShardStatus() + + for sample in shard_dataloader: + shard_status.count += 1 + key = sample["__key__"] + try: + caption = sample.get("txt", b"").decode("utf-8") + metadata = json.loads(sample.get("json", b"{}").decode("utf-8")) + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"Sample {key} failed to download: {err}") + return + + streams: Streams = {} + for modality, encode_format in self.input_encode_formats.items(): + modality = cast(Literal["audio", "video"], modality) + streams[modality] = [sample[encode_format]] + + process_sample( + subsamplers=self.subsamplers, + shard_status=shard_status, + streams=streams, + key=key, + caption=caption, + metadata=metadata, + captions_are_subtitles=self.config["storage"]["captions_are_subtitles"], + shard_sample_writer=shard_sample_writer, + ) + + shard_sample_writer.close() + end_time = time.time() + + write_stats( + self.output_folder, + shard_id, + shard_status.count, + shard_status.successes, + 0, # failed to download + shard_status.failed["failed_to_subsample"], + 0, # bytes downloaded + start_time, + end_time, + shard_status.status_dict, + self.config["storage"]["oom_shard_count"], + ) diff --git a/video2dataset/workers/whisper_worker.py b/video2dataset/workers/whisper_worker.py new file mode 100644 index 00000000..c39f820e --- /dev/null +++ b/video2dataset/workers/whisper_worker.py @@ -0,0 +1,285 @@ +""" +Worker for optical flow stage +""" +import time +import math +import traceback +import json +import fsspec +import pyarrow as pa +import webdataset as wds + +from multiprocessing.pool import ThreadPool +from threading import Semaphore + +from video2dataset.data_reader import VideoDataReader +from video2dataset.logger import CappedCounter, write_stats +from video2dataset.subsamplers import WhisperSubsampler +from video2dataset.dataloader import get_video_dataset + + +def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count): + true_key = (10**oom_sample_per_shard) * shard_id + key + key_format = oom_sample_per_shard + oom_shard_count + str_key = "{true_key:0{key_format}d}".format( # pylint: disable=consider-using-f-string + key_format=key_format, true_key=true_key + ) + return str_key + + +class WhisperWorker: + """ + A class to read shards, process them using WhisperSubsampler, and write the output. + + Attributes: + sample_writer_class (type): The class used to write samples. + output_folder (str): The folder to write the output. + thread_count (int): The number of threads. + number_sample_per_shard (int): The number of samples per shard. + oom_shard_count (int): The number of out-of-memory shards. + encode_formats (dict): The encoding formats. + detector (str): The optical flow detector type. + fps (int): The target frames per second. + optical_flow_subsampler (WhisperSubsampler): The WhisperSubsampler instance. + """ + + def __init__( + self, + sample_writer_class, + output_folder, + column_list, + tmp_dir, + encode_formats, + is_slurm_task, + config, + ) -> None: + self.sample_writer_class = sample_writer_class + self.output_folder = output_folder + self.column_list = column_list + self.encode_formats = encode_formats + self.save_caption = False + self.config = config + + self.data_reader = VideoDataReader(encode_formats, tmp_dir, config["reading"]) + + if config["distribution"]["distributor"] != "slurm": + self.whisper_subsampler = WhisperSubsampler( + **self.config["subsampling"]["WhisperSubsampler"]["args"], + is_slurm_task=is_slurm_task, + ) + + def __call__( + self, + row, + ): + try: + self.process_shard(row) + return (True, row) + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"shard {row[0]} failed with error {err}") + return (False, row) + + def process_shard( + self, + row, + ): + """ + Process a video shard using the WhisperSubsampler. + + Args: + row (tuple): A tuple containing the shard and shard_id. + + Raises: + Except + """ + shard, shard_id = row + start_time = time.time() + + _ = from_wds = shard.endswith(".tar") + + if from_wds: + try: + fs, shard_path = fsspec.core.url_to_fs(shard[: -len(".tar")] + ".parquet") + + with fs.open(shard_path, "rb") as f: + df = pa.parquet.read_table(f) + schema = df.schema + except Exception as e: # pylint: disable=broad-except,unused-variable + fields = [ + pa.field("key", pa.string()), + pa.field("status", pa.string()), + pa.field("error_message", pa.string()), + ] + schema = pa.schema(fields) + semaphore = None + else: + fs, shard_path = fsspec.core.url_to_fs(shard) + with fs.open(shard_path, "rb") as f: + df = pa.ipc.open_file(f).read_all() + schema = df.schema + schema = ( + schema.append(pa.field("key", pa.string())) + .append(pa.field("status", pa.string())) + .append(pa.field("error_message", pa.string())) + ) + + pydict = df.select(self.column_list).to_pydict() + shard_to_dl = list(enumerate(zip(*(pydict[col] for col in self.column_list)))) + del pydict + del df + url_indice = self.column_list.index("url") + key_url_list = [(key, x[url_indice]) for key, x in shard_to_dl] + semaphore = Semaphore(self.config["distribution"]["thread_count"]) + + def data_generator(): + for e in key_url_list: + semaphore.acquire() # pylint: disable=(consider-using-with) + yield e + + loader = data_generator() + + status_dict = CappedCounter() + + # give schema to writer + sample_writer = self.sample_writer_class( + shard_id, + self.output_folder, + self.save_caption, + self.config["storage"]["oom_shard_count"], + schema, + self.encode_formats, + ) + oom_sample_per_shard = math.ceil(math.log10(self.config["storage"]["number_sample_per_shard"])) + + successes = 0 + failed_to_subsample = 0 + + if from_wds: + dset = get_video_dataset( + urls=shard, + batch_size=1, + decoder_kwargs={}, + video_key=self.encode_formats["audio"], + enforce_additional_keys=[], + return_always=True, + handler=wds.warn_and_continue, + ) + else: + + def create_dset(): + with ThreadPool(self.config["distribution"]["thread_count"]) as thread_pool: + for ( + key, + streams, + yt_meta_dict, + error_message, + ) in thread_pool.imap_unordered(self.data_reader, loader): + str_key = compute_key( + key, + shard_id, + oom_sample_per_shard, + self.config["storage"]["oom_shard_count"], + ) + sample = { + "__key__": str_key, + "__url__": shard, + "__corrupted__": error_message is not None, + } + meta = [ + { + "key": str_key, + "status": None, + "error_message": error_message, + "yt_meta_dict": yt_meta_dict, + } + ] + sample.update(streams) + sample["meta"] = meta + yield sample + + dset = create_dset() + + count = 0 + for sample in dset: + count += 1 + corrupted = sample["__corrupted__"] + key = sample["__key__"] + if corrupted: + url = sample["__url__"] + meta = {} + meta["url"] = url + meta["key"] = key + # error_message = "corrupted sample" + error_message = sample["meta"][0]["error_message"] + failed_to_subsample += 1 + status = "failed_to_subsample" + status_dict.increment(error_message) + meta["status"] = status + meta["error_message"] = error_message + meta["__corrupted__"] = True + sample_writer.write( + {}, + key, + None, + meta, + ) + _ = semaphore.release() if not from_wds else None + continue + meta = [json.loads(sample.get("json", b"{}").decode("utf-8"))] if from_wds else sample.pop("meta") + + streams = {"audio": [sample[self.encode_formats["audio"]]]} if from_wds else {"audio": [sample["audio"]]} + streams, meta, error_message = self.whisper_subsampler(streams, meta) + if error_message is not None: + failed_to_subsample += 1 + status = "failed_to_subsample" + status_dict.increment(error_message) + meta = meta[0] + meta["key"] = key + meta["url"] = sample["__url__"] + meta["status"] = status + meta["error_message"] = error_message + meta["__corrupted__"] = True + sample_writer.write( + {}, + key, + None, + meta, + ) + _ = semaphore.release() if not from_wds else None + continue + + streams.pop("audio") # only write metadata shards + meta = meta[0] # remove when unifying workers (needs to be list) + + successes += 1 + status = "success" + status_dict.increment(status) + meta["url"] = sample["__url__"] + meta["status"] = status + meta["__corrupted__"] = False + + sample_writer.write( + streams, + key, + None, + meta, + ) + _ = semaphore.release() if not from_wds else None + + sample_writer.close() + end_time = time.time() + + write_stats( + self.output_folder, + shard_id, + count, # count + successes, + 0, # failed to download + failed_to_subsample, + 0, # bytes downloaded + start_time, + end_time, + status_dict, + self.config["storage"]["oom_shard_count"], + ) diff --git a/video2dataset/workers/worker.py b/video2dataset/workers/worker.py new file mode 100644 index 00000000..45650829 --- /dev/null +++ b/video2dataset/workers/worker.py @@ -0,0 +1,197 @@ +"""Standard worker for video2dataset.""" +from dataclasses import dataclass, field +import numpy as np +from typing import Any, List, Tuple, Optional + +from video2dataset.logger import CappedCounter +from video2dataset.subsamplers import ( + ClippingSubsampler, + CutDetectionSubsampler, + FrameSubsampler, + FFProbeSubsampler, + NoOpSubsampler, + ResolutionSubsampler, + AudioRateSubsampler, + Subsampler, +) +from video2dataset.types import EncodeFormats, Streams, Metadata + + +@dataclass +class ShardStatus: + """Shard processing status""" + + successes: int = 0 + failed: dict = field( + default_factory=lambda: { + "failed_to_download": 0, + "failed_to_subsample": 0, + } + ) + status_dict: CappedCounter = field(default_factory=CappedCounter) + error_message: Optional[str] = None + count: int = 0 + bytes_downloaded: int = 0 + + +@dataclass +class Subsamplers: + """Subsamplers used in processing""" + + ffprobe_subsampler: Optional[FFProbeSubsampler] = None + modal_subsamplers: dict = field(default_factory=dict) + cut_detection_subsampler: Optional[CutDetectionSubsampler] = None + cuts_are_clips: bool = False + broadcast_subsampler: Subsampler = field(default_factory=NoOpSubsampler) + + +def get_subsamplers( + config: dict, + input_encode_formats: EncodeFormats, + do_clipping: bool = False, +) -> Tuple[Subsamplers, EncodeFormats]: + """Initialize all subsamplers using config""" + + clipping_subsampler = ClippingSubsampler( + oom_clip_count=5, + encode_formats=input_encode_formats, + **config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], + ) + need_keyframes = clipping_subsampler.precision == "keyframe_adjusted" + + cut_detection_subsampler = None + cuts_are_clips = False + if "CutDetectionSubsampler" in config["subsampling"]: + if "args" in config["subsampling"]["CutDetectionSubsampler"]: + cut_detection_subsampler = CutDetectionSubsampler(**config["subsampling"]["CutDetectionSubsampler"]["args"]) + cuts_are_clips = config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) + + broadcast_subsampler = ( + clipping_subsampler + if (do_clipping or config["storage"]["captions_are_subtitles"] or cuts_are_clips) + else NoOpSubsampler() + ) + + ffprobe_subsampler = None + if "FFProbeSubsampler" in config["subsampling"] or need_keyframes: + ffprobe_subsampler = FFProbeSubsampler(**config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"]) + ffprobe_subsampler.extract_keyframes |= need_keyframes + + video_subsamplers: List[Any] = [] + if "ResolutionSubsampler" in config["subsampling"]: + video_subsamplers.append(ResolutionSubsampler(**config["subsampling"]["ResolutionSubsampler"]["args"])) + if "FrameSubsampler" in config["subsampling"]: + video_subsamplers.append(FrameSubsampler(**config["subsampling"]["FrameSubsampler"]["args"])) + + audio_subsamplers: List[Any] = [] + if "AudioRateSubsampler" in config["subsampling"]: + audio_subsamplers.append(AudioRateSubsampler(**config["subsampling"]["AudioRateSubsampler"]["args"])) + + modal_subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} + + # output encoding formats + output_encode_formats = input_encode_formats.copy() + if modal_subsamplers["audio"]: + assert ( + len({s.encode_format for s in modal_subsamplers["audio"]}) == 1 + ) # assert that all audio subsamplers have the same output format + output_encode_formats["audio"] = modal_subsamplers["audio"][0].encode_format + if modal_subsamplers["video"]: + assert ( + len({s.encode_format for s in modal_subsamplers["video"]}) == 1 + ) # assert that all video subsamplers have the same output format + output_encode_formats["video"] = modal_subsamplers["video"][0].encode_format + + return ( + Subsamplers( + ffprobe_subsampler=ffprobe_subsampler, + modal_subsamplers=modal_subsamplers, + cut_detection_subsampler=cut_detection_subsampler, + cuts_are_clips=cuts_are_clips, + broadcast_subsampler=broadcast_subsampler, + ), + output_encode_formats, + ) + + +def process_sample( + subsamplers: Subsamplers, + shard_status: ShardStatus, + streams: Streams, + key: str, + caption: str, + metadata: Metadata, + captions_are_subtitles: bool, + shard_sample_writer: Any, # TODO: type correctly +): + """Process a single video""" + + try: + if subsamplers.ffprobe_subsampler is not None: + streams, metadata, shard_status.error_message = subsamplers.ffprobe_subsampler(streams, metadata) + assert shard_status.error_message is None + + if captions_are_subtitles: # create clips + subtitles = metadata["yt_meta_dict"]["subtitles"] + metadata["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] + elif subsamplers.cut_detection_subsampler is not None: # apply cut detection to get clips + streams, cuts, shard_status.error_message = subsamplers.cut_detection_subsampler(streams) + assert shard_status.error_message is None + metadata["cuts"] = cuts + assert cuts is not None + if subsamplers.cuts_are_clips: + metadata["clips"] = (np.array(cuts["cuts_original_fps"]) / cuts["original_fps"]).tolist() + + # 1 video -> many videos (either clipping or noop which does identity broadcasting) + subsampled_streams, metadatas, shard_status.error_message = subsamplers.broadcast_subsampler(streams, metadata) + if shard_status.error_message is not None: + metadata["clips"] = [] + assert False + + for modality in list(subsampled_streams.keys()): + for modality_subsampler in subsamplers.modal_subsamplers[modality]: + subsampled_streams, metadatas, shard_status.error_message = modality_subsampler( + subsampled_streams, metadatas + ) + assert shard_status.error_message is None + + shard_status.successes += 1 + status = "success" + shard_status.status_dict.increment(status) + + subsampled_streams_list = [dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values())] + if len(subsampled_streams_list) == 0: # no audio or video, just write metadata + metadata["status"] = status + shard_sample_writer.write( + {}, + key, + caption, + metadata, + ) + return + for subsampled_streams, subsampled_metadata in zip(subsampled_streams_list, metadatas): + subsampled_metadata["status"] = status + text_caption = caption + if captions_are_subtitles: + clip_subtitles = subsampled_metadata.get("clip_subtitles") + first_clip_subtitles = clip_subtitles[0] if clip_subtitles else None + subtitle_lines = first_clip_subtitles["lines"] if first_clip_subtitles else None + text_caption = subtitle_lines[0] if subtitle_lines else text_caption + shard_sample_writer.write( + subsampled_streams, + subsampled_metadata["key"], + text_caption, + subsampled_metadata, + ) + except Exception as err: # pylint: disable=broad-except + print(err) + shard_status.failed["failed_to_subsample"] += 1 + shard_status.status_dict.increment(shard_status.error_message) + metadata["status"] = "failed_to_subsample" + metadata["error_message"] = shard_status.error_message + shard_sample_writer.write( + {}, + key, + caption, + metadata, + )