|
17 | 17 | "id": "LOf_GnJovd81"
|
18 | 18 | },
|
19 | 19 | "source": [
|
20 |
| - "\n", |
21 |
| - "\n", |
22 |
| - "Online training of an multilayer perceptron (MLP) classifier using conditional moments Gaussian filter (CMGF).\n", |
23 |
| - "\n", |
24 |
| - "\n", |
25 |
| - "We perform sequential (recursive) Bayesian inference for the parameters of a binary MLP classifier.\n", |
| 20 | + "This notebook is similar to the previous notebook that demonstrated how to use CMGF for online Bayesian logistic regression. Here, we perform sequential (recursive) Bayesian inference for the parameters of a binary MLP classifier.\n", |
26 | 21 | "To do this, we treat the parameters of the model as the unknown hidden states.\n",
|
27 | 22 | "We assume that these are approximately constant over time (we add a small amount of Gaussian drift,\n",
|
28 | 23 | "for numerical stability.)\n",
|
|
120 | 115 | "from jax.flatten_util import ravel_pytree"
|
121 | 116 | ]
|
122 | 117 | },
|
| 118 | + { |
| 119 | + "cell_type": "markdown", |
| 120 | + "metadata": {}, |
| 121 | + "source": [ |
| 122 | + "## Helper function to plot the posterior predictive distribution" |
| 123 | + ] |
| 124 | + }, |
123 | 125 | {
|
124 | 126 | "cell_type": "code",
|
125 | 127 | "execution_count": 5,
|
|
499 | 501 | "id": "ld7GSZ2PsxLh"
|
500 | 502 | },
|
501 | 503 | "source": [
|
502 |
| - "Finally, we generate a video of the MLP-Classifier being trained." |
503 |
| - ] |
504 |
| - }, |
505 |
| - { |
506 |
| - "cell_type": "code", |
507 |
| - "execution_count": 16, |
508 |
| - "metadata": { |
509 |
| - "id": "PaL3hY7lhTd4" |
510 |
| - }, |
511 |
| - "outputs": [], |
512 |
| - "source": [ |
513 |
| - "import matplotlib.animation as animation\n", |
514 |
| - "from IPython.display import HTML" |
| 504 | + "### Animation\n", |
| 505 | + "Finally, we generate a video of the MLP-Classifier being trained.\n", |
| 506 | + "\n", |
| 507 | + "Note: This code is commented out by default since it takes a while time to run." |
515 | 508 | ]
|
516 | 509 | },
|
517 | 510 | {
|
|
522 | 515 | },
|
523 | 516 | "outputs": [],
|
524 | 517 | "source": [
|
525 |
| - "def animate(i):\n", |
526 |
| - " ax.cla()\n", |
527 |
| - " w_curr = w_means[i]\n", |
528 |
| - " Zi = posterior_predictive_grid(input_grid, w_means[i], sigmoid_fn)\n", |
529 |
| - " title = f'CMGF-EKF-MLP ({i+1}/500)'\n", |
530 |
| - " plot_posterior_predictive(ax, input[:i+1], output[:i+1], title, input_grid, Zi) \n", |
531 |
| - " return ax" |
532 |
| - ] |
533 |
| - }, |
534 |
| - { |
535 |
| - "cell_type": "code", |
536 |
| - "execution_count": 18, |
537 |
| - "metadata": { |
538 |
| - "colab": { |
539 |
| - "base_uri": "https://localhost:8080/", |
540 |
| - "height": 319 |
541 |
| - }, |
542 |
| - "id": "tMbQ8HU9iUr7", |
543 |
| - "outputId": "c5c55414-8720-432b-e3f5-a20db8c5ba43" |
544 |
| - }, |
545 |
| - "outputs": [], |
546 |
| - "source": [ |
547 |
| - "#fig, ax = plt.subplots(figsize=(6, 5))\n", |
548 |
| - "#anim = animation.FuncAnimation(fig, animate, frames=500, interval=50)\n", |
549 |
| - "#anim.save(\"cmgf_mlp_classifier.mp4\", dpi=200, bitrate=-1, fps=24)" |
550 |
| - ] |
551 |
| - }, |
552 |
| - { |
553 |
| - "cell_type": "code", |
554 |
| - "execution_count": 19, |
555 |
| - "metadata": { |
556 |
| - "colab": { |
557 |
| - "base_uri": "https://localhost:8080/", |
558 |
| - "height": 381 |
559 |
| - }, |
560 |
| - "id": "TAMU4Qc6rYM5", |
561 |
| - "outputId": "799e598f-4c85-4eb6-b3c0-6a84ae327ea6" |
562 |
| - }, |
563 |
| - "outputs": [], |
564 |
| - "source": [ |
565 |
| - "#HTML(anim.to_html5_video())" |
| 518 | + "# import matplotlib.animation as animation\n", |
| 519 | + "# from IPython.display import HTML\n", |
| 520 | + "\n", |
| 521 | + "# def animate(i):\n", |
| 522 | + "# ax.cla()\n", |
| 523 | + "# w_curr = w_means[i]\n", |
| 524 | + "# Zi = posterior_predictive_grid(input_grid, w_means[i], sigmoid_fn)\n", |
| 525 | + "# title = f'CMGF-EKF-MLP ({i+1}/500)'\n", |
| 526 | + "# plot_posterior_predictive(ax, input[:i+1], output[:i+1], title, input_grid, Zi) \n", |
| 527 | + "# return ax\n", |
| 528 | + "\n", |
| 529 | + "# fig, ax = plt.subplots(figsize=(6, 5))\n", |
| 530 | + "# anim = animation.FuncAnimation(fig, animate, frames=500, interval=50)\n", |
| 531 | + "# anim.save(\"cmgf_mlp_classifier.mp4\", dpi=200, bitrate=-1, fps=24)\n", |
| 532 | + "\n", |
| 533 | + "# HTML(anim.to_html5_video())" |
566 | 534 | ]
|
567 | 535 | }
|
568 | 536 | ],
|
|
0 commit comments