问题遇到的现象和发生背景
问题相关代码,请勿粘贴截图
运行结果及报错内容
我的解答思路和尝试过的方法
我想要达到的结果
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
FLASS = tf.compat.v1.app.flags.FLAGS
tf.compat.v1.app.flags.DEFINE_string("job_name", "", "启动服务的类型ps or worker")
tf.compat.v1.app.flags.DEFINE_integer("task_index", 0, "指定ps或者worker当中的那一台服务器以task:0 ,task:1")
def main(argv):
# 定义全局计数的op,给钩子列表当中的训练步数使用
global_step = tf.compat.v1.train.get_global_step()
# 指定集群描述对象, ps , worker
cluster = tf.compat.v1.train.ClusterSpec({"ps": ["192.168.254.128:2223"], "worker": ["10.213.26.88:2222"]})
server = tf.compat.v1.train.Server(cluster, job_name=FLASS.job_name, task_index=FLASS.task_index)
# 根据不同服务做不同的事情, ps:去更新保存参数 worker:指定设备去运行模型计算
if FLASS.job_name == "ps":
# 参数服务器什么都不用干, 是需要等待worker传递参数
server.join()
else:
worker_device = "/job:worker/task:0/cpu:0/"
# 可以指定设备去运行
with tf.compat.v1.device(
tf.compat.v1.train.replica_device_setter(
worker_device = worker_device,
cluster=cluster
)):
# 简单做一个矩阵乘法运算
x = tf.compat.v1.Variable([[1, 2, 3, 4]])
w = tf.compat.v1.Variable([2], [2], [2], [2])
mat = tf.compat.v1.matmul(x, w)
# 创建分布式会话
with tf.compat.v1.train.MonitoredTrainingSession(
master="grpc://10.213.26.88:2222", # 指定主worker
is_chief=(FLASS.task_index == 0), # 判断是否是住worker
config=tf.compat.v1.ConfigProto(log_device_placement=True), # 打印设备信息
hooks=[tf.compat.v1.train.StopAtStepHook(last_step=200)]
) as mon_sess:
while not mon_sess.should_stop():
print(mon_sess.run(mat))
if __name__ == "__main__":
tf.compat.v1.app.run()