Skip to content

Commit

Permalink
TripleCartesianProd and QuadrupleCartesianProd support mini-batch for…
Browse files Browse the repository at this point in the history
… both branch and trunk nets (lululxvi#977)
  • Loading branch information
Jerry-Jzy authored Oct 22, 2022
1 parent 3101bcb commit 87d63a6
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
21 changes: 15 additions & 6 deletions deepxde/data/quadruple.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,29 @@ def __init__(self, X_train, y_train, X_test, y_test):
self.train_x, self.train_y = X_train, y_train
self.test_x, self.test_y = X_test, y_test

self.train_sampler = BatchSampler(len(X_train[0]), shuffle=True)
self.branch_sampler = BatchSampler(len(X_train[0]), shuffle=True)
self.trunk_sampler = BatchSampler(len(X_train[2]), shuffle=True)

def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
return loss_fn(targets, outputs)

def train_next_batch(self, batch_size=None):
if batch_size is None:
return self.train_x, self.train_y
indices = self.train_sampler.get_next(batch_size)
if not isinstance(batch_size, (tuple, list)):
indices = self.branch_sampler.get_next(batch_size)
return (
self.train_x[0][indices],
self.train_x[1][indices],
self.train_x[2],
), self.train_y[indices]
indices_branch = self.branch_sampler.get_next(batch_size[0])
indices_trunk = self.trunk_sampler.get_next(batch_size[1])
return (
self.train_x[0][indices],
self.train_x[1][indices],
self.train_x[2],
), self.train_y[indices]
self.train_x[0][indices_branch],
self.train_x[1][indices_branch],
self.train_x[2][indices_trunk],
), self.train_y[indices_branch, indices_trunk]

def test(self):
return self.test_x, self.test_y
17 changes: 12 additions & 5 deletions deepxde/data/triple.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ class TripleCartesianProd(Data):
Args:
X_train: A tuple of two NumPy arrays. The first element has the shape (`N1`,
`dim1`), and the second element has the shape (`N2`, `dim2`). The mini-batch
is only applied to `N1`.
`dim1`), and the second element has the shape (`N2`, `dim2`).
y_train: A NumPy array of shape (`N1`, `N2`).
"""

Expand All @@ -71,16 +70,24 @@ def __init__(self, X_train, y_train, X_test, y_test):
self.train_x, self.train_y = X_train, y_train
self.test_x, self.test_y = X_test, y_test

self.train_sampler = BatchSampler(len(X_train[0]), shuffle=True)
self.branch_sampler = BatchSampler(len(X_train[0]), shuffle=True)
self.trunk_sampler = BatchSampler(len(X_train[1]), shuffle=True)

def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
return loss_fn(targets, outputs)

def train_next_batch(self, batch_size=None):
if batch_size is None:
return self.train_x, self.train_y
indices = self.train_sampler.get_next(batch_size)
return (self.train_x[0][indices], self.train_x[1]), self.train_y[indices]
if not isinstance(batch_size, (tuple, list)):
indices = self.branch_sampler.get_next(batch_size)
return (self.train_x[0][indices], self.train_x[1]), self.train_y[indices]
indices_branch = self.branch_sampler.get_next(batch_size[0])
indices_trunk = self.trunk_sampler.get_next(batch_size[1])
return (
self.train_x[0][indices_branch],
self.train_x[1][indices_trunk],
), self.train_y[indices_branch, indices_trunk]

def test(self):
return self.test_x, self.test_y
15 changes: 10 additions & 5 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,11 +533,16 @@ def train(
Args:
iterations (Integer): Number of iterations to train the model, i.e., number
of times the network weights are updated.
batch_size: Integer or ``None``. If you solve PDEs via ``dde.data.PDE`` or
``dde.data.TimePDE``, do not use `batch_size`, and instead use
`dde.callbacks.PDEResidualResampler
<https://deepxde.readthedocs.io/en/latest/modules/deepxde.html#deepxde.callbacks.PDEResidualResampler>`_,
see an `example <https://github.com/lululxvi/deepxde/blob/master/examples/diffusion_1d_resample.py>`_.
batch_size: Integer, tuple, or ``None``.
- If you solve PDEs via ``dde.data.PDE`` or ``dde.data.TimePDE``, do not use `batch_size`, and instead use
`dde.callbacks.PDEResidualResampler
<https://deepxde.readthedocs.io/en/latest/modules/deepxde.html#deepxde.callbacks.PDEResidualResampler>`_,
see an `example <https://github.com/lululxvi/deepxde/blob/master/examples/diffusion_1d_resample.py>`_.
- For DeepONet in the format of Cartesian product, if `batch_size` is an Integer,
then it is the batch size for the branch input; if you want to also use mini-batch for the trunk net input,
set `batch_size` as a tuple, where the fist number is the batch size for the branch net input
and the second number is the batch size for the trunk net input.
display_every (Integer): Print the loss and metrics every this steps.
disregard_previous_best: If ``True``, disregard the previous saved best
model.
Expand Down

0 comments on commit 87d63a6

Please sign in to comment.