Skip to content

Commit

Permalink
Backend TensorFlow 1.x: L-BFGS outputs trainable variables and test l…
Browse files Browse the repository at this point in the history
…oss (lululxvi#817)
  • Loading branch information
tsarikahin authored Oct 19, 2022
1 parent 5e40e61 commit 7ddaf6f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
4 changes: 4 additions & 0 deletions deepxde/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ def on_epoch_end(self):
self.epochs_since_last = 0
self.on_train_begin()

def on_train_end(self):
if not self.epochs_since_last == 0:
self.on_train_begin()

def get_value(self):
"""Return the variable values."""
return self.value
Expand Down
31 changes: 26 additions & 5 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from . import utils
from .backend import backend_name, tf, torch, jax, paddle
from .callbacks import CallbackList
from .utils import list_to_str


class Model:
Expand Down Expand Up @@ -553,7 +554,6 @@ def train(
" Use iterations instead."
)
iterations = epochs

self.batch_size = batch_size
self.callbacks = CallbackList(callbacks=callbacks)
self.callbacks.set_model(self)
Expand Down Expand Up @@ -621,17 +621,35 @@ def _train_sgd(self, iterations, display_every):
break

def _train_tensorflow_compat_v1_scipy(self, display_every):
def loss_callback(loss_train):
def loss_callback(loss_train, loss_test, *args):
self.train_state.epoch += 1
self.train_state.step += 1
if self.train_state.step % display_every == 0:
self.train_state.loss_train = loss_train
self.train_state.loss_test = None
self.train_state.loss_test = loss_test
self.train_state.metrics_test = None
self.losshistory.append(
self.train_state.step, self.train_state.loss_train, None, None
self.train_state.step,
self.train_state.loss_train,
self.train_state.loss_test,
None,
)
display.training_display(self.train_state)
for cb in self.callbacks.callbacks:
if type(cb).__name__ == "VariableValue":
cb.epochs_since_last += 1
if cb.epochs_since_last >= cb.period:
cb.epochs_since_last = 0

print(
cb.model.train_state.epoch,
list_to_str(
[float(arg) for arg in args],
precision=cb.precision,
),
file=cb.file,
)
cb.file.flush()

self.train_state.set_data_train(*self.data.train_next_batch(self.batch_size))
feed_dict = self.net.feed_dict(
Expand All @@ -640,10 +658,13 @@ def loss_callback(loss_train):
self.train_state.y_train,
self.train_state.train_aux_vars,
)
fetches = [self.outputs_losses_train[1], self.outputs_losses_test[1]]
if self.external_trainable_variables:
fetches += self.external_trainable_variables
self.train_step.minimize(
self.sess,
feed_dict=feed_dict,
fetches=[self.outputs_losses_train[1]],
fetches=fetches,
loss_callback=loss_callback,
)
self._test()
Expand Down
17 changes: 14 additions & 3 deletions examples/pinn_inverse/Lorenz_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,20 @@ def boundary(_, on_initial):

net = dde.nn.FNN([1] + [40] * 3 + [3], "tanh", "Glorot uniform")
model = dde.Model(data, net)
model.compile("adam", lr=0.001, external_trainable_variables=[C1, C2, C3])

external_trainable_variables = [C1, C2, C3]
variable = dde.callbacks.VariableValue(
[C1, C2, C3], period=600, filename="variables.dat"
external_trainable_variables, period=600, filename="variables.dat"
)

# train adam
model.compile(
"adam", lr=0.001, external_trainable_variables=external_trainable_variables
)
losshistory, train_state = model.train(iterations=60000, callbacks=[variable])
losshistory, train_state = model.train(iterations=20000, callbacks=[variable])

# train lbfgs
model.compile("L-BFGS", external_trainable_variables=external_trainable_variables)
losshistory, train_state = model.train(callbacks=[variable])

dde.saveplot(losshistory, train_state, issave=True, isplot=True)

0 comments on commit 7ddaf6f

Please sign in to comment.