Skip to content

Latest commit

 

History

History
53 lines (40 loc) · 2.91 KB

README.md

File metadata and controls

53 lines (40 loc) · 2.91 KB

Neural Context Flow

Official JAX implementation of Neural Context Flows for Meta-Learning of Dynamical Systems, accepted at ICLR 2025. A PyTorch version is available at this link.

drawing

Overview

Neural Context Flow (NCF) is a framework for learning dynamical systems that can adapt to different environments/contexts, making it particularly valuable for scientific machine learning applications where the underlying system dynamics may vary across physical parameter values.

NCF is powered by the contextual self-modulation regularization mechanism. This inductive bias performs Taylor expansion of the vector field about the context vectors, resulting in several candidate trajectories. While the training loss might be higher than that of the naive Neural ODE, we observe lower losses at adaptation time.

drawing

The NCF package is built around 4 extensible modules:

  • a DataLoader: to load the dynamics datasets
  • a Learner: a model, a context and the loss function
  • a Trainer: the training and adaptation algorithms
  • a VisualTester: to test and visualize the results

Getting Started

To run of the experiments described in the NCF paper, follow the steps below:

  1. Install the package: pip install -e .
  2. Navigate to the problem of interest in the examples folder
  3. (Optional) download the data from Gen-Dynamics and place it in the data folder
  4. Set its hyperparameters, and run the main.py script to both train and adapt the model to various environments. One can either run it in Notebook or script mode. We recommend using nohup to log the results: nohup python main.py > nohup.log &
  5. Once trained, move to the corresponding run folder saved in runs. Toggle the train flag in the main.py to False. Rerun main.py to perform additional experiments such as uncertainty estimation, interpretability, etc.

Notes

The main requirement to run our package is JAX and its ecosystem (Equinox, Diffrax, Optax). However, we require PyTorch (on the CPU platform) to generate data for the Navier-Stokes problem, as done in CoDA.

To-Do

  • Provide links to some weights and contexts
  • Delete long history of commits
  • Test the installation in neutral conda environments

If you use this work, please cite the corresponding paper:

@inproceedings{
    nzoyem2025neural,
    title={Neural Context Flows for Meta-Learning of Dynamical Systems},
    author={Roussel Desmond Nzoyem and David A.W. Barton and Tom Deakin},
    booktitle={The Thirteenth International Conference on Learning Representations},
    year={2025},
    url={https://openreview.net/forum?id=8vzMLo8LDN}
}