2020-10-09

联邦学习中通过request与中心节点通信

import requests
import json
import os

'''def cost_time(func):
    @wraps(func)
    def wraper(*args, **kwargs):
        start = time.time()
        t = func(*args, **kwargs)
        print('[run time] [%s] run time is %.2f' % (func.__name__, time.time() - start))
        return t
    return wraper'''


class RequestCommunicate(object):
    def __init__(self):
        self.ip_port = ""

    def request_join_federated(self, client_id, model_name,model_description):
        """
        Call the api request to join the joint
        Arguments:
            client_id(string): the mark for client
            model_name(string):  example--> 'C_BiLSTMCRF_NER'
        Return:
            whether_join(boolean): if join federated learning,
                                    True is join, False is not join.
            model_id(string): Unique ID of the model used for push and pull
            model_name(string): the name for model used for push and pull
            model_version(string): the version for model for push
        """
        print('[join api] start!!',flush=True)
        url = self.ip_port + "/join"
        data = {
            "client_id": client_id,
            "model_name": model_name,
            "model_description":model_description
        }
        ret = requests.get(url, params=data)
        ret = json.loads(ret.content)
        print('[join api] return: ', ret,flush=True)
        if ret['code'] == 1:
            whether_join = True
            model_id = ret['model_id']
            model_name = ret['model_name']
            model_version = ret['model_version']
            print('join accept: ', model_id, model_name, model_version,flush=True)
        else:
            print('join failed! ', ret['code'], ret['msg'],flush=True)
            whether_join = False
            model_id, model_name, model_version = "", "", ""

        return whether_join, model_id, model_name, model_version

    def request_check_model(self, voc_tag_maps, client_id):
        """
        Post requests, Call the api request to check model and return dictionary
        Arguments:
            voc_tag_maps(dictionary): vocabulary and label dictionary
                   example: {'vocab': {'安':0, '平':1, '中':2, '国':3,},
                             'tag_map': {'O': 0, 'B-BODY': 5, 'E-BODY': 6, },
                             }
            client_id(string): the mark for client
        Return:
            vob_dictionary(dictionary): global dictionary from central which like voc_tag_maps

        Notes:
            format: "pytorch", "tensorflow" et
            class_name: the name of model class. example: BiLSTMCRF
            class_filename: the model script name. example: base_model
            they are change for different file named
        :return:
        """
        print('[check api] start!!',flush=True)
        filename = os.path.join(os.getcwd(), 'base_model.py')
        files = {
            "model_binary": open(filename, "rb")
        }
        # with open(filename,'rb') as f:
        #     files=f.read()
        # f.close()
        url = self.ip_port + '/check_model'
        postdata = {"client_id": client_id,
                    "format": "pytorch",
                    "class_name": "BiLSTMCRF",
                    "class_filename": "base_model",
                    "reserve_param": json.dumps(voc_tag_maps)}
        ret = requests.post(url, data=postdata,files=files)
        ret = json.loads(ret.content)
        vob_dictionary = None
        if ret['code'] == 1:
            vob_dictionary = {'vocab': ret['vocab'],
                              'tag_map': ret['tag_map'],
                             }
            print('check accept!',flush=True)
        else:
            print('check failed! ', ret['msg'],flush=True)
        return vob_dictionary

    def push_model_to_cloud(self, post_data,modelfilename='model_cache/local_model.pth'):
        """
        Call the api to push the model to cloud
        Arguments:
            post_data(dictionary):
                client_id (string): the mark for client
                model_id (string):  get from  request_join_federated
                model_name(string): get from request_join_federated
                model_version(string): add 1 Iteratively
                                       based getted from request_join_federated
        Notes:
            format: "pytorch", "tensorflow" et
        """
        print('[push api] start!!',flush=True)
        url = self.ip_port + '/update_model'
        filename = modelfilename
        files = {
            "files": open(filename, "rb")
        }
        postdata = {"client_id": post_data["client_id"],
                    "model_id": post_data["model_id"],
                    "model_name": post_data["model_name"],
                    "model_version": post_data["model_version"],
                    "format": "pytorch",
                    "reserve_param": ""}

        r = requests.post(url, data=postdata, files=files)
        print('push response: ', r.content,flush=True)

    def pull_model_from_cloud(self, post_data, modelfilename='model_cache/fl_model.pth'):
        """
        Call the api to pull the average model to cloud.
        average model is saved in 'modelfilename'.
        Arguments:
            post_data(dictionary):
                client_id (string): the mark for client
                model_id (string):  get from  request_join_federated
                model_name(string): get from request_join_federated
        Notes:
            format: "pytorch", "tensorflow" et
        """
        print('[pull request api] : start !!',flush=True)
        url = self.ip_port + '/download_model'
        data = {"client_id": post_data["client_id"],
                "model_id": post_data["model_id"],
                "model_name": post_data["model_name"],
                "format": "pytorch",
                "source":"hangyeai",
                "reserve_param": ""}
        save_model_name=modelfilename
        ret = requests.get(url, params=data)
        with open(save_model_name, "wb") as f:
            f.write(ret.content)
        print("[pull success!!!]",flush=True)
        new_name=ret.headers['Content-Disposition']
        return(new_name[:-4].split("_")[-1])

    def upload_status(self,post_data):
        print('[upload metric api] start!!',flush=True)
        url = self.ip_port + '/upload_status'
        r = requests.get(url, params = post_data)
        print('upload status response: ', r.content,flush=True)

    def quit(self,post_data):
        print('[quit api] start!!',flush=True)
        url = self.ip_port + '/quit'
        r = requests.get(url, params = post_data)
        print('quit response: ', r.content,flush=True)

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。