Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add common strs #257

Merged
merged 9 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
repos:
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.1
hooks:
- id: autoflake
args:
- --in-place
- --remove-all-unused-imports
- --remove-unused-variable
- --ignore-init-module-imports
- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
- id: black
language_version: python3.10
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.291
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: name-tests-test
- id: debug-statements
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
from pathlib import Path
import sys

import toml

sys.path.insert(0, os.path.abspath(".."))
sys.path.insert(0, str(Path("..").resolve()))
tmke8 marked this conversation as resolved.
Show resolved Hide resolved


# -- Project information -----------------------------------------------------
Expand Down
628 changes: 388 additions & 240 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ optional = true
[tool.poetry.group.torchcpu.dependencies]
torch = { version = "*", source = "torchcpu", markers = "sys_platform == 'linux'" }


[tool.poetry.group.dev.dependencies]
pre-commit = "^3.4.0"

tmke8 marked this conversation as resolved.
Show resolved Hide resolved
[[tool.poetry.source]]
name = "torchcpu"
url = "https://download.pytorch.org/whl/cpu"
Expand Down
16 changes: 15 additions & 1 deletion ranzen/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import copy
from enum import Enum
from enum import Enum, auto
import functools
import operator
import sys
Expand All @@ -19,6 +19,8 @@
"some",
"str_to_enum",
"unwrap_or",
"Stage",
"Split",
]


Expand Down Expand Up @@ -286,3 +288,15 @@ def unwrap_or(value: T | None, /, *, default: T) -> T:


default_if_none = unwrap_or


class Stage(StrEnum):
FIT = auto()
VALIDATE = auto()
TEST = auto()


class Split(StrEnum):
TRAIN = auto()
VAL = auto()
TEST = auto()
1 change: 0 additions & 1 deletion ranzen/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

__all__ = ["Addable", "DataclassInstance", "Sized", "is_td_instance"]


T_co = TypeVar("T_co", covariant=True)


Expand Down
6 changes: 3 additions & 3 deletions tests/torch_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def dummy_ds() -> TensorDataset:
@pytest.mark.parametrize("as_indices", [False, True])
@pytest.mark.parametrize("props", [0.5, [-0.2, 0.5], [0.1, 0.3, 0.4], [0.5, 0.6]])
def test_prop_random_split(
dummy_ds: TensorDataset, props: float | list[float], as_indices: bool
dummy_ds: TensorDataset, props: float | int | list[float] | list[int], as_indices: bool
tmke8 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
sum_ = props if isinstance(props, float) else sum(props)
props_ls = [props] if isinstance(props, float) else props
sum_ = props if isinstance(props, (float, int)) else sum(props)
props_ls = [props] if isinstance(props, (float, int)) else props
tmke8 marked this conversation as resolved.
Show resolved Hide resolved
if sum_ > 1 or any(not (0 <= prop <= 1) for prop in props_ls):
with pytest.raises(ValueError):
splits = prop_random_split(dataset_or_size=dummy_ds, props=props, as_indices=as_indices)
Expand Down
Loading