Skip to content

Commit

Permalink
Support for Python 3.11 (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
wistuba committed Aug 25, 2024
1 parent 940dda7 commit 2267823
Show file tree
Hide file tree
Showing 84 changed files with 4,002 additions and 3,743 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Install and configure Poetry
uses: snok/install-poetry@v1
with:
version: 1.2.2
version: 1.8.3
virtualenvs-create: true
virtualenvs-in-project: false
installer-parallel: true
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
# Select the Python versions to test against
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ["3.9", "3.10", "3.11"]
steps:
- name: Check out the code
uses: actions/checkout@v3
Expand All @@ -28,7 +28,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1.3.3
with:
version: 1.4.0
version: 1.8.3

# Configure Poetry to use the virtual environment in the project
- name: Setup Poetry
Expand All @@ -38,7 +38,7 @@ jobs:
# Install the dependencies
- name: Install Package
run: |
poetry install --with tests
poetry install --with test
# Run the unit tests and build the coverage report
- name: Run Tests
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version: 2
build:
os: ubuntu-20.04
tools:
python: '3.10'
python: '3.11'
jobs:
post_create_environment:
# Install poetry
Expand Down
4 changes: 2 additions & 2 deletions examples/bring_in_your_own.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __call__(self, x):
from fortuna.utils.random import generate_rng_like_tree
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map
from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp


Expand All @@ -86,7 +86,7 @@ def log_joint_prob(self, params: Params) -> float:
v = jnp.mean((ravel_pytree(params)[0] <= 1) & (ravel_pytree(params)[0] >= 0))
return jnp.where(v == 1.0, jnp.array(0), -jnp.inf)

def sample(self, params_like: Params, rng: Optional[PRNGKeyArray] = None) -> Params:
def sample(self, params_like: Params, rng: Optional[jax.Array] = None) -> Params:
if rng is None:
rng = self.rng.get()
keys = generate_rng_like_tree(rng, params_like)
Expand Down
6 changes: 3 additions & 3 deletions fortuna/calib_model/calib_model_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)

from flax.core import FrozenDict
from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp
from optax._src.base import PyTree

Expand All @@ -35,7 +35,7 @@ def training_loss_step(
params: Params,
batch: Batch,
mutable: Mutable,
rng: PRNGKeyArray,
rng: jax.Array,
n_data: int,
unravel: Optional[Callable[[any], PyTree]] = None,
calib_params: Optional[CalibParams] = None,
Expand Down Expand Up @@ -71,7 +71,7 @@ def validation_step(
state: CalibState,
batch: Batch,
loss_fun: Callable[[Any], Union[float, Tuple[float, dict]]],
rng: PRNGKeyArray,
rng: jax.Array,
n_data: int,
metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None,
unravel: Optional[Callable[[any], PyTree]] = None,
Expand Down
2 changes: 1 addition & 1 deletion fortuna/calib_model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _get_model_manager(
)
else:
model_manager = model_manager_cls(model, model_editor)
except ModuleNotFoundError as e:
except ModuleNotFoundError:
logging.warning(
"No module named 'transformer' is installed. "
"If you are not working with models from the `transformers` library ignore this warning, otherwise "
Expand Down
2 changes: 1 addition & 1 deletion fortuna/calib_model/config/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
f"All metrics in `metrics` must be callable objects, but {metric} is not."
)
if uncertainty_fn is not None and not callable(uncertainty_fn):
raise ValueError(f"`uncertainty_fn` must be a a callable function.")
raise ValueError("`uncertainty_fn` must be a a callable function.")

self.metrics = metrics
self.uncertainty_fn = uncertainty_fn
Expand Down
4 changes: 2 additions & 2 deletions fortuna/calib_model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Tuple,
)

from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp

from fortuna.likelihood.base import Likelihood
Expand Down Expand Up @@ -40,7 +40,7 @@ def __call__(
return_aux: Optional[List[str]] = None,
train: bool = False,
outputs: Optional[jnp.ndarray] = None,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
**kwargs,
) -> Tuple[jnp.ndarray, Any]:
if return_aux is None:
Expand Down
6 changes: 3 additions & 3 deletions fortuna/calib_model/predictive/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp

from fortuna.data.loader import (
Expand Down Expand Up @@ -60,7 +60,7 @@ def sample(
self,
inputs_loader: InputsLoader,
n_samples: int = 1,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Expand All @@ -80,7 +80,7 @@ def sample(
A loader of input data points.
n_samples : int
Number of target samples to sample for each input data point.
rng : Optional[PRNGKeyArray]
rng : Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down
14 changes: 7 additions & 7 deletions fortuna/calib_model/predictive/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Union,
)

from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp

from fortuna.calib_model.predictive.base import Predictive
Expand All @@ -21,7 +21,7 @@ def entropy(
self,
inputs_loader: InputsLoader,
n_samples: int = 30,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Expand All @@ -42,7 +42,7 @@ def entropy(
A loader of input data points.
n_samples : int
Number of samples to draw for each input.
rng : Optional[PRNGKeyArray]
rng : Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand All @@ -67,7 +67,7 @@ def quantile(
q: Union[float, Array, List],
inputs_loader: InputsLoader,
n_samples: Optional[int] = 30,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> Union[float, jnp.ndarray]:
r"""
Expand All @@ -81,7 +81,7 @@ def quantile(
A loader of input data points.
n_samples : int
Number of target samples to sample for each input data point.
rng: Optional[PRNGKeyArray]
rng: Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down Expand Up @@ -109,7 +109,7 @@ def credible_interval(
n_samples: int = 30,
error: float = 0.05,
interval_type: str = "two-tailed",
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Expand All @@ -126,7 +126,7 @@ def credible_interval(
`error=0.05` corresponds to a 95% level of credibility.
interval_type: str
The interval type. We support "two-tailed" (default), "right-tailed" and "left-tailed".
rng : Optional[PRNGKeyArray]
rng : Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional

from jax import vmap
import jax.numpy as jnp

from fortuna.conformal.multivalid.mixins.multicalibrator import MulticalibratorMixin
Expand Down
2 changes: 0 additions & 2 deletions fortuna/conformal/multivalid/one_shot/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import abc
import logging
from typing import (
Dict,
Optional,
Tuple,
Union,
)

Expand Down
14 changes: 7 additions & 7 deletions fortuna/data/dataset/huggingface_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
Dataset,
DatasetDict,
)
import jax
from jax import numpy as jnp
import jax.random
from jax.random import PRNGKeyArray
from tqdm import tqdm
from transformers import (
BatchEncoding,
Expand Down Expand Up @@ -90,7 +90,7 @@ def get_data_loader(
self,
dataset: Dataset,
per_device_batch_size: int,
rng: PRNGKeyArray,
rng: jax.Array,
shuffle: bool = False,
drop_last: bool = False,
verbose: bool = False,
Expand All @@ -105,7 +105,7 @@ def get_data_loader(
A tokenizeed dataset (see :meth:`.HuggingFaceClassificationDatasetABC.get_tokenized_datasets`).
per_device_batch_size: bool
Batch size for each device.
rng: PRNGKeyArray
rng: jax.Array
Random number generator.
shuffle: bool
if True, shuffle the data so that each batch is a ranom sample from the dataset.
Expand Down Expand Up @@ -141,7 +141,7 @@ def _collate(self, batch: Dict[str, Array], batch_size: int) -> Dict[str, Array]

@staticmethod
def _get_batches_idxs(
rng: PRNGKeyArray,
rng: jax.Array,
dataset_size: int,
batch_size: int,
shuffle: bool = False,
Expand All @@ -167,7 +167,7 @@ def _get_data_loader(
batch_size: int,
shuffle: bool = False,
drop_last: bool = False,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
verbose: bool = False,
) -> Union[Iterable[Dict[str, Array]], Iterable[Tuple[Dict[str, Array], Array]]]:
batch_idxs_gen = self._get_batches_idxs(
Expand Down Expand Up @@ -375,7 +375,7 @@ def __init__(
super(HuggingFaceMaskedLMDataset, self).__init__(*args, **kwargs)
if not self.tokenizer.is_fast:
logger.warning(
f"You are not using a Fast Tokenizer, so whole words cannot be masked, only tokens."
"You are not using a Fast Tokenizer, so whole words cannot be masked, only tokens."
)
self.mlm = mlm
self.mlm_probability = mlm_probability
Expand Down Expand Up @@ -407,7 +407,7 @@ def get_tokenized_datasets(
), "Only one text column should be passed when the task is MaskedLM."

def _tokenize_fn(
batch: Dict[str, List[Union[str, int]]]
batch: Dict[str, List[Union[str, int]]],
) -> Dict[str, List[int]]:
tokenized_inputs = self.tokenizer(
*[batch[col] for col in text_columns],
Expand Down
10 changes: 5 additions & 5 deletions fortuna/distribution/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Union

import jax
from jax import (
random,
vmap,
)
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
from jax.scipy.stats import (
multivariate_normal,
Expand All @@ -30,11 +30,11 @@ def __init__(self, mean: Union[float, Array], std: Union[float, Array]):
self.std = std
self.dim = 1 if type(mean) in [int, float] else len(mean)

def sample(self, rng: PRNGKeyArray, n_samples: int = 1) -> jnp.ndarray:
def sample(self, rng: jax.Array, n_samples: int = 1) -> jnp.ndarray:
"""
Sample from the diagonal Gaussian.
:param rng: PRNGKeyArray
:param rng: jax.Array
Random number generator.
:param n_samples: int
Number of samples.
Expand Down Expand Up @@ -72,11 +72,11 @@ def __init__(self, mean: Array, cov: Array):
self.cov = cov
self.dim = len(mean)

def sample(self, rng: PRNGKeyArray, n_samples: int = 1) -> jnp.ndarray:
def sample(self, rng: jax.Array, n_samples: int = 1) -> jnp.ndarray:
"""
Sample from the multivariate Gaussian.
:param rng: PRNGKeyArray
:param rng: jax.Array
Random number generator.
:param n_samples: int
Number of samples.
Expand Down
1 change: 0 additions & 1 deletion fortuna/hallucination/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
List,
Optional,
Tuple,
Union,
)

import numpy as np
Expand Down
Loading

0 comments on commit 2267823

Please sign in to comment.