Skip to content

Commit

Permalink
Common strs (#257)
Browse files Browse the repository at this point in the history
This is not a very exciting PR. It just adds two enums to avoid
potentially mis-typing common (albeit short) words.

On the housekeeping side, ran `poetry update`, and add `pre-commit` to
the dev dependencies (doesn't really fit into an existing group).
  • Loading branch information
olliethomas authored Sep 29, 2023
2 parents cf694b9 + 6c3a489 commit 080d5d3
Show file tree
Hide file tree
Showing 8 changed files with 441 additions and 247 deletions.
29 changes: 29 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
repos:
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.1
hooks:
- id: autoflake
args:
- --in-place
- --remove-all-unused-imports
- --remove-unused-variable
- --ignore-init-module-imports
- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
- id: black
language_version: python3.10
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.291
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: name-tests-test
- id: debug-statements
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
from pathlib import Path
import sys

import toml

sys.path.insert(0, os.path.abspath(".."))
sys.path.insert(0, str(Path("..").resolve()))


# -- Project information -----------------------------------------------------
Expand Down
628 changes: 388 additions & 240 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ optional = true
[tool.poetry.group.torchcpu.dependencies]
torch = { version = "*", source = "torchcpu", markers = "sys_platform == 'linux'" }


[tool.poetry.group.dev.dependencies]
pre-commit = "^3.4.0"

[[tool.poetry.source]]
name = "torchcpu"
url = "https://download.pytorch.org/whl/cpu"
Expand Down
16 changes: 15 additions & 1 deletion ranzen/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import copy
from enum import Enum
from enum import Enum, auto
import functools
import operator
import sys
Expand All @@ -19,6 +19,8 @@
"some",
"str_to_enum",
"unwrap_or",
"Stage",
"Split",
]


Expand Down Expand Up @@ -286,3 +288,15 @@ def unwrap_or(value: T | None, /, *, default: T) -> T:


default_if_none = unwrap_or


class Stage(StrEnum):
FIT = auto()
VALIDATE = auto()
TEST = auto()


class Split(StrEnum):
TRAIN = auto()
VAL = auto()
TEST = auto()
2 changes: 1 addition & 1 deletion ranzen/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def prop_random_split(
)
len_ = len(dataset_or_size)

if isinstance(props, float):
if isinstance(props, (float, int)):
props = [props]
sum_ = np.sum(props)
if (sum_ > 1.0) or any(prop < 0 for prop in props):
Expand Down
1 change: 0 additions & 1 deletion ranzen/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

__all__ = ["Addable", "DataclassInstance", "Sized", "is_td_instance"]


T_co = TypeVar("T_co", covariant=True)


Expand Down
4 changes: 2 additions & 2 deletions tests/torch_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def dummy_ds() -> TensorDataset:
def test_prop_random_split(
dummy_ds: TensorDataset, props: float | list[float], as_indices: bool
) -> None:
sum_ = props if isinstance(props, float) else sum(props)
props_ls = [props] if isinstance(props, float) else props
sum_ = props if isinstance(props, (float, int)) else sum(props)
props_ls = [props] if isinstance(props, (float, int)) else props
if sum_ > 1 or any(not (0 <= prop <= 1) for prop in props_ls):
with pytest.raises(ValueError):
splits = prop_random_split(dataset_or_size=dummy_ds, props=props, as_indices=as_indices)
Expand Down

0 comments on commit 080d5d3

Please sign in to comment.