联邦学习中通过request与中心节点通信
import requests
import json
import os
'''def cost_time(func):
@wraps(func)
def wraper(*args, **kwargs):
start = time.time()
t = func(*args, **kwargs)
print('[run time] [%s] run time is %.2f' % (func.__name__, time.time() - start))
return t
return wraper'''
class RequestCommunicate(object):
def __init__(self):
self.ip_port = ""
def request_join_federated(self, client_id, model_name,model_description):
"""
Call the api request to join the joint
Arguments:
client_id(string): the mark for client
model_name(string): example--> 'C_BiLSTMCRF_NER'
Return:
whether_join(boolean): if join federated learning,
True is join, False is not join.
model_id(string): Unique ID of the model used for push and pull
model_name(string): the name for model used for push and pull
model_version(string): the version for model for push
"""
print('[join api] start!!',flush=True)
url = self.ip_port + "/join"
data = {
"client_id": client_id,
"model_name": model_name,
"model_description":model_description
}
ret = requests.get(url, params=data)
ret = json.loads(ret.content)
print('[join api] return: ', ret,flush=True)
if ret['code'] == 1:
whether_join = True
model_id = ret['model_id']
model_name = ret['model_name']
model_version = ret['model_version']
print('join accept: ', model_id, model_name, model_version,flush=True)
else:
print('join failed! ', ret['code'], ret['msg'],flush=True)
whether_join = False
model_id, model_name, model_version = "", "", ""
return whether_join, model_id, model_name, model_version
def request_check_model(self, voc_tag_maps, client_id):
"""
Post requests, Call the api request to check model and return dictionary
Arguments:
voc_tag_maps(dictionary): vocabulary and label dictionary
example: {'vocab': {'安':0, '平':1, '中':2, '国':3,},
'tag_map': {'O': 0, 'B-BODY': 5, 'E-BODY': 6, },
}
client_id(string): the mark for client
Return:
vob_dictionary(dictionary): global dictionary from central which like voc_tag_maps
Notes:
format: "pytorch", "tensorflow" et
class_name: the name of model class. example: BiLSTMCRF
class_filename: the model script name. example: base_model
they are change for different file named
:return:
"""
print('[check api] start!!',flush=True)
filename = os.path.join(os.getcwd(), 'base_model.py')
files = {
"model_binary": open(filename, "rb")
}
# with open(filename,'rb') as f:
# files=f.read()
# f.close()
url = self.ip_port + '/check_model'
postdata = {"client_id": client_id,
"format": "pytorch",
"class_name": "BiLSTMCRF",
"class_filename": "base_model",
"reserve_param": json.dumps(voc_tag_maps)}
ret = requests.post(url, data=postdata,files=files)
ret = json.loads(ret.content)
vob_dictionary = None
if ret['code'] == 1:
vob_dictionary = {'vocab': ret['vocab'],
'tag_map': ret['tag_map'],
}
print('check accept!',flush=True)
else:
print('check failed! ', ret['msg'],flush=True)
return vob_dictionary
def push_model_to_cloud(self, post_data,modelfilename='model_cache/local_model.pth'):
"""
Call the api to push the model to cloud
Arguments:
post_data(dictionary):
client_id (string): the mark for client
model_id (string): get from request_join_federated
model_name(string): get from request_join_federated
model_version(string): add 1 Iteratively
based getted from request_join_federated
Notes:
format: "pytorch", "tensorflow" et
"""
print('[push api] start!!',flush=True)
url = self.ip_port + '/update_model'
filename = modelfilename
files = {
"files": open(filename, "rb")
}
postdata = {"client_id": post_data["client_id"],
"model_id": post_data["model_id"],
"model_name": post_data["model_name"],
"model_version": post_data["model_version"],
"format": "pytorch",
"reserve_param": ""}
r = requests.post(url, data=postdata, files=files)
print('push response: ', r.content,flush=True)
def pull_model_from_cloud(self, post_data, modelfilename='model_cache/fl_model.pth'):
"""
Call the api to pull the average model to cloud.
average model is saved in 'modelfilename'.
Arguments:
post_data(dictionary):
client_id (string): the mark for client
model_id (string): get from request_join_federated
model_name(string): get from request_join_federated
Notes:
format: "pytorch", "tensorflow" et
"""
print('[pull request api] : start !!',flush=True)
url = self.ip_port + '/download_model'
data = {"client_id": post_data["client_id"],
"model_id": post_data["model_id"],
"model_name": post_data["model_name"],
"format": "pytorch",
"source":"hangyeai",
"reserve_param": ""}
save_model_name=modelfilename
ret = requests.get(url, params=data)
with open(save_model_name, "wb") as f:
f.write(ret.content)
print("[pull success!!!]",flush=True)
new_name=ret.headers['Content-Disposition']
return(new_name[:-4].split("_")[-1])
def upload_status(self,post_data):
print('[upload metric api] start!!',flush=True)
url = self.ip_port + '/upload_status'
r = requests.get(url, params = post_data)
print('upload status response: ', r.content,flush=True)
def quit(self,post_data):
print('[quit api] start!!',flush=True)
url = self.ip_port + '/quit'
r = requests.get(url, params = post_data)
print('quit response: ', r.content,flush=True)