Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[pre-commit.ci] pre-commit suggestions (#1697)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
pre-commit-ci[bot] authored Oct 3, 2023
1 parent 1ce6710 commit a0e4620
Show file tree
Hide file tree
Showing 70 changed files with 209 additions and 23 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ repos:
- id: detect-private-key

- repo: https://github.com/asottile/pyupgrade
rev: v3.8.0
rev: v3.14.0
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand All @@ -48,20 +48,20 @@ repos:
- id: nbstripout

- repo: https://github.com/PyCQA/docformatter
rev: v1.7.3
rev: v1.7.5
hooks:
- id: docformatter
additional_dependencies: [tomli]
args: ["--in-place"]

- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.9.1
hooks:
- id: black
name: Format code

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.276
rev: v0.0.292
hooks:
- id: ruff
args: ["--fix"]
10 changes: 8 additions & 2 deletions src/flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@


class AudioClassificationData(DataModule):
"""The ``AudioClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of
class methods for loading data for audio classification."""
"""The ``AudioClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of class
methods for loading data for audio classification."""

input_transform_cls = AudioClassificationInputTransform

Expand Down Expand Up @@ -141,6 +141,7 @@ def from_files(
>>> import os
>>> _ = [os.remove(f"spectrogram_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_spectrogram_{i}.png") for i in range(1, 4)]
"""

ds_kw = {
Expand Down Expand Up @@ -275,6 +276,7 @@ def from_folders(
>>> import shutil
>>> shutil.rmtree("train_folder")
>>> shutil.rmtree("predict_folder")
"""

ds_kw = {
Expand Down Expand Up @@ -365,6 +367,7 @@ def from_numpy(
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
"""

ds_kw = {
Expand Down Expand Up @@ -453,6 +456,7 @@ def from_tensors(
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
"""

ds_kw = {
Expand Down Expand Up @@ -607,6 +611,7 @@ def from_data_frame(
>>> shutil.rmtree("predict_folder")
>>> del train_data_frame
>>> del predict_data_frame
"""

ds_kw = {
Expand Down Expand Up @@ -854,6 +859,7 @@ def from_csv(
>>> shutil.rmtree("predict_folder")
>>> os.remove("train_data.tsv")
>>> os.remove("predict_data.tsv")
"""

ds_kw = {
Expand Down
4 changes: 4 additions & 0 deletions src/flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def from_files(
>>> import os
>>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
"""

ds_kw = {"sampling_rate": sampling_rate}
Expand Down Expand Up @@ -302,6 +303,7 @@ def from_csv(
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
>>> os.remove("train_data.tsv")
>>> os.remove("predict_data.tsv")
"""

ds_kw = {"input_key": input_field, "sampling_rate": sampling_rate}
Expand Down Expand Up @@ -424,6 +426,7 @@ def from_json(
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
>>> os.remove("train_data.json")
>>> os.remove("predict_data.json")
"""

ds_kw = {"input_key": input_field, "sampling_rate": sampling_rate, "field": field}
Expand Down Expand Up @@ -570,6 +573,7 @@ def from_datasets(
>>> import os
>>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
"""

ds_kw = {"sampling_rate": sampling_rate}
Expand Down
1 change: 1 addition & 0 deletions src/flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class SpeechRecognition(Task):
learning_rate: Learning rate to use for training, defaults to ``1e-5``.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
"""

backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES
Expand Down
2 changes: 2 additions & 0 deletions src/flash/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def from_task(cls, task: "flash.Task", **kwargs) -> "Adapter":
"""Instantiate the adapter from the given :class:`~flash.core.model.Task`.
This includes resolution / creation of backbones / heads and any other provider specific options.
"""

def forward(self, x: Any) -> Any:
Expand Down Expand Up @@ -73,6 +74,7 @@ class AdapterTask(Task):
Args:
adapter: The :class:`~flash.core.adapter.Adapter` to wrap.
kwargs: Keyword arguments to be passed to the base :class:`~flash.core.model.Task`.
"""

def __init__(self, adapter: Adapter, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage):
As the :class:`~flash.core.data.io.input_transform.InputTransform` hooks are injected within
the threaded workers of the DataLoader,
the data won't be accessible when using ``num_workers > 0``.
"""

def _show(
Expand Down
2 changes: 2 additions & 0 deletions src/flash/core/data/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class FlashCallback(Callback):
Same as PyTorch Lightning, Callbacks can be provided directly to the Trainer::
trainer = Trainer(callbacks=[MyCustomCallback()])
"""

def on_per_sample_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
Expand Down Expand Up @@ -146,6 +147,7 @@ def from_inputs(
'val': {},
'predict': {}
}
"""

batches: dict
Expand Down
5 changes: 3 additions & 2 deletions src/flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@

class DatasetInput(Input):
"""The ``DatasetInput`` implements default behaviours for data sources which expect the input to
:meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset`
"""
:meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset`"""

def load_sample(self, sample: Any) -> Dict[str, Any]:
if isinstance(sample, tuple) and len(sample) == 2:
Expand Down Expand Up @@ -103,6 +102,7 @@ class DataModule(pl.LightningDataModule):
>>> datamodule = DataModule(train_input, sampler=WeightedRandomSampler([0.1, 0.5], 2), batch_size=1)
>>> print(datamodule.train_dataloader().sampler) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<torch.utils.data.sampler.WeightedRandomSampler object at ...>
"""

input_transform_cls = InputTransform
Expand Down Expand Up @@ -399,6 +399,7 @@ def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
"""This function is used to configure a :class:`~flash.core.data.callback.BaseDataFetcher`.
Override with your custom one.
"""
return BaseDataFetcher()

Expand Down
2 changes: 2 additions & 0 deletions src/flash/core/data/io/classification_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ClassificationInputMixin(Properties):
targets and store metadata like ``labels`` and ``num_classes``.
* In the ``load_sample`` method, use ``format_target`` to convert the target to a standard format for use with our
tasks.
"""

target_formatter: TargetFormatter
Expand All @@ -46,6 +47,7 @@ def load_target_metadata(
rather than inferring from the targets.
add_background: If ``True``, a background class will be inserted as class zero if ``labels`` and
``num_classes`` are being inferred.
"""
self.target_formatter = target_formatter
if target_formatter is None and targets is not None:
Expand Down
20 changes: 17 additions & 3 deletions src/flash/core/data/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _validate_input(input: "InputBase") -> None:
Raises:
RuntimeError: If the ``input`` is of type ``Input`` and it's ``data`` attribute does not support ``len``.
RuntimeError: If the ``input`` is of type ``IterableInput`` and it's ``data`` attribute does support ``len``.
"""
if input.data is not None:
if isinstance(input, Input) and not _has_len(input.data):
Expand All @@ -122,6 +123,7 @@ def _wrap_init(class_dict: Dict[str, Any]) -> None:
Args:
class_dict: The class construction dict, optionally containing an init to wrap.
"""
if "__init__" in class_dict:
fn = class_dict["__init__"]
Expand Down Expand Up @@ -153,14 +155,15 @@ def __new__(mcs, name: str, bases: Tuple, class_dict: Dict[str, Any]) -> "_Itera

class InputBase(Properties, metaclass=_InputMeta):
"""``InputBase`` is the base class for the :class:`~flash.core.data.io.input.Input` and
:class:`~flash.core.data.io.input.IterableInput` dataset implementations in Flash. These datasets are
constructed via the ``load_data`` and ``load_sample`` hooks, which allow a single dataset object to include custom
loading logic according to the running stage (e.g. train, validate, test, predict).
:class:`~flash.core.data.io.input.IterableInput` dataset implementations in Flash. These datasets are constructed
via the ``load_data`` and ``load_sample`` hooks, which allow a single dataset object to include custom loading logic
according to the running stage (e.g. train, validate, test, predict).
Args:
running_stage: The running stage for which the input will be used.
*args: Any arguments that are to be passed to the ``load_data`` hook.
**kwargs: Any additional keyword arguments to pass to the ``load_data`` hook.
"""

def __init__(self, running_stage: RunningStage, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -194,6 +197,7 @@ def load_data(*args: Any, **kwargs: Any) -> Union[Sequence, Iterable]:
Args:
*args: Any arguments that the input requires.
**kwargs: Any additional keyword arguments that the input requires.
"""
return args[0]

Expand All @@ -203,6 +207,7 @@ def train_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable
Args:
*args: Any arguments that the input requires.
**kwargs: Any additional keyword arguments that the input requires.
"""
return self.load_data(*args, **kwargs)

Expand All @@ -212,6 +217,7 @@ def val_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]:
Args:
*args: Any arguments that the input requires.
**kwargs: Any additional keyword arguments that the input requires.
"""
return self.load_data(*args, **kwargs)

Expand All @@ -221,6 +227,7 @@ def test_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]
Args:
*args: Any arguments that the input requires.
**kwargs: Any additional keyword arguments that the input requires.
"""
return self.load_data(*args, **kwargs)

Expand All @@ -230,6 +237,7 @@ def predict_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterab
Args:
*args: Any arguments that the input requires.
**kwargs: Any additional keyword arguments that the input requires.
"""
return self.load_data(*args, **kwargs)

Expand All @@ -240,6 +248,7 @@ def load_sample(sample: Dict[str, Any]) -> Any:
Args:
sample: A single sample from the output of the ``load_data`` hook.
"""
return sample

Expand All @@ -248,6 +257,7 @@ def train_load_sample(self, sample: Dict[str, Any]) -> Any:
Args:
sample: A single sample from the output of the ``load_data`` hook.
"""
return self.load_sample(sample)

Expand All @@ -256,6 +266,7 @@ def val_load_sample(self, sample: Dict[str, Any]) -> Any:
Args:
sample: A single sample from the output of the ``load_data`` hook.
"""
return self.load_sample(sample)

Expand All @@ -264,6 +275,7 @@ def test_load_sample(self, sample: Dict[str, Any]) -> Any:
Args:
sample: A single sample from the output of the ``load_data`` hook.
"""
return self.load_sample(sample)

Expand All @@ -272,13 +284,15 @@ def predict_load_sample(self, sample: Dict[str, Any]) -> Any:
Args:
sample: A single sample from the output of the ``load_data`` hook.
"""
return self.load_sample(sample)

def __bool__(self):
"""If ``self.data`` is ``None`` then the ``InputBase`` is considered falsey.
This allows for quickly checking whether or not the ``InputBase`` is populated with data.
"""
return self.data is not None

Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/io/transform_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class TransformPredictions(Callback):
Args:
output_transform: The :class:`~flash.core.data.io.output_transform.OutputTransform` to apply.
output: The :class:`~flash.core.data.io.output.Output` to apply.
"""

def __init__(self, output_transform: OutputTransform, output: Output):
Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class ApplyToKeys(nn.Sequential):
Args:
keys: The key (``str``) or sequence of keys (``Sequence[str]``) to extract and forward to the transforms.
args: The transforms, passed to the ``nn.Sequential`` super constructor.
"""

def __init__(self, keys: Union[str, Sequence[str]], *args):
Expand Down
Loading

0 comments on commit a0e4620

Please sign in to comment.