Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Python 3.11 #197

Merged
merged 15 commits into from
Aug 25, 2024
2 changes: 1 addition & 1 deletion .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Install and configure Poetry
uses: snok/install-poetry@v1
with:
version: 1.2.2
version: 1.8.3
virtualenvs-create: true
virtualenvs-in-project: false
installer-parallel: true
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
# Select the Python versions to test against
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ["3.9", "3.10", "3.11"]
steps:
- name: Check out the code
uses: actions/checkout@v3
Expand All @@ -28,7 +28,7 @@ jobs:
- name: Install Poetry
uses: snok/[email protected]
with:
version: 1.4.0
version: 1.8.3

# Configure Poetry to use the virtual environment in the project
- name: Setup Poetry
Expand All @@ -38,7 +38,7 @@ jobs:
# Install the dependencies
- name: Install Package
run: |
poetry install --with tests
poetry install --with test

# Run the unit tests and build the coverage report
- name: Run Tests
Expand Down
4 changes: 2 additions & 2 deletions examples/bring_in_your_own.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __call__(self, x):
from fortuna.utils.random import generate_rng_like_tree
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map
from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp


Expand All @@ -86,7 +86,7 @@ def log_joint_prob(self, params: Params) -> float:
v = jnp.mean((ravel_pytree(params)[0] <= 1) & (ravel_pytree(params)[0] >= 0))
return jnp.where(v == 1.0, jnp.array(0), -jnp.inf)

def sample(self, params_like: Params, rng: Optional[PRNGKeyArray] = None) -> Params:
def sample(self, params_like: Params, rng: Optional[jax.Array] = None) -> Params:
if rng is None:
rng = self.rng.get()
keys = generate_rng_like_tree(rng, params_like)
Expand Down
6 changes: 3 additions & 3 deletions fortuna/calib_model/calib_model_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)

from flax.core import FrozenDict
from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp
from optax._src.base import PyTree

Expand All @@ -35,7 +35,7 @@ def training_loss_step(
params: Params,
batch: Batch,
mutable: Mutable,
rng: PRNGKeyArray,
rng: jax.Array,
n_data: int,
unravel: Optional[Callable[[any], PyTree]] = None,
calib_params: Optional[CalibParams] = None,
Expand Down Expand Up @@ -71,7 +71,7 @@ def validation_step(
state: CalibState,
batch: Batch,
loss_fun: Callable[[Any], Union[float, Tuple[float, dict]]],
rng: PRNGKeyArray,
rng: jax.Array,
n_data: int,
metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None,
unravel: Optional[Callable[[any], PyTree]] = None,
Expand Down
2 changes: 1 addition & 1 deletion fortuna/calib_model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _get_model_manager(
)
else:
model_manager = model_manager_cls(model, model_editor)
except ModuleNotFoundError as e:
except ModuleNotFoundError:
logging.warning(
"No module named 'transformer' is installed. "
"If you are not working with models from the `transformers` library ignore this warning, otherwise "
Expand Down
2 changes: 1 addition & 1 deletion fortuna/calib_model/config/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
f"All metrics in `metrics` must be callable objects, but {metric} is not."
)
if uncertainty_fn is not None and not callable(uncertainty_fn):
raise ValueError(f"`uncertainty_fn` must be a a callable function.")
raise ValueError("`uncertainty_fn` must be a a callable function.")

self.metrics = metrics
self.uncertainty_fn = uncertainty_fn
Expand Down
4 changes: 2 additions & 2 deletions fortuna/calib_model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Tuple,
)

from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp

from fortuna.likelihood.base import Likelihood
Expand Down Expand Up @@ -40,7 +40,7 @@ def __call__(
return_aux: Optional[List[str]] = None,
train: bool = False,
outputs: Optional[jnp.ndarray] = None,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
**kwargs,
) -> Tuple[jnp.ndarray, Any]:
if return_aux is None:
Expand Down
6 changes: 3 additions & 3 deletions fortuna/calib_model/predictive/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp

from fortuna.data.loader import (
Expand Down Expand Up @@ -60,7 +60,7 @@ def sample(
self,
inputs_loader: InputsLoader,
n_samples: int = 1,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Expand All @@ -80,7 +80,7 @@ def sample(
A loader of input data points.
n_samples : int
Number of target samples to sample for each input data point.
rng : Optional[PRNGKeyArray]
rng : Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down
14 changes: 7 additions & 7 deletions fortuna/calib_model/predictive/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Union,
)

from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp

from fortuna.calib_model.predictive.base import Predictive
Expand All @@ -21,7 +21,7 @@ def entropy(
self,
inputs_loader: InputsLoader,
n_samples: int = 30,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Expand All @@ -42,7 +42,7 @@ def entropy(
A loader of input data points.
n_samples : int
Number of samples to draw for each input.
rng : Optional[PRNGKeyArray]
rng : Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand All @@ -67,7 +67,7 @@ def quantile(
q: Union[float, Array, List],
inputs_loader: InputsLoader,
n_samples: Optional[int] = 30,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> Union[float, jnp.ndarray]:
r"""
Expand All @@ -81,7 +81,7 @@ def quantile(
A loader of input data points.
n_samples : int
Number of target samples to sample for each input data point.
rng: Optional[PRNGKeyArray]
rng: Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down Expand Up @@ -109,7 +109,7 @@ def credible_interval(
n_samples: int = 30,
error: float = 0.05,
interval_type: str = "two-tailed",
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Expand All @@ -126,7 +126,7 @@ def credible_interval(
`error=0.05` corresponds to a 95% level of credibility.
interval_type: str
The interval type. We support "two-tailed" (default), "right-tailed" and "left-tailed".
rng : Optional[PRNGKeyArray]
rng : Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional

from jax import vmap
import jax.numpy as jnp

from fortuna.conformal.multivalid.mixins.multicalibrator import MulticalibratorMixin
Expand Down
2 changes: 0 additions & 2 deletions fortuna/conformal/multivalid/one_shot/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import abc
import logging
from typing import (
Dict,
Optional,
Tuple,
Union,
)

Expand Down
14 changes: 7 additions & 7 deletions fortuna/data/dataset/huggingface_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
Dataset,
DatasetDict,
)
import jax
from jax import numpy as jnp
import jax.random
from jax.random import PRNGKeyArray
from tqdm import tqdm
from transformers import (
BatchEncoding,
Expand Down Expand Up @@ -90,7 +90,7 @@ def get_data_loader(
self,
dataset: Dataset,
per_device_batch_size: int,
rng: PRNGKeyArray,
rng: jax.Array,
shuffle: bool = False,
drop_last: bool = False,
verbose: bool = False,
Expand All @@ -105,7 +105,7 @@ def get_data_loader(
A tokenizeed dataset (see :meth:`.HuggingFaceClassificationDatasetABC.get_tokenized_datasets`).
per_device_batch_size: bool
Batch size for each device.
rng: PRNGKeyArray
rng: jax.Array
Random number generator.
shuffle: bool
if True, shuffle the data so that each batch is a ranom sample from the dataset.
Expand Down Expand Up @@ -141,7 +141,7 @@ def _collate(self, batch: Dict[str, Array], batch_size: int) -> Dict[str, Array]

@staticmethod
def _get_batches_idxs(
rng: PRNGKeyArray,
rng: jax.Array,
dataset_size: int,
batch_size: int,
shuffle: bool = False,
Expand All @@ -167,7 +167,7 @@ def _get_data_loader(
batch_size: int,
shuffle: bool = False,
drop_last: bool = False,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
verbose: bool = False,
) -> Union[Iterable[Dict[str, Array]], Iterable[Tuple[Dict[str, Array], Array]]]:
batch_idxs_gen = self._get_batches_idxs(
Expand Down Expand Up @@ -375,7 +375,7 @@ def __init__(
super(HuggingFaceMaskedLMDataset, self).__init__(*args, **kwargs)
if not self.tokenizer.is_fast:
logger.warning(
f"You are not using a Fast Tokenizer, so whole words cannot be masked, only tokens."
"You are not using a Fast Tokenizer, so whole words cannot be masked, only tokens."
)
self.mlm = mlm
self.mlm_probability = mlm_probability
Expand Down Expand Up @@ -407,7 +407,7 @@ def get_tokenized_datasets(
), "Only one text column should be passed when the task is MaskedLM."

def _tokenize_fn(
batch: Dict[str, List[Union[str, int]]]
batch: Dict[str, List[Union[str, int]]],
) -> Dict[str, List[int]]:
tokenized_inputs = self.tokenizer(
*[batch[col] for col in text_columns],
Expand Down
10 changes: 5 additions & 5 deletions fortuna/distribution/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Union

import jax
from jax import (
random,
vmap,
)
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
from jax.scipy.stats import (
multivariate_normal,
Expand All @@ -30,11 +30,11 @@ def __init__(self, mean: Union[float, Array], std: Union[float, Array]):
self.std = std
self.dim = 1 if type(mean) in [int, float] else len(mean)

def sample(self, rng: PRNGKeyArray, n_samples: int = 1) -> jnp.ndarray:
def sample(self, rng: jax.Array, n_samples: int = 1) -> jnp.ndarray:
"""
Sample from the diagonal Gaussian.

:param rng: PRNGKeyArray
:param rng: jax.Array
Random number generator.
:param n_samples: int
Number of samples.
Expand Down Expand Up @@ -72,11 +72,11 @@ def __init__(self, mean: Array, cov: Array):
self.cov = cov
self.dim = len(mean)

def sample(self, rng: PRNGKeyArray, n_samples: int = 1) -> jnp.ndarray:
def sample(self, rng: jax.Array, n_samples: int = 1) -> jnp.ndarray:
"""
Sample from the multivariate Gaussian.

:param rng: PRNGKeyArray
:param rng: jax.Array
Random number generator.
:param n_samples: int
Number of samples.
Expand Down
1 change: 0 additions & 1 deletion fortuna/hallucination/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
List,
Optional,
Tuple,
Union,
)

import numpy as np
Expand Down
12 changes: 6 additions & 6 deletions fortuna/likelihood/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
Union,
)

import jax
from jax import (
jit,
pmap,
)
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp

from fortuna.data.loader import (
Expand Down Expand Up @@ -133,7 +133,7 @@ def _batched_log_joint_prob(
return_aux: Optional[List[str]] = None,
train: bool = False,
outputs: Optional[jnp.ndarray] = None,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
**kwargs,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Any]]:
"""
Expand Down Expand Up @@ -161,7 +161,7 @@ def _batched_log_joint_prob(
Whether the method is called during training.
outputs : Optional[jnp.ndarray]
Pre-computed batch of outputs.
rng: Optional[PRNGKeyArray]
rng: Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.

Returns
Expand Down Expand Up @@ -272,7 +272,7 @@ def sample(
calib_params: Optional[CalibParams] = None,
calib_mutable: Optional[CalibMutable] = None,
return_aux: Optional[List[str]] = None,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
**kwargs,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, dict]]:
Expand All @@ -296,7 +296,7 @@ def sample(
return_aux : Optional[List[str]]
The auxiliary objects to return. We support 'outputs'. If this argument is not given, no auxiliary object
is returned.
rng: Optional[PRNGKeyArray]
rng: Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down Expand Up @@ -345,7 +345,7 @@ def _batched_sample(
calib_params: Optional[CalibParams] = None,
calib_mutable: Optional[CalibMutable] = None,
return_aux: Optional[List[str]] = None,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
**kwargs,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, dict]]:
if return_aux is None:
Expand Down
Loading
Loading