-
API
tf.train.ClusterSpec Create a cluster from the parameter server and worker hosts. tf.train.Server Create and start a server for the local task. tf.train.MonitoredTrainingSession tf.train.SyncReplicasOptimizer tf.train.Supervisor # deprecated
-
Session
tf.train.SyncReplicasOptimizer
Distribution Strategy TF-Example
def cluster_server(config): ps_hosts = config['ps_hosts'] worker_hosts = config['worker_hosts'] cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) job_name = config['job_name'] # 'worker', 'ps' task_index = config["index"] # 0, 1, 2, ... server = tf.train.Server(cluster, job_name=job_name, task_index=task_index) return cluster, server, job_name, task_index
# tf.train.MonitoredTrainingSession def main(): cluster, server, job_name, task_index = cluster_server() if job_name == "ps": server.join() elif job_name == "worker": with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % task_index, cluster=cluster)): model = ... loss = ... optimizer = ... global_step=tf.get_or_create_global_step() train_op = optimizer.minimize(loss, global_step=global_step) hooks=[tf.train.StopAtStepHook(last_step=1000000)] with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(task_index == 0), checkpoint_dir="logs") as sess: while not sess.should_stop(): sess.run(train_op, feed_dict={})
# 同步训练 opt= tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=len(worker_hosts), total_num_replicas=len(worker_hosts), use_locking=True) train_op = opt.minimize(loss, global_step=global_step) sync_replicas_hook = opt.make_session_run_hook(is_chief) with training.MonitoredTrainingSession(master=server.target, is_chief=(task_index == 0), hooks=[sync_replicas_hook] checkpoint_dir="logs") as sess: while not sess.should_stop(): mon_sess.run(train_op)
-
Estimator
# 直接设置分布式训练环境参数 TF_CONFIG, 使用 estimator 进行训练 tf_config = { 'cluster': {'chief': chief_hosts, 'worker': worker_hosts, 'ps': ps_hosts}, 'task': {'type': task_type, 'index': task_index} } os.environ['TF_CONFIG'] = json.dumps(tf_config)