Tensorflow的分布式训练

  • API

    tf.train.ClusterSpec    Create a cluster from the parameter server and worker hosts.
    tf.train.Server         Create and start a server for the local task.
    
    tf.train.MonitoredTrainingSession
    tf.train.SyncReplicasOptimizer
    tf.train.Supervisor    # deprecated
    
  • Session

    TensorFlow分布式部署

    配置分布式TensorFlow

    tf.train.SyncReplicasOptimizer

    tf.distribute.Strategy

    Distribution Strategy TF-Example

    def cluster_server(config):
        ps_hosts = config['ps_hosts']
        worker_hosts = config['worker_hosts']
        cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
        
        job_name = config['job_name']   # 'worker', 'ps'
        task_index = config["index"]    # 0, 1, 2, ...
        server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)
    
        return cluster, server, job_name, task_index
    
    # tf.train.MonitoredTrainingSession
    def main():
        cluster, server, job_name, task_index = cluster_server()
        if job_name == "ps":
            server.join()
        elif job_name == "worker":
            with tf.device(tf.train.replica_device_setter(
                worker_device="/job:worker/task:%d" % task_index,
                cluster=cluster)):
    
                model = ...
                loss = ...
                optimizer = ...
                global_step=tf.get_or_create_global_step()
                train_op = optimizer.minimize(loss, global_step=global_step)
            hooks=[tf.train.StopAtStepHook(last_step=1000000)]
            with tf.train.MonitoredTrainingSession(master=server.target,
                                                is_chief=(task_index == 0),
                                                checkpoint_dir="logs") as sess:
            while not sess.should_stop():
                sess.run(train_op, feed_dict={})
    
    # 同步训练
    opt= tf.train.SyncReplicasOptimizer(optimizer,
                 replicas_to_aggregate=len(worker_hosts),
                 total_num_replicas=len(worker_hosts),
                 use_locking=True)
    train_op = opt.minimize(loss, global_step=global_step)
    sync_replicas_hook = opt.make_session_run_hook(is_chief)
    with training.MonitoredTrainingSession(master=server.target,
        is_chief=(task_index == 0),
        hooks=[sync_replicas_hook]
        checkpoint_dir="logs") as sess:
      while not sess.should_stop():
        mon_sess.run(train_op)
    
  • Estimator

    # 直接设置分布式训练环境参数 TF_CONFIG, 使用 estimator 进行训练
    
    tf_config = {
        'cluster': {'chief': chief_hosts, 'worker': worker_hosts, 'ps': ps_hosts},
        'task': {'type': task_type, 'index': task_index}
    }
    os.environ['TF_CONFIG'] = json.dumps(tf_config)
    
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。