Federated Learning for Image Classification

https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification

1. 安装

此过程第一次运行会很慢

# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip3 uninstall --yes tensorboard tb-nightly

!pip3 install --quiet --upgrade tensorflow-federated-nightly
!pip3 install --quiet --upgrade nest-asyncio
!pip3 install --quiet --upgrade tb-nightly  # or tensorboard, but not both

import nest_asyncio
nest_asyncio.apply()
# Load the TensorBoard notebook extension
%load_ext tensorboard

2. 测试

import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()

如果成功输出,说明安装成功

b'Hello, World!'

3. 准备数据集

# 加载数据
# 返回tff.simulation.ClientData实例对
# https://www.tensorflow.org/federated/api_docs/python/tff/simulation/ClientData
# Each client's local dataset is represented as a tf.data.Dataset
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
# 查看数据集的客户端的数量
len(emnist_train.client_ids)
# 查看训练数据集的结构
emnist_train.element_type_structure

4. 可视化数据

# 返回tf.data.Dataset实例
# https://www.tensorflow.org/api_docs/python/tf/data/Dataset
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

# 看第一个例子的label
example_element = next(iter(example_dataset))
example_element['label'].numpy()
# 看第一个例子的图片
!pip3 install matplotlib
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()
# Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0
for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1
# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.
for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

5. 准备数据的输入

# 随机的次数
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  # flatten the 28x28 images into 784-element arrays
  # shuffle the individual examples, organize them into batches
  # rename the features from pixels and label to x and y for use with Keras
  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    # 
    return collections.OrderedDict(
        # 返回Tensor类型
        # 转化为784列的数组
        x=tf.reshape(element['pixels'], [-1, 784]),
         # 转化为1列的数组
        y=tf.reshape(element['label'], [-1, 1]))

  # Repeats this dataset so each original value is seen count times
  # Randomly shuffles the elements of this dataset.
  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)
preprocessed_example_dataset = preprocess(example_dataset)
# 把每一个都变成numpy
# The .numpy() method explicitly converts a Tensor to a numpy array
sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))
sample_batch
# One of the ways to feed federated data to TFF in a simulation is simply as a Python list
# with each element of the list holding the data of an individual user
def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

What we'll do instead is sample the set of clients once, and reuse the same set across rounds to speed up convergence (intentionally over-fitting to these few user's data).

We leave it as an exercise for the reader to modify this tutorial to simulate random sampling

NUM_CLIENTS = 10
# 取随机的前十个客户的id
sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]
# 十个客户的数据
federated_train_data = make_federated_data(emnist_train, sample_clients)
print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))

6. 创建模型

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])
def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

7. 模型训练

# Keep in mind that the argument needs to be a constructor (such as model_fn above)
# not an already-constructed instance
# Return tff.templates.IterativeProcess.
# https://www.tensorflow.org/federated/api_docs/python/tff/templates/IterativeProcess
iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    # The _clientoptimizer is only used to compute local model updates on each client.
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    # The _serveroptimizer applies the averaged update to the global model at the server.
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
# Returns the TFF type of this object (an instance of tff.Type).
str(iterative_process.initialize.type_signature)
# invoke the initialize computation to construct the server state.
state = iterative_process.initialize()

The second of the pair of federated computations, next, represents a single round of Federated Averaging,

which consists of pushing the server state (including the model parameters) to the clients, on-device training on their local data, collecting and averaging model updates, and producing a new updated model at the server.

Conceptually, you can think of next as having a functional type signature that looks as follows.

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

In particular, one should think about next() not as being a function that runs on a server,

but rather being a declarative functional representation of the entire decentralized computation

some of the inputs are provided by the server (SERVER_STATE), but each participating device contributes its own local dataset.

# run a single round of training and visualize the results.
state, metrics = iterative_process.next(state, federated_train_data)
print(f'round  1, metrics={metrics}')
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13950618), ('loss', 2.9957898)])), ('stat', OrderedDict([('num_examples', 4860)]))])
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print(f'round {round_num}, metrics={metrics}')

随机尝试

import random

def ranfomClients(NUM_CLIENTS):
    # NUM_CLIENTS<=300
    startId = random.randint(0, 3000)
    # 取随机的前十个客户的id
    sample_clients = emnist_train.client_ids[startId:NUM_CLIENTS+startId]
    # 十个客户的数据
    federated_train_data = make_federated_data(emnist_train, sample_clients)
    print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
    print('First dataset: {d}'.format(d=federated_train_data[0]))
    return federated_train_data
NUM_CLIENTS = 100
federated_train_data = ranfomClients(NUM_CLIENTS)
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
    federated_train_data = ranfomClients(NUM_CLIENTS)
    state, metrics = iterative_process.next(state, federated_train_data)
    print(f'round {round_num}, state={state}, metrics={metrics}')

8. 在TensorBoard展示

logdir = "/tmp/logs/scalars/training/"
# Creates a summary file writer for the given log directory.
summary_writer = tf.summary.create_file_writer(logdir)
# 重新训练
state = iterative_process.initialize()
with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)
# 打开TensorBoard
!ls {logdir}
%tensorboard --logdir {logdir} --port=0

9. 自定义模型

Defining model variables, forward pass, and metrics

This will include variables such as weights and bias that we will train, as well as variables that will hold various cumulative statistics and counters we will update during training, such as loss_sum, accuracy_sum, and num_examples.

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Wrapping variable initializers as lambdas is a requirement imposed by resource variables.

'''
tf.Variable(
    initial_value=None, trainable=None, validate_shape=True, caching_device=None,
    name=None, variable_def=None, dtype=None, import_scope=None, constraint=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.compat.v1.VariableAggregation.NONE, shape=None
)
The Variable() constructor requires an initial value for the variable, which can be a Tensor of any type and shape. This initial value defines the type and shape of the variable. After construction, the type and shape of the variable are fixed. The value can be changed using one of the assign methods.
'''
def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

we can now define the forward pass method that computes loss, emits predictions, and updates the cumulative statistics for a single batch of input data

def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

we define a function that returns a set of local metrics, again using TensorFlow.

Here, we simply return the average loss and accuracy, as well as the num_examples, which we'll need to correctly weight the contributions from different users when computing federated aggregates.

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

Finally, we need to determine how to aggregate the local metrics emitted by each device via get_local_mnist_metrics.

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))

The input metrics argument corresponds to the OrderedDict returned by get_local_mnist_metrics above, but critically the values are no longer tf.Tensors - they are "boxed" as tff.Values, to make it clear you can no longer manipulate them using TensorFlow, but only using TFF's federated operators like tff.federated_mean and tff.federated_sum. The returned dictionary of global aggregates defines the set of metrics which will be available on the server.

Constructing an instance of tff.learning.Model

With all of the above in place, we are ready to construct a model representation for use with TFF similar to one that's generated for you when you let TFF ingest a Keras model.

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

As you can see, the abstract methods and properties defined by tff.learning.Model corresponds to the code snippets in the preceding section that introduced the variables and defined the loss and statistics.

Simulating federated training with the new model

just replace the model constructor with the constructor of our new model class, and use the two federated computations in the iterative process you created to cycle through training rounds.

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))

10. 评价

To perform evaluation on federated data, you can construct another federated computation designed for just this purpose, using the tff.learning.build_federated_evaluation function, and passing in your model constructor as an argument. Note that unlike with Federated Averaging, where we've used MnistTrainableModel, it suffices to pass the MnistModel. Evaluation doesn't perform gradient descent, and there's no need to construct optimizers.

For experimentation and research, when a centralized test dataset is available, Federated Learning for Text Generation demonstrates another evaluation option: taking the trained weights from federated learning, applying them to a standard Keras model, and then simply calling tf.keras.models.Model.evaluate() on a centralized dataset.

evaluation = tff.learning.build_federated_evaluation(MnistModel)
str(evaluation.type_signature)
train_metrics = evaluation(state.model, federated_train_data)
str(train_metrics)

Now, let's compile a test sample of federated data and rerun evaluation on the test data. The data will come from the same sample of real users, but from a distinct held-out data set.

federated_test_data = make_federated_data(emnist_test, sample_clients)
len(federated_test_data), federated_test_data[0]
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)

This concludes the tutorial. We encourage you to play with the parameters (e.g., batch sizes, number of users, epochs, learning rates, etc.), to modify the code above to simulate training on random samples of users in each round, and to explore the other tutorials we've developed.

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 220,063评论 6 510
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,805评论 3 396
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 166,403评论 0 357
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 59,110评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 68,130评论 6 395
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,877评论 1 308
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,533评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,429评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,947评论 1 319
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 38,078评论 3 340
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,204评论 1 352
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,894评论 5 347
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,546评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,086评论 0 23
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,195评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,519评论 3 375
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,198评论 2 357

推荐阅读更多精彩内容

  • 今天感恩节哎,感谢一直在我身边的亲朋好友。感恩相遇!感恩不离不弃。 中午开了第一次的党会,身份的转变要...
    迷月闪星情阅读 10,567评论 0 11
  • 彩排完,天已黑
    刘凯书法阅读 4,220评论 1 3
  • 表情是什么,我认为表情就是表现出来的情绪。表情可以传达很多信息。高兴了当然就笑了,难过就哭了。两者是相互影响密不可...
    Persistenc_6aea阅读 125,199评论 2 7