@@ -274,22 +274,31 @@ def FileWriter(dir_path):
274
274
# ================================================================
275
275
# Saving variables
276
276
# ================================================================
277
-
278
277
def load_state (fname , var_list = None ):
279
278
if var_list is not None : saver = tf .train .Saver (var_list = var_list )
280
279
else : saver = tf .train .Saver ()
281
280
saver .restore (get_session (), fname )
282
281
282
+ _SAVER_CACHE = {} # name -> saver
283
283
284
284
def save_state (fname , var_list = None , counter = None ):
285
285
286
286
os .makedirs (os .path .dirname (fname ), exist_ok = True )
287
- if var_list is not None : saver = tf . train . Saver ( var_list = var_list )
288
- else : saver = tf . train . Saver ( )
287
+
288
+ saver = get_saver ( var_list = var_list )
289
289
290
290
if counter is not None : saver .save (get_session (), fname , global_step = counter )
291
291
else : saver .save (get_session (), fname )
292
292
293
+ def get_saver (name = 'default_saver' , var_list = None ):
294
+ if name in _SAVER_CACHE :
295
+ return _SAVER_CACHE [name ]
296
+ else :
297
+ saver = tf .train .Saver (var_list = var_list )
298
+ _SAVER_CACHE [name ] = saver
299
+ return saver
300
+
301
+
293
302
# ================================================================
294
303
# Model components
295
304
# ================================================================
0 commit comments