forked from tmulc18/Distributed-TensorFlow-Guide
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdist_setup_sup.py
70 lines (59 loc) · 1.86 KB
/
dist_setup_sup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Simple example with one parameter server and one worker.
Author: Tommy Mulc
"""
from __future__ import print_function
import tensorflow as tf
import argparse
import time
import os
FLAGS = None
log_dir = '/logdir'
def main():
# Distributed Baggage
cluster = tf.train.ClusterSpec({
'ps':['localhost:2222'],
'worker':['localhost:2223']
}) #lets this node know about all other nodes
if FLAGS.job_name == 'ps': #checks if parameter server
server = tf.train.Server(cluster,job_name="ps",task_index=FLAGS.task_index)
server.join()
else:
is_chief = (FLAGS.task_index == 0) #checks if this is the chief node
server = tf.train.Server(cluster,job_name="worker",task_index=FLAGS.task_index)
# Graph
with tf.device('/cpu:0'):
a = tf.Variable(tf.truncated_normal(shape=[2]),dtype=tf.float32)
b = tf.Variable(tf.truncated_normal(shape=[2]),dtype=tf.float32)
c=a+b
target = tf.constant(100.,shape=[2],dtype=tf.float32)
loss = tf.reduce_mean(tf.square(c-target))
opt = tf.train.GradientDescentOptimizer(.0001).minimize(loss)
# Session
# Supervisor
sv = tf.train.Supervisor(logdir=os.getcwd()+log_dir,is_chief=is_chief,save_model_secs=30)
sess = sv.prepare_or_wait_for_session(server.target)
for i in range(1000):
if sv.should_stop(): break
sess.run(opt)
if i % 10 == 0:
r = sess.run(c)
print(r)
time.sleep(.1)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Flags for defining the tf.train.ClusterSpec
parser.add_argument(
"--job_name",
type=str,
default="",
help="One of 'ps', 'worker'"
)
# Flags for defining the tf.train.Server
parser.add_argument(
"--task_index",
type=int,
default=0,
help="Index of task within the job"
)
FLAGS, unparsed = parser.parse_known_args()
main()