Skip to content

Commit

Permalink
Allow python3.11 (#153)
Browse files Browse the repository at this point in the history
gianlucadetommaso authored Nov 20, 2023
1 parent 76ad7a2 commit afbf206
Showing 15 changed files with 66 additions and 78 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
# Select the Python versions to test against
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.8', '3.9', '3.10', '3.11']
steps:
- name: Check out the code
uses: actions/checkout@v3
4 changes: 2 additions & 2 deletions fortuna/output_calib_model/state.py
Original file line number Diff line number Diff line change
@@ -16,13 +16,13 @@
CalibParams,
OptaxOptimizer,
)
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class OutputCalibState(TrainState):
params: CalibParams
mutable: Optional[CalibMutable] = None
encoded_name: jnp.ndarray = convert_string_to_jnp_array("OutputCalibState")
encoded_name: tuple = convert_string_to_tuple("OutputCalibState")

@classmethod
def init(
4 changes: 2 additions & 2 deletions fortuna/prob_model/posterior/laplace/laplace_state.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
)
from fortuna.utils.nested_dicts import nested_pair
from fortuna.utils.strings import (
convert_string_to_jnp_array,
convert_string_to_tuple,
encode_tuple_of_lists_of_strings_to_numpy,
)

@@ -36,7 +36,7 @@ class LaplaceState(PosteriorState):
"""

prior_log_var: float = 0.0
encoded_name: jnp.ndarray = convert_string_to_jnp_array("LaplaceState")
encoded_name: tuple = convert_string_to_tuple("LaplaceState")
_encoded_which_params: Optional[Dict[str, Array]] = None

@classmethod
4 changes: 2 additions & 2 deletions fortuna/prob_model/posterior/map/map_state.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import jax.numpy as jnp

from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class MAPState(PosteriorState):
@@ -14,4 +14,4 @@ class MAPState(PosteriorState):
MAP state name encoded as an array.
"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("MAPState")
encoded_name: tuple = convert_string_to_tuple("MAPState")
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
NormalizingFlowState,
)
from fortuna.typing import Array
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class ADVIState(NormalizingFlowState):
@@ -23,5 +23,5 @@ class ADVIState(NormalizingFlowState):
ADVI state name encoded as an array.
"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("ADVIState")
encoded_name: tuple = convert_string_to_tuple("ADVIState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
OptaxOptimizer,
)
from fortuna.utils.strings import (
convert_string_to_jnp_array,
convert_string_to_tuple,
encode_tuple_of_lists_of_strings_to_numpy,
)

@@ -30,7 +30,7 @@ class CyclicalSGLDState(PosteriorState):
CyclicalSGLDState state name encoded as an array.
"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("CyclicalSGLDState")
encoded_name: tuple = convert_string_to_tuple("CyclicalSGLDState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None

@classmethod
4 changes: 2 additions & 2 deletions fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
OptaxOptimizer,
)
from fortuna.utils.strings import (
convert_string_to_jnp_array,
convert_string_to_tuple,
encode_tuple_of_lists_of_strings_to_numpy,
)

@@ -30,7 +30,7 @@ class SGHMCState(PosteriorState):
SGHMC state name encoded as an array.
"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("SGHMCState")
encoded_name: tuple = convert_string_to_tuple("SGHMCState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None

@classmethod
4 changes: 2 additions & 2 deletions fortuna/prob_model/posterior/state.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
OptaxOptimizer,
Params,
)
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class PosteriorState(TrainState):
@@ -33,7 +33,7 @@ class PosteriorState(TrainState):
calib_mutable: Optional[CalibMutable] = None
grad_accumulated: Optional[jnp.ndarray] = None
dynamic_scale: Optional[dynamic_scale.DynamicScale] = None
encoded_name: jnp.ndarray = convert_string_to_jnp_array("PosteriorState")
encoded_name: tuple = convert_string_to_tuple("PosteriorState")

@classmethod
def init(
4 changes: 2 additions & 2 deletions fortuna/prob_model/posterior/swag/swag_state.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
Array,
OptaxOptimizer,
)
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class SWAGState(PosteriorState):
@@ -35,7 +35,7 @@ class SWAGState(PosteriorState):
mean: Optional[jnp.ndarray] = None
std: Optional[jnp.ndarray] = None
dev: Optional[jnp.ndarray] = None
encoded_name: jnp.ndarray = convert_string_to_jnp_array("SWAGState")
encoded_name: tuple = convert_string_to_tuple("SWAGState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None

@classmethod
2 changes: 1 addition & 1 deletion fortuna/training/mixin.py
Original file line number Diff line number Diff line change
@@ -81,7 +81,7 @@ def restore_checkpoint(
raise ValueError(
f"No checkpoint was found in `restore_checkpoint_path={restore_checkpoint_path}`."
)
name = "".join([chr(n) for n in d["encoded_name"].tolist()])
name = "".join([chr(n) for n in d["encoded_name"].values()])
return name_to_train_state[name].value.init_from_dict(d, optimizer, **kwargs)

def get_path_latest_checkpoint(
4 changes: 2 additions & 2 deletions fortuna/training/train_state.py
Original file line number Diff line number Diff line change
@@ -12,11 +12,11 @@
import jax.numpy as jnp

from fortuna.typing import Params
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class TrainState(train_state.TrainState):
encoded_name: jnp.ndarray = convert_string_to_jnp_array("TrainState")
encoded_name: tuple = convert_string_to_tuple("TrainState")
frozen_params: Optional[Params] = None
dynamic_scale: Optional[dynamic_scale.DynamicScale] = None

4 changes: 2 additions & 2 deletions fortuna/utils/strings.py
Original file line number Diff line number Diff line change
@@ -12,8 +12,8 @@
from fortuna.typing import Array


def convert_string_to_jnp_array(s: str) -> jnp.ndarray:
return jnp.array([ord(c) for c in s])
def convert_string_to_tuple(s: str) -> Tuple:
return tuple([ord(c) for c in s])


def convert_string_to_np_array(s: str) -> np.ndarray:
86 changes: 42 additions & 44 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -9,14 +9,14 @@ documentation = "https://aws-fortuna.readthedocs.io/en/latest/"
packages = [{include = "fortuna"}]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
python = ">=3.8,<3.12"
flax = "^0.6.2"
optax = "^0.1.3"
matplotlib = "^3.6.2"
tqdm = "^4.64.1"
numpy = "^1.23.4"
tensorflow-cpu = { version = "^2.11.0", markers = "sys_platform != 'darwin'" }
tensorflow-macos = { version = "^2.11.0", markers = "sys_platform == 'darwin'" }
tensorflow-macos = { version = "^2.13.1", markers = "sys_platform == 'darwin'" }
Sphinx = { version = "^5.3.0", optional = true }
sphinx-autodoc-typehints = { version = "^1.19.5", optional = true }
nbsphinx = { version = "^0.8.10", optional = true }
10 changes: 0 additions & 10 deletions tests/fortuna/test_mixin.py
Original file line number Diff line number Diff line change
@@ -113,16 +113,6 @@ def test_restore_checkpoint(self):
calib_params=None,
calib_mutable=None,
)
restored_state = trainer.restore_checkpoint(
tmp_dir, prefix="test_prefix_"
)
mc.restore_checkpoint.assert_called_with(
ckpt_dir=tmp_dir,
target=None,
step=None,
prefix="test_prefix_",
parallel=True,
)


class TestEarlyStoppingMixins(unittest.TestCase):

0 comments on commit afbf206

Please sign in to comment.