Skip to content

Commit 12d3e7e

Browse files
committed
Prevent saver node generation during training phase.
1 parent a9b4ddb commit 12d3e7e

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

gailtf/baselines/common/tf_util.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -274,22 +274,31 @@ def FileWriter(dir_path):
274274
# ================================================================
275275
# Saving variables
276276
# ================================================================
277-
278277
def load_state(fname, var_list=None):
279278
if var_list is not None: saver = tf.train.Saver(var_list=var_list)
280279
else: saver = tf.train.Saver()
281280
saver.restore(get_session(), fname)
282281

282+
_SAVER_CACHE = {} # name -> saver
283283

284284
def save_state(fname, var_list=None, counter=None):
285285

286286
os.makedirs(os.path.dirname(fname), exist_ok=True)
287-
if var_list is not None: saver = tf.train.Saver(var_list=var_list)
288-
else: saver = tf.train.Saver()
287+
288+
saver = get_saver(var_list=var_list)
289289

290290
if counter is not None: saver.save(get_session(), fname, global_step=counter)
291291
else: saver.save(get_session(), fname)
292292

293+
def get_saver(name='default_saver', var_list=None):
294+
if name in _SAVER_CACHE:
295+
return _SAVER_CACHE[name]
296+
else:
297+
saver = tf.train.Saver(var_list=var_list)
298+
_SAVER_CACHE[name] = saver
299+
return saver
300+
301+
293302
# ================================================================
294303
# Model components
295304
# ================================================================

0 commit comments

Comments
 (0)