Skip to content

Commit

Permalink
Merge branch 'lululxvi:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
amostof authored Oct 27, 2022
2 parents c548b59 + 87d63a6 commit 128b051
Show file tree
Hide file tree
Showing 47 changed files with 1,738 additions and 215 deletions.
32 changes: 16 additions & 16 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Reference: https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/use-conda-with-travis-ci.html
dist: focal
dist: jammy
language: python
python:
# We don't actually use the Travis Python, but this keeps it organized.
Expand All @@ -9,11 +9,7 @@ python:
install:
# We do this conditionally because it saves us some downloading if the
# version is the same.
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh;
else
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
fi
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
- bash miniconda.sh -b -p $HOME/miniconda
- source "$HOME/miniconda/etc/profile.d/conda.sh"
- hash -r
Expand All @@ -23,19 +19,23 @@ install:
- conda info -a

# Replace dep1 dep2 ... with your dependencies
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION matplotlib numpy scikit-learn scipy
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION
- conda activate test-environment

- conda install matplotlib numpy scikit-learn scipy
- conda install -c conda-forge scikit-optimize
- pip install tensorflow tensorflow-probability
- conda install pytorch cpuonly -c pytorch
- pip install jax flax optax
- pip install paddlepaddle
# - pip install tensorflow tensorflow-probability
# - conda install pytorch cpuonly -c pytorch
# - pip install jax flax optax
# - pip install paddlepaddle

# - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/

script:
# Your test script goes here
- DDEBACKEND=tensorflow.compat.v1 python -c "import deepxde"
- DDEBACKEND=tensorflow python -c "import deepxde"
- DDEBACKEND=pytorch python -c "import deepxde"
- DDEBACKEND=jax python -c "import deepxde"
- DDEBACKEND=paddle python -c "import deepxde"
- python -c "print('Hello World')"
# - DDE_BACKEND=tensorflow.compat.v1 python -c "import deepxde"
# - DDE_BACKEND=tensorflow python -c "import deepxde"
# - DDE_BACKEND=pytorch python -c "import deepxde"
# - DDE_BACKEND=jax python -c "import deepxde"
# - DDE_BACKEND=paddle python -c "import deepxde"
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ DeepXDE is a library for scientific machine learning and physics-informed learni
- NN-arbitrary polynomial chaos (NN-aPC): solving forward/inverse stochastic PDEs (sPDEs) [[J. Comput. Phys.](https://doi.org/10.1016/j.jcp.2019.07.048)]
- PINN with hard constraints (hPINN): solving inverse design/topology optimization [[SIAM J. Sci. Comput.](https://doi.org/10.1137/21M1397908)]
- improving PINN accuracy
- residual-based adaptive sampling [[SIAM Rev.](https://doi.org/10.1137/19M1274067), [Comput. Methods Appl. Mech. Eng.](https://doi.org/10.1016/j.cma.2022.115671)]
- gradient-enhanced PINN (gPINN) [[Comput. Methods Appl. Mech. Eng.](https://doi.org/10.1016/j.cma.2022.114823)]
- PINN with multi-scale Fourier features [[Comput. Methods Appl. Mech. Eng.](https://doi.org/10.1016/j.cma.2021.113938)]
- [Slides](https://github.com/lululxvi/tutorials/blob/master/20211210_pinn/pinn.pdf), [Video](https://www.youtube.com/watch?v=Wfgr1pMA9fY&list=PL1e3Jic2_DwwJQ528agJYMEpA0oMaDSA9&index=13), [Video in Chinese](http://tianyuan.xmu.edu.cn/cn/minicourses/637.html)
Expand Down Expand Up @@ -52,7 +53,7 @@ DeepXDE has implemented many algorithms as shown above and supports many feature
- **complex domain geometries** without tyranny mesh generation. The primitive geometries are interval, triangle, rectangle, polygon, disk, cuboid, sphere, hypercube, and hypersphere. Other geometries can be constructed as constructive solid geometry (CSG) using three boolean operations: union, difference, and intersection. DeepXDE also supports a geometry represented by a point cloud.
- 5 types of **boundary conditions** (BCs): Dirichlet, Neumann, Robin, periodic, and a general BC, which can be defined on an arbitrary domain or on a point set.
- different **neural networks**: fully connected neural network (FNN), stacked FNN, residual neural network, (spatio-temporal) multi-scale Fourier feature networks, etc.
- 6 **sampling methods**: uniform, pseudorandom, Latin hypercube sampling, Halton sequence, Hammersley sequence, and Sobol sequence. The training points can keep the same during training or be resampled every certain iterations.
- many **sampling methods**: uniform, pseudorandom, Latin hypercube sampling, Halton sequence, Hammersley sequence, and Sobol sequence. The training points can keep the same during training or be resampled (adaptively) every certain iterations.
- 4 **function spaces**: power series, Chebyshev polynomial, Gaussian random field (1D/2D).
- different **optimizers**: Adam, L-BFGS, etc.
- conveniently **save** the model during training, and **load** a trained model.
Expand All @@ -67,7 +68,7 @@ All the components of DeepXDE are loosely coupled, and thus DeepXDE is well-stru

DeepXDE requires one of the following backend-specific dependencies to be installed:

- TensorFlow 1.x: [TensorFlow](https://www.tensorflow.org)>=2.2.0
- TensorFlow 1.x: [TensorFlow](https://www.tensorflow.org)>=2.7.0
- TensorFlow 2.x: [TensorFlow](https://www.tensorflow.org)>=2.2.0, [TensorFlow Probability](https://www.tensorflow.org/probability)>=0.10.0
- PyTorch: [PyTorch](https://pytorch.org)>=1.9.0
- JAX: [JAX](https://jax.readthedocs.io), [Flax](https://flax.readthedocs.io), [Optax](https://optax.readthedocs.io)
Expand Down Expand Up @@ -96,9 +97,10 @@ $ git clone https://github.com/lululxvi/deepxde.git
## Explore more

- [Install and Setup](https://deepxde.readthedocs.io/en/latest/user/installation.html)
- [Demos of function approximation](https://deepxde.readthedocs.io/en/latest/demos/function.html)
- [Demos of forward problems](https://deepxde.readthedocs.io/en/latest/demos/pinn_forward.html)
- [Demos of inverse problems](https://deepxde.readthedocs.io/en/latest/demos/pinn_inverse.html)
- [Demos of function approximation](https://deepxde.readthedocs.io/en/latest/demos/function.html)
- [Demos of operator learning](https://deepxde.readthedocs.io/en/latest/demos/operator.html)
- [FAQ](https://deepxde.readthedocs.io/en/latest/user/faq.html)
- [Research papers used DeepXDE](https://deepxde.readthedocs.io/en/latest/user/research.html)
- [API](https://deepxde.readthedocs.io/en/latest/modules/deepxde.html)
Expand Down Expand Up @@ -136,7 +138,7 @@ First off, thanks for taking the time to contribute!

DeepXDE was developed by [Lu Lu](https://lu.seas.upenn.edu) under the supervision of Prof. [George Karniadakis](https://www.brown.edu/research/projects/crunch/george-karniadakis) at [Brown University](https://www.brown.edu) from the summer of 2018 to 2020, supported by [PhILMs](https://www.pnnl.gov/computing/philms). DeepXDE was originally self-hosted in Subversion at Brown University, under the name SciCoNet (Scientific Computing Neural Networks). On Feb 7, 2019, SciCoNet was moved from Subversion to GitHub, renamed to DeepXDE.

DeepXDE is currently maintained by [Lu Lu](https://lu.seas.upenn.edu) at [University of Pennsylvania](https://www.upenn.edu) with major contributions coming from several talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: [Zongren Zou](https://github.com/ZongrenZou), [Shunyuan Mao](https://github.com/smao-astro).
DeepXDE is currently maintained by [Lu Lu](https://lu.seas.upenn.edu) at [University of Pennsylvania](https://www.upenn.edu) with major contributions coming from several talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: [Zongren Zou](https://github.com/ZongrenZou), [Zhongyi Jiang](https://github.com/Jerry-Jzy), [Shunyuan Mao](https://github.com/smao-astro), [Paul Escapil-Inchauspé](https://github.com/pescap).

## License

Expand Down
2 changes: 1 addition & 1 deletion deepxde/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.6.0"
__version__ = "1.7.0"
4 changes: 2 additions & 2 deletions deepxde/backend/tensorflow_compat_v1/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import tensorflow.compat.v1 as tf


if LooseVersion(tf.__version__) < LooseVersion("2.2.0"):
raise RuntimeError("DeepXDE requires TensorFlow>=2.2.0.")
if LooseVersion(tf.__version__) < LooseVersion("2.7.0"):
raise RuntimeError("DeepXDE requires TensorFlow>=2.7.0.")


# The major changes from TensorFlow 1.x to TensorFlow 2.x are:
Expand Down
46 changes: 39 additions & 7 deletions deepxde/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,24 @@ class ModelCheckpoint(Callback):
monitored. Model is only checked at validation step according to
``display_every`` in ``Model.train``.
period: Interval (number of epochs) between checkpoints.
monitor: The loss function that is monitored. Either 'train loss' or 'test loss'.
"""

def __init__(self, filepath, verbose=0, save_better_only=False, period=1):
def __init__(
self,
filepath,
verbose=0,
save_better_only=False,
period=1,
monitor="train loss",
):
super().__init__()
self.filepath = filepath
self.verbose = verbose
self.save_better_only = save_better_only
self.period = period

self.monitor = "train loss"
self.monitor = monitor
self.monitor_op = np.less
self.epochs_since_last_save = 0
self.best = np.Inf
Expand All @@ -137,7 +145,7 @@ def on_epoch_end(self):
return
self.epochs_since_last_save = 0
if self.save_better_only:
current = self.model.train_state.best_loss_train
current = self.get_monitor_value()
if self.monitor_op(current, self.best):
save_path = self.model.save(self.filepath, verbose=0)
if self.verbose > 0:
Expand All @@ -154,6 +162,16 @@ def on_epoch_end(self):
else:
self.model.save(self.filepath, verbose=self.verbose)

def get_monitor_value(self):
if self.monitor == "train loss":
result = sum(self.model.train_state.loss_train)
elif self.monitor == "test loss":
result = sum(self.model.train_state.loss_test)
else:
raise ValueError("The specified monitor function is incorrect.")

return result


class EarlyStopping(Callback):
"""Stop training when a monitored quantity (training or testing loss) has stopped improving.
Expand Down Expand Up @@ -325,6 +343,10 @@ def on_epoch_end(self):
self.epochs_since_last = 0
self.on_train_begin()

def on_train_end(self):
if not self.epochs_since_last == 0:
self.on_train_begin()

def get_value(self):
"""Return the variable values."""
return self.value
Expand Down Expand Up @@ -484,12 +506,22 @@ def on_train_end(self):
)


class PDEResidualResampler(Callback):
"""Resample the training points for PDE losses every given period."""
class PDEPointResampler(Callback):
"""Resample the training points for PDE and/or BC losses every given period.
Args:
period: How often to resample the training points (default is 100 iterations).
pde_points: If True, resample the training points for PDE losses (default is
True).
bc_points: If True, resample the training points for BC losses (default is
False; only supported by pytorch backend currently).
"""

def __init__(self, period=100):
def __init__(self, period=100, pde_points=True, bc_points=False):
super().__init__()
self.period = period
self.pde_points = pde_points
self.bc_points = bc_points

self.num_bcs_initial = None
self.epochs_since_last_resample = 0
Expand All @@ -502,7 +534,7 @@ def on_epoch_end(self):
if self.epochs_since_last_resample < self.period:
return
self.epochs_since_last_resample = 0
self.model.data.resample_train_points()
self.model.data.resample_train_points(self.pde_points, self.bc_points)

if not np.array_equal(self.num_bcs_initial, self.model.data.num_bcs):
print("Initial value of self.num_bcs:", self.num_bcs_initial)
Expand Down
8 changes: 5 additions & 3 deletions deepxde/data/fpde.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
meshtype="dynamic",
num_domain=0,
num_boundary=0,
train_distribution="Sobol",
train_distribution="Hammersley",
anchors=None,
solution=None,
num_test=None,
Expand Down Expand Up @@ -217,7 +217,7 @@ def __init__(
num_domain=0,
num_boundary=0,
num_initial=0,
train_distribution="Sobol",
train_distribution="Hammersley",
anchors=None,
solution=None,
num_test=None,
Expand Down Expand Up @@ -312,7 +312,9 @@ def train_points(self):
if self.train_distribution == "uniform":
tmp = self.geom.uniform_initial_points(self.num_initial)
else:
tmp = self.geom.random_initial_points(self.num_initial, random="Sobol")
tmp = self.geom.random_initial_points(
self.num_initial, random=self.train_distribution
)
X = np.vstack((tmp, X))
return X

Expand Down
10 changes: 6 additions & 4 deletions deepxde/data/func_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):

@run_if_any_none("train_x", "train_y")
def train_next_batch(self, batch_size=None):
if self.dist_train == "log uniform":
if self.dist_train == "uniform":
self.train_x = self.geom.uniform_points(self.num_train, False)
elif self.dist_train == "log uniform":
self.train_x = self.geom.log_uniform_points(self.num_train, False)
elif self.dist_train == "random":
self.train_x = self.geom.random_points(self.num_train, "Sobol")
else:
self.train_x = self.geom.uniform_points(self.num_train, False)
self.train_x = self.geom.random_points(
self.num_train, random=self.dist_train
)
if self.anchors is not None:
self.train_x = np.vstack((self.anchors, self.train_x))
self.train_y = self.func(self.train_x)
Expand Down
2 changes: 1 addition & 1 deletion deepxde/data/ide.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
kernel=None,
num_domain=0,
num_boundary=0,
train_distribution="Sobol",
train_distribution="Hammersley",
anchors=None,
solution=None,
num_test=None,
Expand Down
4 changes: 2 additions & 2 deletions deepxde/data/mf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def train_next_batch(self, batch_size=None):
else:
self.X_train = np.vstack(
(
self.geom.random_points(self.num_lo, "Sobol"),
self.geom.random_points(self.num_hi, "Sobol"),
self.geom.random_points(self.num_lo, random=self.dist_train),
self.geom.random_points(self.num_hi, random=self.dist_train),
)
)
y_lo_train = self.func_lo(self.X_train)
Expand Down
38 changes: 18 additions & 20 deletions deepxde/data/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PDE(Data):
error= dde.metrics.l2_relative_error(y_true, y_pred)
Attributes:
train_x_all: A Numpy array of all points for training. `train_x_all` is
train_x_all: A Numpy array of points for PDE training. `train_x_all` is
unordered, and does not have duplication. If there is PDE, then
`train_x_all` is used as the training points of PDE.
train_x_bc: A Numpy array of the training points for BCs. `train_x_bc` is
Expand All @@ -78,7 +78,7 @@ def __init__(
bcs,
num_domain=0,
num_boundary=0,
train_distribution="Sobol",
train_distribution="Hammersley",
anchors=None,
exclusions=None,
solution=None,
Expand All @@ -91,19 +91,6 @@ def __init__(

self.num_domain = num_domain
self.num_boundary = num_boundary
if train_distribution not in [
"uniform",
"pseudo",
"LHS",
"Halton",
"Hammersley",
"Sobol",
]:
raise ValueError(
"train_distribution == {} is not available choices.".format(
train_distribution
)
)
self.train_distribution = train_distribution
self.anchors = None if anchors is None else anchors.astype(config.real(np))
self.exclusions = exclusions
Expand All @@ -113,11 +100,14 @@ def __init__(

self.auxiliary_var_fn = auxiliary_var_function

# TODO: train_x_all is used for PDE losses. It is better to add train_x_pde explicitly.
# TODO: train_x_all is used for PDE losses. It is better to add train_x_pde
# explicitly.
self.train_x_all = None
self.train_x, self.train_y = None, None
self.train_x_bc = None
self.num_bcs = None

# these include both BC and PDE points
self.train_x, self.train_y = None, None
self.test_x, self.test_y = None, None
self.train_aux_vars, self.test_aux_vars = None, None

Expand Down Expand Up @@ -193,8 +183,12 @@ def test(self):
)
return self.test_x, self.test_y, self.test_aux_vars

def resample_train_points(self):
"""Resample the training points for PDEs. The BC points will not be updated."""
def resample_train_points(self, pde_points=True, bc_points=True):
"""Resample the training points for PDE and/or BC."""
if pde_points:
self.train_x_all = None
if bc_points:
self.train_x_bc = None
self.train_x, self.train_y, self.train_aux_vars = None, None, None
self.train_next_batch()

Expand Down Expand Up @@ -228,6 +222,7 @@ def replace_with_anchors(self, anchors):
config.real(np)
)

@run_if_all_none("train_x_all")
def train_points(self):
X = np.empty((0, self.geom.dim), dtype=config.real(np))
if self.num_domain > 0:
Expand All @@ -253,6 +248,7 @@ def is_not_excluded(x):
return not np.any([np.allclose(x, y) for y in self.exclusions])

X = np.array(list(filter(is_not_excluded, X)))
self.train_x_all = X
return X

@run_if_all_none("train_x_bc")
Expand Down Expand Up @@ -289,7 +285,7 @@ def __init__(
num_domain=0,
num_boundary=0,
num_initial=0,
train_distribution="Sobol",
train_distribution="Hammersley",
anchors=None,
exclusions=None,
solution=None,
Expand All @@ -311,6 +307,7 @@ def __init__(
auxiliary_var_function=auxiliary_var_function,
)

@run_if_all_none("train_x_all")
def train_points(self):
X = super().train_points()
if self.num_initial > 0:
Expand All @@ -327,4 +324,5 @@ def is_not_excluded(x):

tmp = np.array(list(filter(is_not_excluded, tmp)))
X = np.vstack((tmp, X))
self.train_x_all = X
return X
Loading

0 comments on commit 128b051

Please sign in to comment.