diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 79ef9f36..78a19f7c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: # Select the Python versions to test against - python-version: ['3.8', '3.9', '3.10'] + python-version: ['3.8', '3.9', '3.10', '3.11'] steps: - name: Check out the code uses: actions/checkout@v3 diff --git a/fortuna/output_calib_model/state.py b/fortuna/output_calib_model/state.py index 9c400a1a..11fa55d4 100644 --- a/fortuna/output_calib_model/state.py +++ b/fortuna/output_calib_model/state.py @@ -16,13 +16,13 @@ CalibParams, OptaxOptimizer, ) -from fortuna.utils.strings import convert_string_to_jnp_array +from fortuna.utils.strings import convert_string_to_tuple class OutputCalibState(TrainState): params: CalibParams mutable: Optional[CalibMutable] = None - encoded_name: jnp.ndarray = convert_string_to_jnp_array("OutputCalibState") + encoded_name: tuple = convert_string_to_tuple("OutputCalibState") @classmethod def init( diff --git a/fortuna/prob_model/posterior/laplace/laplace_state.py b/fortuna/prob_model/posterior/laplace/laplace_state.py index 51f41caa..2970811b 100644 --- a/fortuna/prob_model/posterior/laplace/laplace_state.py +++ b/fortuna/prob_model/posterior/laplace/laplace_state.py @@ -20,7 +20,7 @@ ) from fortuna.utils.nested_dicts import nested_pair from fortuna.utils.strings import ( - convert_string_to_jnp_array, + convert_string_to_tuple, encode_tuple_of_lists_of_strings_to_numpy, ) @@ -36,7 +36,7 @@ class LaplaceState(PosteriorState): """ prior_log_var: float = 0.0 - encoded_name: jnp.ndarray = convert_string_to_jnp_array("LaplaceState") + encoded_name: tuple = convert_string_to_tuple("LaplaceState") _encoded_which_params: Optional[Dict[str, Array]] = None @classmethod diff --git a/fortuna/prob_model/posterior/map/map_state.py b/fortuna/prob_model/posterior/map/map_state.py index 39db5769..159233da 100644 --- a/fortuna/prob_model/posterior/map/map_state.py +++ b/fortuna/prob_model/posterior/map/map_state.py @@ -3,7 +3,7 @@ import jax.numpy as jnp from fortuna.prob_model.posterior.state import PosteriorState -from fortuna.utils.strings import convert_string_to_jnp_array +from fortuna.utils.strings import convert_string_to_tuple class MAPState(PosteriorState): @@ -14,4 +14,4 @@ class MAPState(PosteriorState): MAP state name encoded as an array. """ - encoded_name: jnp.ndarray = convert_string_to_jnp_array("MAPState") + encoded_name: tuple = convert_string_to_tuple("MAPState") diff --git a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_state.py b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_state.py index 1c7ba0eb..2db8bb78 100644 --- a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_state.py +++ b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_state.py @@ -12,7 +12,7 @@ NormalizingFlowState, ) from fortuna.typing import Array -from fortuna.utils.strings import convert_string_to_jnp_array +from fortuna.utils.strings import convert_string_to_tuple class ADVIState(NormalizingFlowState): @@ -23,5 +23,5 @@ class ADVIState(NormalizingFlowState): ADVI state name encoded as an array. """ - encoded_name: jnp.ndarray = convert_string_to_jnp_array("ADVIState") + encoded_name: tuple = convert_string_to_tuple("ADVIState") _encoded_which_params: Optional[Dict[str, List[Array]]] = None diff --git a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_state.py b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_state.py index 89bf4aed..5aec1731 100644 --- a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_state.py +++ b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_state.py @@ -17,7 +17,7 @@ OptaxOptimizer, ) from fortuna.utils.strings import ( - convert_string_to_jnp_array, + convert_string_to_tuple, encode_tuple_of_lists_of_strings_to_numpy, ) @@ -30,7 +30,7 @@ class CyclicalSGLDState(PosteriorState): CyclicalSGLDState state name encoded as an array. """ - encoded_name: jnp.ndarray = convert_string_to_jnp_array("CyclicalSGLDState") + encoded_name: tuple = convert_string_to_tuple("CyclicalSGLDState") _encoded_which_params: Optional[Dict[str, List[Array]]] = None @classmethod diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py index e3dec9cd..28605cb4 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py @@ -17,7 +17,7 @@ OptaxOptimizer, ) from fortuna.utils.strings import ( - convert_string_to_jnp_array, + convert_string_to_tuple, encode_tuple_of_lists_of_strings_to_numpy, ) @@ -30,7 +30,7 @@ class SGHMCState(PosteriorState): SGHMC state name encoded as an array. """ - encoded_name: jnp.ndarray = convert_string_to_jnp_array("SGHMCState") + encoded_name: tuple = convert_string_to_tuple("SGHMCState") _encoded_which_params: Optional[Dict[str, List[Array]]] = None @classmethod diff --git a/fortuna/prob_model/posterior/state.py b/fortuna/prob_model/posterior/state.py index 68475afd..110c148a 100644 --- a/fortuna/prob_model/posterior/state.py +++ b/fortuna/prob_model/posterior/state.py @@ -18,7 +18,7 @@ OptaxOptimizer, Params, ) -from fortuna.utils.strings import convert_string_to_jnp_array +from fortuna.utils.strings import convert_string_to_tuple class PosteriorState(TrainState): @@ -33,7 +33,7 @@ class PosteriorState(TrainState): calib_mutable: Optional[CalibMutable] = None grad_accumulated: Optional[jnp.ndarray] = None dynamic_scale: Optional[dynamic_scale.DynamicScale] = None - encoded_name: jnp.ndarray = convert_string_to_jnp_array("PosteriorState") + encoded_name: tuple = convert_string_to_tuple("PosteriorState") @classmethod def init( diff --git a/fortuna/prob_model/posterior/swag/swag_state.py b/fortuna/prob_model/posterior/swag/swag_state.py index 10540e57..42f25dfe 100644 --- a/fortuna/prob_model/posterior/swag/swag_state.py +++ b/fortuna/prob_model/posterior/swag/swag_state.py @@ -15,7 +15,7 @@ Array, OptaxOptimizer, ) -from fortuna.utils.strings import convert_string_to_jnp_array +from fortuna.utils.strings import convert_string_to_tuple class SWAGState(PosteriorState): @@ -35,7 +35,7 @@ class SWAGState(PosteriorState): mean: Optional[jnp.ndarray] = None std: Optional[jnp.ndarray] = None dev: Optional[jnp.ndarray] = None - encoded_name: jnp.ndarray = convert_string_to_jnp_array("SWAGState") + encoded_name: tuple = convert_string_to_tuple("SWAGState") _encoded_which_params: Optional[Dict[str, List[Array]]] = None @classmethod diff --git a/fortuna/training/mixin.py b/fortuna/training/mixin.py index ce9772d6..353c8aa6 100755 --- a/fortuna/training/mixin.py +++ b/fortuna/training/mixin.py @@ -81,7 +81,7 @@ def restore_checkpoint( raise ValueError( f"No checkpoint was found in `restore_checkpoint_path={restore_checkpoint_path}`." ) - name = "".join([chr(n) for n in d["encoded_name"].tolist()]) + name = "".join([chr(n) for n in d["encoded_name"].values()]) return name_to_train_state[name].value.init_from_dict(d, optimizer, **kwargs) def get_path_latest_checkpoint( diff --git a/fortuna/training/train_state.py b/fortuna/training/train_state.py index cf1e85e6..0707cdb7 100644 --- a/fortuna/training/train_state.py +++ b/fortuna/training/train_state.py @@ -12,11 +12,11 @@ import jax.numpy as jnp from fortuna.typing import Params -from fortuna.utils.strings import convert_string_to_jnp_array +from fortuna.utils.strings import convert_string_to_tuple class TrainState(train_state.TrainState): - encoded_name: jnp.ndarray = convert_string_to_jnp_array("TrainState") + encoded_name: tuple = convert_string_to_tuple("TrainState") frozen_params: Optional[Params] = None dynamic_scale: Optional[dynamic_scale.DynamicScale] = None diff --git a/fortuna/utils/strings.py b/fortuna/utils/strings.py index 306e6d49..5dae4eaf 100644 --- a/fortuna/utils/strings.py +++ b/fortuna/utils/strings.py @@ -12,8 +12,8 @@ from fortuna.typing import Array -def convert_string_to_jnp_array(s: str) -> jnp.ndarray: - return jnp.array([ord(c) for c in s]) +def convert_string_to_tuple(s: str) -> Tuple: + return tuple([ord(c) for c in s]) def convert_string_to_np_array(s: str) -> np.ndarray: diff --git a/poetry.lock b/poetry.lock index ecab5a9e..9d785657 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2494,13 +2494,14 @@ toml = ["toml"] [[package]] name = "keras" -version = "2.12.0" +version = "2.13.1" description = "Deep learning for humans." category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "keras-2.12.0-py2.py3-none-any.whl", hash = "sha256:35c39534011e909645fb93515452e98e1a0ce23727b55d4918b9c58b2308c15e"}, + {file = "keras-2.13.1-py3-none-any.whl", hash = "sha256:5ce5f706f779fa7330e63632f327b75ce38144a120376b2ae1917c00fa6136af"}, + {file = "keras-2.13.1.tar.gz", hash = "sha256:5df12cc241a015a11b65ddb452c0eeb2744fce21d9b54ba48db87492568ccc68"}, ] [[package]] @@ -2948,6 +2949,7 @@ files = [ numpy = [ {version = ">1.20", markers = "python_version <= \"3.9\""}, {version = ">=1.21.2", markers = "python_version > \"3.9\""}, + {version = ">=1.23.3", markers = "python_version > \"3.10\""}, ] [package.extras] @@ -3847,6 +3849,7 @@ files = [ numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -5597,13 +5600,13 @@ files = [ [[package]] name = "tensorboard" -version = "2.12.3" +version = "2.13.0" description = "TensorBoard lets you watch Tensors Flow" category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "tensorboard-2.12.3-py3-none-any.whl", hash = "sha256:b4a69366784bc347e02fbe7d847e01896a649ca52f8948a11005e205dcf724fb"}, + {file = "tensorboard-2.13.0-py3-none-any.whl", hash = "sha256:ab69961ebddbddc83f5fa2ff9233572bdad5b883778c35e4fe94bf1798bd8481"}, ] [package.dependencies] @@ -5635,49 +5638,48 @@ files = [ [[package]] name = "tensorflow-cpu" -version = "2.12.0" +version = "2.13.1" description = "TensorFlow is an open source machine learning framework for everyone." category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "tensorflow_cpu-2.12.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:734ce850e2b3493041bdc071b594f0f78d35e4bfce5a7e0a98d449b20420e01d"}, - {file = "tensorflow_cpu-2.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:361b19b5a64bf611beccd22de1fc04f614a8c157ac99893d9702ed24932018d6"}, - {file = "tensorflow_cpu-2.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:d5ad746bf8c87d9a9fcea4698828ba1d101a7f7bfd323a2571130374a192578b"}, - {file = "tensorflow_cpu-2.12.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:b9c8f0d0658da8a5b25a4fe5ca315f86c449eb11e30d79cea49c7658be75a825"}, - {file = "tensorflow_cpu-2.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8c7047552a2d759f3e65ac13e36dd24bb5fec2e6576e848287811ec44b3d62f"}, - {file = "tensorflow_cpu-2.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:8fdb636736f95094368bc7d26bb3b8ed93ba820cc5d95f847e00bf4a7645463d"}, - {file = "tensorflow_cpu-2.12.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:5beeb99d2a1cc1383ca981513c35a4a18157e52d91a89e69c94cb7b7e411f0d8"}, - {file = "tensorflow_cpu-2.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a406f751180fe5282776e8bc84f39a2dc2b796c3ae35fbe20e4edc86ec580dd3"}, - {file = "tensorflow_cpu-2.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:b6ba926f9a56cdf0657defc6d046735e31ded383054f67c1a16ef2b0511f68d7"}, - {file = "tensorflow_cpu-2.12.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:ef4f142b6fe75fcc71ada6331ed2a15ed61b7034187049d0ef1dac482d52db78"}, - {file = "tensorflow_cpu-2.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55685b9a19c8ecb2587fb53914c045b188ed0289a2c6495e4e59d5fb082da9cc"}, - {file = "tensorflow_cpu-2.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:374b15d1cec1a62006e388062e89dd4899a121272d41ea5d3fcbcc96e2d875c9"}, + {file = "tensorflow_cpu-2.13.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a05ae27c373224e3db3a56c5fc3330ad9831e5a1a67fec11126ee846f5cd48ee"}, + {file = "tensorflow_cpu-2.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfedeee206ccf8130d14d8990dbe991946324df92458ee1bdbeed1c21497beb3"}, + {file = "tensorflow_cpu-2.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:db4cb514ec72cdf0a25eb545ed8b0c0c56cc3f38bdbcaac3f729181c4632f699"}, + {file = "tensorflow_cpu-2.13.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:20f278a756665c4eccd04423035cde334d65d4cfc7bb6f7f42cf8067764de982"}, + {file = "tensorflow_cpu-2.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87be6e61314b2c0fa062f71e26e41d6a17e162ffe4da72c820fe8f1fc94d97fe"}, + {file = "tensorflow_cpu-2.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:21c9f87851bac10d5d4e1b252cf577c28e7f27339361ff67a1287187881868ad"}, + {file = "tensorflow_cpu-2.13.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:53e414d9130f0849c9d2cf05b50eba137525d3442b92f8d3efee6238b31dc7be"}, + {file = "tensorflow_cpu-2.13.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba6eec4d9a1e86d36fd7e7d010e07ce69702bb3cb68bc33a02a668b1928772a9"}, + {file = "tensorflow_cpu-2.13.1-cp38-cp38-win_amd64.whl", hash = "sha256:a50b1ccd420757de33a1c9b4152e79b71dfaa311615736a0adc92d6333f58b6b"}, + {file = "tensorflow_cpu-2.13.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:2087fe00d1050991ffa295c56c2c22b2ead01e527452ed252bfc8395ae386567"}, + {file = "tensorflow_cpu-2.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a6d6b57d1336eafc4b980be1f618382eb95de12657c6e34ce63d3f9869a8f92"}, + {file = "tensorflow_cpu-2.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:548a3ea205a2de8e4c4686855834e949bb1602a577891518bcb425aafa790705"}, ] [package.dependencies] absl-py = ">=1.0.0" astunparse = ">=1.6.0" -flatbuffers = ">=2.0" +flatbuffers = ">=23.1.21" gast = ">=0.2.1,<=0.4.0" google-pasta = ">=0.1.1" grpcio = ">=1.24.3,<2.0" h5py = ">=2.9.0" -jax = ">=0.3.15" -keras = ">=2.12.0,<2.13" +keras = ">=2.13.1,<2.14" libclang = ">=13.0.0" -numpy = ">=1.22,<1.24" +numpy = ">=1.22,<=1.24.3" opt-einsum = ">=2.3.2" packaging = "*" protobuf = ">=3.20.3,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" setuptools = "*" six = ">=1.12.0" -tensorboard = ">=2.12,<2.13" -tensorflow-estimator = ">=2.12.0,<2.13" +tensorboard = ">=2.13,<2.14" +tensorflow-estimator = ">=2.13.0,<2.14" tensorflow-io-gcs-filesystem = {version = ">=0.23.1", markers = "platform_machine != \"arm64\" or platform_system != \"Darwin\""} termcolor = ">=1.1.0" -typing-extensions = ">=3.6.6" -wrapt = ">=1.11.0,<1.15" +typing-extensions = ">=3.6.6,<4.6.0" +wrapt = ">=1.11.0" [[package]] name = "tensorflow-datasets" @@ -5748,13 +5750,13 @@ youtube-vis = ["pycocotools"] [[package]] name = "tensorflow-estimator" -version = "2.12.0" +version = "2.13.0" description = "TensorFlow Estimator." category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "tensorflow_estimator-2.12.0-py2.py3-none-any.whl", hash = "sha256:59b191bead4883822de3d63ac02ace11a83bfe6c10d64d0c4dfde75a50e60ca1"}, + {file = "tensorflow_estimator-2.13.0-py2.py3-none-any.whl", hash = "sha256:6f868284eaa654ae3aa7cacdbef2175d0909df9fcf11374f5166f8bf475952aa"}, ] [[package]] @@ -5790,43 +5792,39 @@ tensorflow-rocm = ["tensorflow-rocm (>=2.12.0,<2.13.0)"] [[package]] name = "tensorflow-macos" -version = "2.12.0" +version = "2.13.1" description = "TensorFlow is an open source machine learning framework for everyone." category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "tensorflow_macos-2.12.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:db464c88e10e927725997f9b872a21c9d057789d3b7e9a26e4ef1af41d0bcc8c"}, - {file = "tensorflow_macos-2.12.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:172277c33cb1ae0da19f98c5bcd4946149cfa73c8ea05c6ba18365d58dd3c6f2"}, - {file = "tensorflow_macos-2.12.0-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:e3fa53e63672fd71998bbd71cc5478c74dbe5a2d9291d1801c575358c28403c2"}, - {file = "tensorflow_macos-2.12.0-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:5499312c21ed3ed47cc6b4cf861896e9564c2c32d8d3c2ef1437c5ca31adfc73"}, - {file = "tensorflow_macos-2.12.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:84cb873c90be63efabfecca53fdc48b734a037d0750532b55cb7ce7c343b5cac"}, - {file = "tensorflow_macos-2.12.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:85d9451a691324490e1d644b1051972e14edc249004eef5831b3510df9e36515"}, + {file = "tensorflow_macos-2.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:7a09d250532c87083b00396b0367baa73558a50d0bd8982907297422858f6af6"}, + {file = "tensorflow_macos-2.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:8f95c332ee60a32691963adab31b6afa29a2302de05fc4cd6009cd96866a6886"}, + {file = "tensorflow_macos-2.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:f2bade2725c5a81dc0d96acd1c5b8475593bc9bdf1da0055bd72a09a501d7866"}, ] [package.dependencies] absl-py = ">=1.0.0" astunparse = ">=1.6.0" -flatbuffers = ">=2.0" +flatbuffers = ">=23.1.21" gast = ">=0.2.1,<=0.4.0" google-pasta = ">=0.1.1" grpcio = ">=1.24.3,<2.0" h5py = ">=2.9.0" -jax = ">=0.3.15" -keras = ">=2.12.0,<2.13" +keras = ">=2.13.1,<2.14" libclang = ">=13.0.0" -numpy = ">=1.22,<1.24" +numpy = ">=1.22,<=1.24.3" opt-einsum = ">=2.3.2" packaging = "*" protobuf = ">=3.20.3,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" setuptools = "*" six = ">=1.12.0" -tensorboard = ">=2.12,<2.13" -tensorflow-estimator = ">=2.12.0,<2.13" +tensorboard = ">=2.13,<2.14" +tensorflow-estimator = ">=2.13.0,<2.14" tensorflow-io-gcs-filesystem = {version = ">=0.23.1", markers = "platform_machine != \"arm64\" or platform_system != \"Darwin\""} termcolor = ">=1.1.0" -typing-extensions = ">=3.6.6" -wrapt = ">=1.11.0,<1.15" +typing-extensions = ">=3.6.6,<4.6.0" +wrapt = ">=1.11.0" [[package]] name = "tensorflow-metadata" @@ -6771,5 +6769,5 @@ transformers = ["datasets", "transformers"] [metadata] lock-version = "2.0" -python-versions = ">=3.8,<3.11" -content-hash = "56d6ba9935cf9ebda97c1aaedfa43be401154d396f9d7e8df08a2fb80df8b1fe" +python-versions = ">=3.8,<3.12" +content-hash = "0d5afd19b3364e869fbb1409b172f1c1d0cb3f0b679dc47eea3d821a92a01c2e" diff --git a/pyproject.toml b/pyproject.toml index 70c22e15..850731a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,14 +9,14 @@ documentation = "https://aws-fortuna.readthedocs.io/en/latest/" packages = [{include = "fortuna"}] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.8,<3.12" flax = "^0.6.2" optax = "^0.1.3" matplotlib = "^3.6.2" tqdm = "^4.64.1" numpy = "^1.23.4" tensorflow-cpu = { version = "^2.11.0", markers = "sys_platform != 'darwin'" } -tensorflow-macos = { version = "^2.11.0", markers = "sys_platform == 'darwin'" } +tensorflow-macos = { version = "^2.13.1", markers = "sys_platform == 'darwin'" } Sphinx = { version = "^5.3.0", optional = true } sphinx-autodoc-typehints = { version = "^1.19.5", optional = true } nbsphinx = { version = "^0.8.10", optional = true } diff --git a/tests/fortuna/test_mixin.py b/tests/fortuna/test_mixin.py index 642ab8c7..bdd3403c 100755 --- a/tests/fortuna/test_mixin.py +++ b/tests/fortuna/test_mixin.py @@ -113,16 +113,6 @@ def test_restore_checkpoint(self): calib_params=None, calib_mutable=None, ) - restored_state = trainer.restore_checkpoint( - tmp_dir, prefix="test_prefix_" - ) - mc.restore_checkpoint.assert_called_with( - ckpt_dir=tmp_dir, - target=None, - step=None, - prefix="test_prefix_", - parallel=True, - ) class TestEarlyStoppingMixins(unittest.TestCase):