Skip to content

Commit a14929c

Browse files
authored
Merge pull request #405 from probml/joss_review
Clean up the notebooks for the JOSS Review
2 parents 35e1217 + d0d6b32 commit a14929c

19 files changed

+2032
-1700
lines changed

docs/notebooks/generalized_gaussian_ssm/cmgf_logistic_regression_demo.ipynb

+364-350
Large diffs are not rendered by default.

docs/notebooks/generalized_gaussian_ssm/cmgf_mlp_classification_demo.ipynb

+28-60
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@
1717
"id": "LOf_GnJovd81"
1818
},
1919
"source": [
20-
"\n",
21-
"\n",
22-
"Online training of an multilayer perceptron (MLP) classifier using conditional moments Gaussian filter (CMGF).\n",
23-
"\n",
24-
"\n",
25-
"We perform sequential (recursive) Bayesian inference for the parameters of a binary MLP classifier.\n",
20+
"This notebook is similar to the previous notebook that demonstrated how to use CMGF for online Bayesian logistic regression. Here, we perform sequential (recursive) Bayesian inference for the parameters of a binary MLP classifier.\n",
2621
"To do this, we treat the parameters of the model as the unknown hidden states.\n",
2722
"We assume that these are approximately constant over time (we add a small amount of Gaussian drift,\n",
2823
"for numerical stability.)\n",
@@ -120,6 +115,13 @@
120115
"from jax.flatten_util import ravel_pytree"
121116
]
122117
},
118+
{
119+
"cell_type": "markdown",
120+
"metadata": {},
121+
"source": [
122+
"## Helper function to plot the posterior predictive distribution"
123+
]
124+
},
123125
{
124126
"cell_type": "code",
125127
"execution_count": 5,
@@ -499,19 +501,10 @@
499501
"id": "ld7GSZ2PsxLh"
500502
},
501503
"source": [
502-
"Finally, we generate a video of the MLP-Classifier being trained."
503-
]
504-
},
505-
{
506-
"cell_type": "code",
507-
"execution_count": 16,
508-
"metadata": {
509-
"id": "PaL3hY7lhTd4"
510-
},
511-
"outputs": [],
512-
"source": [
513-
"import matplotlib.animation as animation\n",
514-
"from IPython.display import HTML"
504+
"### Animation\n",
505+
"Finally, we generate a video of the MLP-Classifier being trained.\n",
506+
"\n",
507+
"Note: This code is commented out by default since it takes a while time to run."
515508
]
516509
},
517510
{
@@ -522,47 +515,22 @@
522515
},
523516
"outputs": [],
524517
"source": [
525-
"def animate(i):\n",
526-
" ax.cla()\n",
527-
" w_curr = w_means[i]\n",
528-
" Zi = posterior_predictive_grid(input_grid, w_means[i], sigmoid_fn)\n",
529-
" title = f'CMGF-EKF-MLP ({i+1}/500)'\n",
530-
" plot_posterior_predictive(ax, input[:i+1], output[:i+1], title, input_grid, Zi) \n",
531-
" return ax"
532-
]
533-
},
534-
{
535-
"cell_type": "code",
536-
"execution_count": 18,
537-
"metadata": {
538-
"colab": {
539-
"base_uri": "https://localhost:8080/",
540-
"height": 319
541-
},
542-
"id": "tMbQ8HU9iUr7",
543-
"outputId": "c5c55414-8720-432b-e3f5-a20db8c5ba43"
544-
},
545-
"outputs": [],
546-
"source": [
547-
"#fig, ax = plt.subplots(figsize=(6, 5))\n",
548-
"#anim = animation.FuncAnimation(fig, animate, frames=500, interval=50)\n",
549-
"#anim.save(\"cmgf_mlp_classifier.mp4\", dpi=200, bitrate=-1, fps=24)"
550-
]
551-
},
552-
{
553-
"cell_type": "code",
554-
"execution_count": 19,
555-
"metadata": {
556-
"colab": {
557-
"base_uri": "https://localhost:8080/",
558-
"height": 381
559-
},
560-
"id": "TAMU4Qc6rYM5",
561-
"outputId": "799e598f-4c85-4eb6-b3c0-6a84ae327ea6"
562-
},
563-
"outputs": [],
564-
"source": [
565-
"#HTML(anim.to_html5_video())"
518+
"# import matplotlib.animation as animation\n",
519+
"# from IPython.display import HTML\n",
520+
"\n",
521+
"# def animate(i):\n",
522+
"# ax.cla()\n",
523+
"# w_curr = w_means[i]\n",
524+
"# Zi = posterior_predictive_grid(input_grid, w_means[i], sigmoid_fn)\n",
525+
"# title = f'CMGF-EKF-MLP ({i+1}/500)'\n",
526+
"# plot_posterior_predictive(ax, input[:i+1], output[:i+1], title, input_grid, Zi) \n",
527+
"# return ax\n",
528+
"\n",
529+
"# fig, ax = plt.subplots(figsize=(6, 5))\n",
530+
"# anim = animation.FuncAnimation(fig, animate, frames=500, interval=50)\n",
531+
"# anim.save(\"cmgf_mlp_classifier.mp4\", dpi=200, bitrate=-1, fps=24)\n",
532+
"\n",
533+
"# HTML(anim.to_html5_video())"
566534
]
567535
}
568536
],

0 commit comments

Comments
 (0)