From 42b1095564cec368b149099e98d073cf9e56e822 Mon Sep 17 00:00:00 2001 From: Smit Chaudhary Date: Wed, 12 Feb 2025 15:32:31 +0100 Subject: [PATCH] Implement Exploratory Landscape Analysis (ELA) with Information Content (IC) (#647) --- .../tutorials/qml/ml_tools/data_and_config.md | 4 +- docs/tutorials/qml/ml_tools/trainer.md | 90 ++++- qadence/ml_tools/__init__.py | 1 + qadence/ml_tools/information/__init__.py | 3 + .../information/information_content.py | 339 ++++++++++++++++++ qadence/ml_tools/trainer.py | 107 +++++- tests/ml_tools/test_information_content.py | 148 ++++++++ 7 files changed, 687 insertions(+), 5 deletions(-) create mode 100644 qadence/ml_tools/information/__init__.py create mode 100644 qadence/ml_tools/information/information_content.py create mode 100644 tests/ml_tools/test_information_content.py diff --git a/docs/tutorials/qml/ml_tools/data_and_config.md b/docs/tutorials/qml/ml_tools/data_and_config.md index 1a83b4305..445fa6115 100644 --- a/docs/tutorials/qml/ml_tools/data_and_config.md +++ b/docs/tutorials/qml/ml_tools/data_and_config.md @@ -54,7 +54,7 @@ The [`TrainConfig`][qadence.ml_tools.config.TrainConfig] tells [`Trainer`][qaden It is also possible to provide custom callback functions by instantiating a [`Callback`][qadence.ml_tools.callbacks.Callback] with a function `callback`. -For example of how to use the TrainConfig with `Trainer`, please see [Examples in Trainer](/trainer.md) +For example of how to use the TrainConfig with `Trainer`, please see [Examples in Trainer](../trainer) ### 2.1 Explanation of `TrainConfig` Attributes @@ -172,7 +172,7 @@ def validation_criterion(val_loss: float, best_val_loss: float, val_epsilon: flo #### Custom Callbacks `TrainConfig` supports custom callbacks that can be triggered at specific stages of training. The `callbacks` attribute accepts a list of callback instances, which allow for custom behaviors like early stopping or additional logging. -See [Callbacks](/callbacks.md) for more details. +See [Callbacks](../callbacks) for more details. - `callbacks` (**list[Callback]**): List of custom callbacks to execute during training. diff --git a/docs/tutorials/qml/ml_tools/trainer.md b/docs/tutorials/qml/ml_tools/trainer.md index cefd1ecdc..3fccbe067 100644 --- a/docs/tutorials/qml/ml_tools/trainer.md +++ b/docs/tutorials/qml/ml_tools/trainer.md @@ -469,7 +469,93 @@ for i in range(n_epochs): -### 6.4. Custom `train` loop +### 6.4. Performing pre-training Exploratory Landscape Analysis (ELA) with Information Content (IC) + +Before one embarks on training a model, one may wish to analyze the loss landscape to judge the trainability and catch vanishing gradient issues early. +One way of doing this is made possible via calculating the [Information Content of the loss landscape](https://www.nature.com/articles/s41534-024-00819-8). +This is done by discretizing the gradient in the loss landscapes and then calculating the information content therein. +This serves as a measure of flatness or ruggedness of the loss landscape. +Quantitatively, the information content allows us to get bounds on the average norm of the gradient in the loss landscape. + +Using the information content technique, we can get two types of bounds on the average of the norm of the gradient. +1. The bounds as achieved in the maximum Information Content regime: Gives us a lower and upper bound on the average norm of the gradient in case high Information Content is achieved. +2. The bounds as achieved in the sensitivity regime: Gives us an upper bound on the average norm of the gradient corresponding to the sensitivity IC achieved. + +Thus, we get 3 bounds. The upper and lower bounds for the maximum IC and the upper bound for the sensitivity IC. + +The `Trainer` class provides a method to calculate these gradient norms. + +```python exec="on" source="material-block" html="1" +import torch +from torch.optim.adam import Adam + +from qadence.constructors import ObservableConfig +from qadence.ml_tools.config import AnsatzConfig, FeatureMapConfig, TrainConfig +from qadence.ml_tools.data import to_dataloader +from qadence.ml_tools.models import QNN +from qadence.ml_tools.optimize_step import optimize_step +from qadence.ml_tools.trainer import Trainer +from qadence.operations.primitive import Z + +fm_config = FeatureMapConfig(num_features=1) +ansatz_config = AnsatzConfig(depth=4) +obs_config = ObservableConfig(detuning=Z) + +qnn = QNN.from_configs( + register=4, + obs_config=obs_config, + fm_config=fm_config, + ansatz_config=ansatz_config, +) + +optimizer = Adam(qnn.parameters(), lr=0.001) + +batch_size = 25 +x = torch.linspace(0, 1, 32).reshape(-1, 1) +y = torch.sin(x) +train_loader = to_dataloader(x, y, batch_size=batch_size, infinite=True) + +train_config = TrainConfig(max_iter=100) + +trainer = Trainer( + model=qnn, + optimizer=optimizer, + config=train_config, + loss_fn="mse", + train_dataloader=train_loader, + optimize_step=optimize_step, +) + +# Perform exploratory landscape analysis with Information Content +ic_sensitivity_threshold = 1e-4 +epsilons = torch.logspace(-2, 2, 10) + +max_ic_lower_bound, max_ic_upper_bound, sensitivity_ic_upper_bound = ( + trainer.get_ic_grad_bounds( + eta=ic_sensitivity_threshold, + epsilons=epsilons, + ) +) + +print( + f"Using maximum IC, the gradients are bound between {max_ic_lower_bound:.3f} and {max_ic_upper_bound:.3f}\n" +) +print( + f"Using sensitivity IC, the gradients are bounded above by {sensitivity_ic_upper_bound:.3f}" +) + +# Resume training as usual... + +trainer.fit(train_loader) +``` + +The `get_ic_grad_bounds` function returns a tuple containing a tuple containing the lower bound as achieved in maximum IC case, upper bound as achieved in maximum IC case, and the upper bound for the sensitivity IC case. + +The sensitivity IC bound is guaranteed to appear, while the usually much tighter bounds that we get via the maximum IC case is only meaningful in the case of the maximum achieved information content $H(\epsilon)_{max} \geq log_6(2)$. + + + +### 6.5. Custom `train` loop If you need custom training functionality that goes beyond what is available in `qadence.ml_tools.Trainer` you can write your own @@ -546,6 +632,6 @@ def train( return model, optimizer ``` -### 6.5. Gradient-free optimization using `Trainer` +### 6.6. Gradient-free optimization using `Trainer` We can achieve gradient free optimization with `Trainer.set_use_grad(False)` or `trainer.disable_grad_opt(ng_optimizer)`. An example solving a QUBO using gradient free optimization based on `Nevergrad` optimizers and `Trainer` is shown in the [analog QUBO Tutorial](../../digital_analog_qc/analog-qubo.md). diff --git a/qadence/ml_tools/__init__.py b/qadence/ml_tools/__init__.py index 4a4334c32..e1af7a6b3 100644 --- a/qadence/ml_tools/__init__.py +++ b/qadence/ml_tools/__init__.py @@ -4,6 +4,7 @@ from .config import AnsatzConfig, FeatureMapConfig, TrainConfig from .constructors import create_ansatz, create_fm_blocks, observable_from_config from .data import DictDataLoader, InfiniteTensorDataset, OptimizeResult, to_dataloader +from .information import InformationContent from .models import QNN from .optimize_step import optimize_step as default_optimize_step from .parameters import get_parameters, num_parameters, set_parameters diff --git a/qadence/ml_tools/information/__init__.py b/qadence/ml_tools/information/__init__.py new file mode 100644 index 000000000..1e1dec061 --- /dev/null +++ b/qadence/ml_tools/information/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from .information_content import InformationContent diff --git a/qadence/ml_tools/information/information_content.py b/qadence/ml_tools/information/information_content.py new file mode 100644 index 000000000..d9d5b1997 --- /dev/null +++ b/qadence/ml_tools/information/information_content.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +import functools +from logging import getLogger +from math import log, sqrt +from statistics import NormalDist +from typing import Any, Callable + +import torch +from torch import nn +from torch.func import functional_call # type: ignore + +logger = getLogger("ml_tools") + + +class InformationContent: + def __init__( + self, + model: nn.Module, + loss_fn: Callable, + xs: Any, + epsilons: torch.Tensor, + variation_multiple: int = 20, + ) -> None: + """Information Landscape class. + + This class handles the study of loss landscape from information theoretic + perspective and provides methods to get bounds on the norm of the + gradient from the Information Content of the loss landscape. + + Args: + model: The quantum or classical model to analyze. + loss_fn: Loss function that takes model output and calculates loss + xs: Input data to evaluate the model on + epsilons: The thresholds to use for discretization of the finite derivatives + variation_multiple: The number of sets of variational parameters to generate per each + variational parameter. The number of variational parameters required for the + statistical analysis scales linearly with the amount of them present in the + model. This is that linear factor. + + Notes: + This class provides flexibility in terms of what the model, the loss function, + and the xs are. The only requirement is that the loss_fn takes the model and xs as + arguments and returns the loss, and another dictionary of other metrics. + + Thus, assumed structure: + loss_fn(model, xs) -> (loss, metrics, ...) + + Example: A Classifier + ```python + model = nn.Linear(10, 1) + + def loss_fn( + model: nn.Module, + xs: tuple[torch.Tensor, torch.Tensor] + ) -> tuple[torch.Tensor, dict[str, float]: + criterion = nn.MSELoss() + inputs, labels = xs + outputs = model(inputs) + loss = criterion(outputs, labels) + metrics = {"loss": loss.item()} + return loss, metrics + + xs = (torch.randn(10, 10), torch.randn(10, 1)) + + info_landscape = InfoLandscape(model, loss_fn, xs) + ``` + In this example, the model is a linear classifier, and the `xs` include both the + inputs and the target labels. The logic for calculation of the loss from this lies + entirely within the `loss_fn` function. This can then further be used to obtain the + bounds on the average norm of the gradient of the loss function. + + Example: A Physics Informed Neural Network + ```python + class PhysicsInformedNN(nn.Module): + // + + def forward(self, xs: dict[str, torch.Tensor]): + return { + "pde_residual": pde_residual(xs["pde"]), + "boundary_condition": bc_term(xs["bc"]), + } + + def loss_fn( + model: PhysicsInformedNN, + xs: dict[str, torch.Tensor] + ) -> tuple[torch.Tensor, dict[str, float]: + pde_residual, bc_term = model(xs) + loss = torch.mean(torch.sum(pde_residual**2, dim=1), dim=0) + + torch.mean(torch.sum(bc_term**2, dim=1), dim=0) + + return loss, {"pde_residual": pde_residual, "bc_term": bc_term} + + xs = { + "pde": torch.linspace(0, 1, 10), + "bc": torch.tensor([0.0]), + } + + info_landscape = InfoLandscape(model, loss_fn, xs) + ``` + + In this example, the model is a Physics Informed Neural Network, and the `xs` + are the inputs to the different residual components of the model. The logic + for calculation of the residuals lies within the PhysicsInformedNN class, and + the loss function is defined to calculate the loss that is to be optimized + from these residuals. This can then further be used to obtain the + bounds on the average norm of the gradient of the loss function. + + The first value that the `loss_fn` returns is the loss value that is being optimized. + The function is also expected to return other value(s), often the metrics that are + used to calculate the loss. These values are ignored for the purpose of this class. + """ + self.model = model + self.loss_fn = loss_fn + self.xs = xs + self.epsilons = epsilons + self.device = next(model.parameters()).device + + self.param_shapes = {} + self.total_params = 0 + + for name, param in model.named_parameters(): + self.param_shapes[name] = param.shape + self.total_params += param.numel() + self.n_variations = variation_multiple * self.total_params + self.all_variations = torch.empty( + (self.n_variations, self.total_params), device=self.device + ).uniform_(0, 2 * torch.pi) + + def reshape_param_variations(self) -> dict[str, torch.Tensor]: + """Reshape variations of the model's variational parameters. + + Returns: + Dictionary of parameter tensors, each with shape [n_variations, *param_shape] + """ + param_variations = {} + start_idx = 0 + + for name, shape in self.param_shapes.items(): + param_size = torch.prod(torch.tensor(shape)).item() + param_variations[name] = self.all_variations[ + :, start_idx : start_idx + param_size + ].view(self.n_variations, *shape) + start_idx += param_size + + return param_variations + + def batched_loss(self) -> torch.Tensor: + """Calculate loss for all parameter variations in a batched manner. + + Returns: Tensor of loss values for each parameter variation + """ + param_variations = self.reshape_param_variations() + losses = torch.zeros(self.n_variations, device=self.device) + + for i in range(self.n_variations): + params = {name: param[i] for name, param in param_variations.items()} + current_model = lambda x: functional_call(self.model, params, (x,)) + losses[i] = self.loss_fn(current_model, self.xs)[0] + + return losses + + def randomized_finite_der(self) -> torch.Tensor: + """ + Calculate normalized finite difference of loss on doing random walk in the parameter space. + + This serves as a proxy for the derivative of the loss with respect to parameters. + + Returns: + Tensor containing normalized finite differences (approximate directional derivatives) + between consecutive points in the random walk. Shape: [n_variations - 1] + """ + losses = self.batched_loss() + + return (losses[1:] - losses[:-1]) / ( + torch.norm(self.all_variations[1:] - self.all_variations[:-1], dim=1) + 1e-8 + ) + + def discretize_derivatives(self) -> torch.Tensor: + """ + Convert finite derivatives into discrete values. + + Returns: + Tensor containing discretized derivatives with shape [n_epsilons, n_variations-2] + Each row contains {-1, 0, 1} values for that epsilon + """ + derivatives = self.randomized_finite_der() + + derivatives = derivatives.unsqueeze(0) + epsilons = self.epsilons.unsqueeze(1) + + discretized = torch.zeros((len(epsilons), len(derivatives[0])), device=self.device) + discretized[derivatives > epsilons] = 1 + discretized[derivatives < -epsilons] = -1 + + return discretized + + def calculate_transition_probabilities_batch(self) -> torch.Tensor: + """ + Calculate transition probabilities for multiple epsilon values. + + Returns: + Tensor of shape [n_epsilons, 6] containing probabilities for each transition type + Columns order: [+1to0, +1to-1, 0to+1, 0to-1, -1to0, -1to+1] + """ + discretized = self.discretize_derivatives() + + current = discretized[:, :-1] + next_val = discretized[:, 1:] + + transitions = torch.stack( + [ + ((current == 1) & (next_val == 0)).sum(dim=1), + ((current == 1) & (next_val == -1)).sum(dim=1), + ((current == 0) & (next_val == 1)).sum(dim=1), + ((current == 0) & (next_val == -1)).sum(dim=1), + ((current == -1) & (next_val == 0)).sum(dim=1), + ((current == -1) & (next_val == 1)).sum(dim=1), + ], + dim=1, + ).float() + + total_transitions = current.size(1) + probabilities = transitions / total_transitions + + return probabilities + + @functools.cached_property + def calculate_IC(self) -> torch.Tensor: + """ + Calculate Information Content for multiple epsilon values. + + Returns: Tensor of IC values for each epsilon [n_epsilons] + """ + probs = self.calculate_transition_probabilities_batch() + + mask = probs > 1e-4 + + ic_terms = torch.where(mask, -probs * torch.log(probs), torch.zeros_like(probs)) + ic_values = ic_terms.sum(dim=1) / torch.log(torch.tensor(6.0)) + + return ic_values + + def max_IC(self) -> tuple[float, float]: + """ + Get the maximum Information Content and its corresponding epsilon. + + Returns: Tuple of (maximum IC value, optimal epsilon) + """ + max_ic, max_idx = torch.max(self.calculate_IC, dim=0) + max_epsilon = self.epsilons[max_idx] + return max_ic.item(), max_epsilon.item() + + def sensitivity_IC(self, eta: float) -> float: + """ + Find the minimum value of epsilon such that the information content is less than eta. + + Args: + eta: Threshold value, the sensitivity IC. + + Returns: The epsilon value that gives IC that is less than the sensitivity IC. + """ + ic_values = self.calculate_IC + mask = ic_values < eta + epsilons = self.epsilons[mask] + return float(epsilons.min().item()) + + @staticmethod + @functools.lru_cache + def q_value(H_value: float) -> float: + """ + Compute the q value. + + q is the solution to the equation: + H(x) = 4h(x) + 2h(1/2 - 2x) + + It is the value of the probability of 4 of the 6 transitions such that + the IC is the same as the IC of our system. + + This quantity is useful in calculating the bounds on the norms of the gradients. + + Args: + H_value (float): The information content. + + Returns: + float: The q value + """ + + x = torch.linspace(0.001, 0.16667, 10000) + + H = -4 * x * torch.log(x) / torch.log(torch.tensor(6)) - 2 * (0.5 - 2 * x) * torch.log( + 0.5 - 2 * x + ) / torch.log(torch.tensor(6)) + err = torch.abs(H - H_value) + idx = torch.argmin(err) + return float(x[idx].item()) + + def get_grad_norm_bounds_max_IC(self) -> tuple[float, float]: + """ + Compute the bounds on the average norm of the gradient. + + Returns: + tuple[Tensor, Tensor]: The lower and upper bounds. + """ + max_IC, epsilon_m = self.max_IC() + lower_bound = ( + epsilon_m + * sqrt(self.total_params) + / (NormalDist().inv_cdf(1 - 2 * self.q_value(max_IC))) + ) + upper_bound = ( + epsilon_m + * sqrt(self.total_params) + / (NormalDist().inv_cdf(0.5 * (1 + 2 * self.q_value(max_IC)))) + ) + + if max_IC < log(2, 6): + logger.warning( + "Warning: The maximum IC is less than the required value. The bounds may be" + + " inaccurate." + ) + + return lower_bound, upper_bound + + def get_grad_norm_bounds_sensitivity_IC(self, eta: float) -> float: + """ + Compute the bounds on the average norm of the gradient. + + Args: + eta (float): The sensitivity IC. + + Returns: + Tensor: The lower bound. + """ + epsilon_sensitivity = self.sensitivity_IC(eta) + upper_bound = ( + epsilon_sensitivity * sqrt(self.total_params) / (NormalDist().inv_cdf(1 - 3 * eta / 2)) + ) + return upper_bound diff --git a/qadence/ml_tools/trainer.py b/qadence/ml_tools/trainer.py index 6c8df77ef..4187ef654 100644 --- a/qadence/ml_tools/trainer.py +++ b/qadence/ml_tools/trainer.py @@ -14,7 +14,8 @@ from torch.utils.data import DataLoader from qadence.ml_tools.config import TrainConfig -from qadence.ml_tools.data import DictDataLoader, OptimizeResult +from qadence.ml_tools.data import DictDataLoader, OptimizeResult, data_to_device +from qadence.ml_tools.information import InformationContent from qadence.ml_tools.optimize_step import optimize_step, update_ng_parameters from qadence.ml_tools.stages import TrainingStage @@ -711,3 +712,107 @@ def build_optimize_result( self.opt_result = OptimizeResult( self.current_epoch, self.model_old, self.optimizer_old, loss, metrics ) + + def get_ic_grad_bounds( + self, + eta: float, + epsilons: torch.Tensor, + variation_multiple: int = 20, + dataloader: DataLoader | DictDataLoader | None = None, + ) -> tuple[float, float, float]: + """ + Calculate the bounds on the gradient norm of the loss using Information Content. + + Args: + eta (float): The sensitivity IC. + epsilons (torch.Tensor): The epsilons to use for thresholds to for discretization of the + finite derivatives. + variation_multiple (int): The number of sets of variational parameters to generate per + each variational parameter. The number of variational parameters required for the + statisctiacal analysis scales linearly with the amount of them present in the + model. This is that linear factor. + dataloader (DataLoader | DictDataLoader | None): The dataloader for training data. A + new dataloader can be provided, or the dataloader provided in the trinaer will be + used. In case no dataloaders are provided at either places, it assumes that the + model does not require any input data. + + Returns: + tuple[float, float, float]: The max IC lower bound, max IC upper bound, and sensitivity + IC upper bound. + + Examples: + ```python + import torch + from torch.optim.adam import Adam + + from qadence.constructors import ObservableConfig + from qadence.ml_tools.config import AnsatzConfig, FeatureMapConfig, TrainConfig + from qadence.ml_tools.data import to_dataloader + from qadence.ml_tools.models import QNN + from qadence.ml_tools.optimize_step import optimize_step + from qadence.ml_tools.trainer import Trainer + from qadence.operations.primitive import Z + + fm_config = FeatureMapConfig(num_features=1) + ansatz_config = AnsatzConfig(depth=4) + obs_config = ObservableConfig(detuning=Z) + + qnn = QNN.from_configs( + register=4, + obs_config=obs_config, + fm_config=fm_config, + ansatz_config=ansatz_config, + ) + + optimizer = Adam(qnn.parameters(), lr=0.001) + + batch_size = 25 + x = torch.linspace(0, 1, 32).reshape(-1, 1) + y = torch.sin(x) + train_loader = to_dataloader(x, y, batch_size=batch_size, infinite=True) + + train_config = TrainConfig(max_iter=100) + + trainer = Trainer( + model=qnn, + optimizer=optimizer, + config=train_config, + loss_fn="mse", + train_dataloader=train_loader, + optimize_step=optimize_step, + ) + + # Perform exploratory landscape analysis with Information Content + ic_sensitivity_threshold = 1e-4 + epsilons = torch.logspace(-2, 2, 10) + + max_ic_lower_bound, max_ic_upper_bound, sensitivity_ic_upper_bound = ( + trainer.get_ic_grad_bounds( + eta=ic_sensitivity_threshold, + epsilons=epsilons, + ) + ) + + # Resume training as usual... + + trainer.fit(train_loader) + ``` + """ + if not self._use_grad: + logger.warning( + "Gradient norm bounds are only relevant when using a gradient based optimizer. \ + Currently the trainer is set to use a gradient-free optimizer." + ) + + dataloader = dataloader if dataloader is not None else self.train_dataloader + + batch = next(iter(self._batch_iter(dataloader, num_batches=1))) + + xs = data_to_device(batch, device=self.device, dtype=self.data_dtype) + + ic = InformationContent(self.model, self.loss_fn, xs, epsilons) + + max_ic_lower_bound, max_ic_upper_bound = ic.get_grad_norm_bounds_max_IC() + sensitivity_ic_upper_bound = ic.get_grad_norm_bounds_sensitivity_IC(eta) + + return max_ic_lower_bound, max_ic_upper_bound, sensitivity_ic_upper_bound diff --git a/tests/ml_tools/test_information_content.py b/tests/ml_tools/test_information_content.py new file mode 100644 index 000000000..adc572a1b --- /dev/null +++ b/tests/ml_tools/test_information_content.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from math import isclose + +import pytest +import torch +import torch.nn as nn + +from qadence.ml_tools.information.information_content import InformationContent + + +class SimpleModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.layer = nn.Linear(2, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layer(x) + + +def loss_fn(model: nn.Module, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + output = model(x) + return torch.mean(output**2), output + + +@pytest.fixture +def setup_ic() -> InformationContent: + model = SimpleModel() + xs = torch.randn(10, 2) # 10 samples, 2 features each + epsilons = torch.logspace(-4, 4, 10) + landscape = InformationContent( + model=model, loss_fn=loss_fn, xs=xs, epsilons=epsilons, variation_multiple=5 + ) + return landscape + + +def test_initialization(setup_ic: InformationContent) -> None: + info_content = setup_ic + + assert callable(info_content.loss_fn) + assert isinstance(info_content.epsilons, torch.Tensor) + + assert len(info_content.param_shapes) > 0 + assert info_content.total_params > 0 + assert info_content.n_variations == info_content.total_params * 5 + + +def test_reshape_param_variations(setup_ic: InformationContent) -> None: + info_content = setup_ic + param_variations = info_content.reshape_param_variations() + + assert set(param_variations.keys()) == set(info_content.param_shapes.keys()) + + for name, tensor in param_variations.items(): + expected_shape = (info_content.n_variations, *info_content.param_shapes[name]) + assert tensor.shape == expected_shape + + +def test_batched_loss(setup_ic: InformationContent) -> None: + info_content = setup_ic + losses = info_content.batched_loss() + + assert isinstance(losses, torch.Tensor) + assert losses.shape == (info_content.n_variations,) + assert not torch.isnan(losses).any() + assert not torch.isinf(losses).any() + + +def test_randomized_finite_der(setup_ic: InformationContent) -> None: + info_content = setup_ic + derivatives = info_content.randomized_finite_der() + + assert isinstance(derivatives, torch.Tensor) + assert derivatives.shape == (info_content.n_variations - 1,) + assert not torch.isnan(derivatives).any() + assert not torch.isinf(derivatives).any() + + +def test_discretize_derivatives(setup_ic: InformationContent) -> None: + info_content = setup_ic + discretized = info_content.discretize_derivatives() + + assert discretized.shape == (len(info_content.epsilons), info_content.n_variations - 1) + + unique_values = torch.unique(discretized) + assert all(val in [-1.0, 0.0, 1.0] for val in unique_values.tolist()) + + +def test_calculate_transition_probabilities_batch(setup_ic: InformationContent) -> None: + info_content = setup_ic + probs = info_content.calculate_transition_probabilities_batch() + + assert probs.shape == (len(info_content.epsilons), 6) + + assert torch.all(probs >= 0) + assert torch.all(probs <= 1) + assert torch.all(torch.sum(probs, dim=1) <= torch.ones(len(info_content.epsilons))) + + +def test_calculate_IC(setup_ic: InformationContent) -> None: + info_content = setup_ic + ic_values = info_content.calculate_IC + + assert ic_values.shape == (len(info_content.epsilons),) + assert torch.all(ic_values >= 0) + assert torch.all(ic_values <= 1) + + +def test_max_IC(setup_ic: InformationContent) -> None: + info_content = setup_ic + max_ic, optimal_epsilon = info_content.max_IC() + + assert isinstance(max_ic, float) + assert isinstance(optimal_epsilon, float) + assert 0 <= max_ic <= 1 + assert optimal_epsilon > 0 + + +def test_sensitivity_IC(setup_ic: InformationContent) -> None: + info_content = setup_ic + eta = 0.5 + epsilon = info_content.sensitivity_IC(eta) + + assert isinstance(epsilon, float) + assert epsilon > 0 + + +def test_q_value(setup_ic: InformationContent) -> None: + info_content = setup_ic + H_value = 1.0 + q = info_content.q_value(H_value) + + assert isinstance(q, float) + assert isclose(q, 1 / 6, abs_tol=1e-5) + + +def test_grad_norm_bounds(setup_ic: InformationContent) -> None: + info_content = setup_ic + + lower, upper = info_content.get_grad_norm_bounds_max_IC() + assert isinstance(lower, float) + assert isinstance(upper, float) + assert lower <= upper + + eta = 2e-2 + upper_bound = info_content.get_grad_norm_bounds_sensitivity_IC(eta) + assert isinstance(upper_bound, float) + assert upper_bound > 0