Skip to content

Commit

Permalink
Merge branch 'main' of github.com:probml/dynamax
Browse files Browse the repository at this point in the history
  • Loading branch information
slinderman committed Mar 7, 2025
2 parents 51dfdb0 + a14929c commit c13fe76
Show file tree
Hide file tree
Showing 27 changed files with 2,361 additions and 2,135 deletions.
4 changes: 0 additions & 4 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ Low-level inference
-------------------

.. autofunction:: dynamax.nonlinear_gaussian_ssm.extended_kalman_filter
.. autofunction:: dynamax.nonlinear_gaussian_ssm.iterated_extended_kalman_filter
.. autofunction:: dynamax.nonlinear_gaussian_ssm.extended_kalman_smoother
.. autofunction:: dynamax.nonlinear_gaussian_ssm.iterated_extended_kalman_smoother

.. autofunction:: dynamax.nonlinear_gaussian_ssm.unscented_kalman_filter
.. autofunction:: dynamax.nonlinear_gaussian_ssm.unscented_kalman_smoother
Expand All @@ -219,9 +217,7 @@ Low-level inference
-------------------

.. autofunction:: dynamax.generalized_gaussian_ssm.conditional_moments_gaussian_filter
.. autofunction:: dynamax.generalized_gaussian_ssm.iterated_conditional_moments_gaussian_filter
.. autofunction:: dynamax.generalized_gaussian_ssm.conditional_moments_gaussian_smoother
.. autofunction:: dynamax.generalized_gaussian_ssm.iterated_conditional_moments_gaussian_smoother

Types
-----
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
"id": "LOf_GnJovd81"
},
"source": [
"\n",
"\n",
"Online training of an multilayer perceptron (MLP) classifier using conditional moments Gaussian filter (CMGF).\n",
"\n",
"\n",
"We perform sequential (recursive) Bayesian inference for the parameters of a binary MLP classifier.\n",
"This notebook is similar to the previous notebook that demonstrated how to use CMGF for online Bayesian logistic regression. Here, we perform sequential (recursive) Bayesian inference for the parameters of a binary MLP classifier.\n",
"To do this, we treat the parameters of the model as the unknown hidden states.\n",
"We assume that these are approximately constant over time (we add a small amount of Gaussian drift,\n",
"for numerical stability.)\n",
Expand Down Expand Up @@ -120,6 +115,13 @@
"from jax.flatten_util import ravel_pytree"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper function to plot the posterior predictive distribution"
]
},
{
"cell_type": "code",
"execution_count": 5,
Expand Down Expand Up @@ -499,19 +501,10 @@
"id": "ld7GSZ2PsxLh"
},
"source": [
"Finally, we generate a video of the MLP-Classifier being trained."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "PaL3hY7lhTd4"
},
"outputs": [],
"source": [
"import matplotlib.animation as animation\n",
"from IPython.display import HTML"
"### Animation\n",
"Finally, we generate a video of the MLP-Classifier being trained.\n",
"\n",
"Note: This code is commented out by default since it takes a while time to run."
]
},
{
Expand All @@ -522,47 +515,22 @@
},
"outputs": [],
"source": [
"def animate(i):\n",
" ax.cla()\n",
" w_curr = w_means[i]\n",
" Zi = posterior_predictive_grid(input_grid, w_means[i], sigmoid_fn)\n",
" title = f'CMGF-EKF-MLP ({i+1}/500)'\n",
" plot_posterior_predictive(ax, input[:i+1], output[:i+1], title, input_grid, Zi) \n",
" return ax"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 319
},
"id": "tMbQ8HU9iUr7",
"outputId": "c5c55414-8720-432b-e3f5-a20db8c5ba43"
},
"outputs": [],
"source": [
"#fig, ax = plt.subplots(figsize=(6, 5))\n",
"#anim = animation.FuncAnimation(fig, animate, frames=500, interval=50)\n",
"#anim.save(\"cmgf_mlp_classifier.mp4\", dpi=200, bitrate=-1, fps=24)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 381
},
"id": "TAMU4Qc6rYM5",
"outputId": "799e598f-4c85-4eb6-b3c0-6a84ae327ea6"
},
"outputs": [],
"source": [
"#HTML(anim.to_html5_video())"
"# import matplotlib.animation as animation\n",
"# from IPython.display import HTML\n",
"\n",
"# def animate(i):\n",
"# ax.cla()\n",
"# w_curr = w_means[i]\n",
"# Zi = posterior_predictive_grid(input_grid, w_means[i], sigmoid_fn)\n",
"# title = f'CMGF-EKF-MLP ({i+1}/500)'\n",
"# plot_posterior_predictive(ax, input[:i+1], output[:i+1], title, input_grid, Zi) \n",
"# return ax\n",
"\n",
"# fig, ax = plt.subplots(figsize=(6, 5))\n",
"# anim = animation.FuncAnimation(fig, animate, frames=500, interval=50)\n",
"# anim.save(\"cmgf_mlp_classifier.mp4\", dpi=200, bitrate=-1, fps=24)\n",
"\n",
"# HTML(anim.to_html5_video())"
]
}
],
Expand Down
Loading

0 comments on commit c13fe76

Please sign in to comment.