Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: MNIST classification tutorial not running #165

Closed
yucenli opened this issue Dec 4, 2023 · 6 comments · Fixed by #170
Closed

bug: MNIST classification tutorial not running #165

yucenli opened this issue Dec 4, 2023 · 6 comments · Fixed by #170
Labels
bug Something isn't working

Comments

@yucenli
Copy link

yucenli commented Dec 4, 2023

Bug Report

Fortuna version: v0.1.42

Current behavior: When I run the MNIST classification tutorial, I run into a broadcasting error. I think the issue occurs during the calibration step of training for SWAG.

Here's the traceback:

Traceback (most recent call last):
  File "/home/yl9959/23_09_uncertainty/src/ftest.py", line 88, in <module>
    status = prob_model.train(
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/prob_model/classification.py", line 254, in train
    return super().train(
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/prob_model/base.py", line 101, in train
    calib_status = self.calibrate(
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/prob_model/classification.py", line 289, in calibrate
    return super()._calibrate(
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/prob_model/base.py", line 204, in _calibrate
    state, status = calibrator.train(
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/training/output_calibrator.py", line 117, in train
    ) = self._training_loop(
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/training/output_calibrator.py", line 195, in _training_loop
    state, aux = self.training_step(
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/training/output_calibrator.py", line 650, in training_step
    return super().training_step(state, batch, outputs, loss_fun, rng, n_data)
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/training/output_calibrator.py", line 252, in training_step
    (loss, aux), grad = grad_fn(state.params)
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/training/output_calibrator.py", line 247, in <lambda>
    lambda params: self.training_loss_step(
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/prob_model/prob_model_calibrator.py", line 44, in training_loss_step
    loss, aux = loss_fun(
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/prob_model/predictive/base.py", line 297, in _batched_negative_log_joint_prob
    outs = self._batched_log_joint_prob(
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/prob_model/predictive/base.py", line 271, in _batched_log_joint_prob
    outs = lax.map(_lik_log_joint_prob, ensemble_outputs)
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/prob_model/predictive/base.py", line 259, in _lik_log_joint_prob
    return self.likelihood._batched_log_joint_prob(
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/likelihood/base.py", line 248, in _batched_log_joint_prob
    self.prob_output_layer.log_prob(outputs, targets, train=train, **kwargs)
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/fortuna/prob_output_layer/classification.py", line 29, in log_prob
    return jnp.sum(targets * outputs, -1) - jsp.special.logsumexp(outputs, -1)
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 728, in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 256, in deferring_binary_op
    return binary_op(*args)
  File "/home/yl9959/.conda/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py", line 97, in fn
    return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
TypeError: mul got incompatible shapes for broadcasting: (128, 10), (3840, 10)

Is this due to a versioning issue? Thanks!

@yucenli yucenli added the bug Something isn't working label Dec 4, 2023
@gianlucadetommaso
Copy link
Contributor

gianlucadetommaso commented Dec 4, 2023

Hi! Are you running the example on CPU or GPU? If on GPU, how many devices do you have available?

@yucenli
Copy link
Author

yucenli commented Dec 6, 2023

I'm running it on 1 GPU!

@gianlucadetommaso
Copy link
Contributor

gianlucadetommaso commented Dec 7, 2023

I'm not managing to replicate the issue. Could you write me exactly any changes you made while running the example? In the example, the probabilistic model is using a Laplace approximation for posterior inference. I guess you changed something to use SWAG instead?

If no changes were made, I guess you are right and it must be some versioning issue.

@yucenli
Copy link
Author

yucenli commented Dec 19, 2023

I'm running the notebook exactly as is! I also run into the same error in the calibration step for Laplace as well.

Is there documentation I can refer to for the exact versions of of jax, flax, etc, that I should install? I installed fortuna using pip and python 3.10, but not sure what else I should look for.

@gianlucadetommaso gianlucadetommaso linked a pull request Dec 19, 2023 that will close this issue
@gianlucadetommaso
Copy link
Contributor

Hi Yucen, I eventually manage to replicate the issue - thanks a lot for raising this! I fixed the small bug and pushed to Pypi, so you should be able to get a working version by upgrading Fortuna with pip, i.e. pip install --upgrade aws-fortuna.

Please, let me know if it works now!

Although it may be no longer needed, to answer your question, you can find the main dependencies in pyproject.toml, and all dependencies in poetry.lock. In fact, if installation with pip creates problems in the future, you may also try installing the repo using poetry.

@yucenli
Copy link
Author

yucenli commented Dec 21, 2023

It works for me now! Thank you so much.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants