https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification
1. 安装
此过程第一次运行会很慢
# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip3 uninstall --yes tensorboard tb-nightly
!pip3 install --quiet --upgrade tensorflow-federated-nightly
!pip3 install --quiet --upgrade nest-asyncio
!pip3 install --quiet --upgrade tb-nightly # or tensorboard, but not both
import nest_asyncio
nest_asyncio.apply()
# Load the TensorBoard notebook extension
%load_ext tensorboard
2. 测试
import collections
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
np.random.seed(0)
tff.federated_computation(lambda: 'Hello, World!')()
如果成功输出,说明安装成功
b'Hello, World!'
3. 准备数据集
# 加载数据
# 返回tff.simulation.ClientData实例对
# https://www.tensorflow.org/federated/api_docs/python/tff/simulation/ClientData
# Each client's local dataset is represented as a tf.data.Dataset
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
# 查看数据集的客户端的数量
len(emnist_train.client_ids)
# 查看训练数据集的结构
emnist_train.element_type_structure
4. 可视化数据
# 返回tf.data.Dataset实例
# https://www.tensorflow.org/api_docs/python/tf/data/Dataset
example_dataset = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[0])
# 看第一个例子的label
example_element = next(iter(example_dataset))
example_element['label'].numpy()
# 看第一个例子的图片
!pip3 install matplotlib
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()
# Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0
for example in example_dataset.take(40):
plt.subplot(4, 10, j+1)
plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
plt.axis('off')
j += 1
# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
client_dataset = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[i])
plot_data = collections.defaultdict(list)
for example in client_dataset:
# Append counts individually per label to make plots
# more colorful instead of one color per plot.
label = example['label'].numpy()
plot_data[label].append(label)
plt.subplot(2, 3, i+1)
plt.title('Client {}'.format(i))
for j in range(10):
plt.hist(
plot_data[j],
density=False,
bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.
for i in range(5):
client_dataset = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[i])
plot_data = collections.defaultdict(list)
for example in client_dataset:
plot_data[example['label'].numpy()].append(example['pixels'].numpy())
f = plt.figure(i, figsize=(12, 5))
f.suptitle("Client #{}'s Mean Image Per Label".format(i))
for j in range(10):
mean_img = np.mean(plot_data[j], 0)
plt.subplot(2, 5, j+1)
plt.imshow(mean_img.reshape((28, 28)))
plt.axis('off')
5. 准备数据的输入
# 随机的次数
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10
def preprocess(dataset):
# flatten the 28x28 images into 784-element arrays
# shuffle the individual examples, organize them into batches
# rename the features from pixels and label to x and y for use with Keras
def batch_format_fn(element):
"""Flatten a batch `pixels` and return the features as an `OrderedDict`."""
#
return collections.OrderedDict(
# 返回Tensor类型
# 转化为784列的数组
x=tf.reshape(element['pixels'], [-1, 784]),
# 转化为1列的数组
y=tf.reshape(element['label'], [-1, 1]))
# Repeats this dataset so each original value is seen count times
# Randomly shuffles the elements of this dataset.
return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)
preprocessed_example_dataset = preprocess(example_dataset)
# 把每一个都变成numpy
# The .numpy() method explicitly converts a Tensor to a numpy array
sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
next(iter(preprocessed_example_dataset)))
sample_batch
# One of the ways to feed federated data to TFF in a simulation is simply as a Python list
# with each element of the list holding the data of an individual user
def make_federated_data(client_data, client_ids):
return [
preprocess(client_data.create_tf_dataset_for_client(x))
for x in client_ids
]
What we'll do instead is sample the set of clients once, and reuse the same set across rounds to speed up convergence (intentionally over-fitting to these few user's data).
We leave it as an exercise for the reader to modify this tutorial to simulate random sampling
NUM_CLIENTS = 10
# 取随机的前十个客户的id
sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]
# 十个客户的数据
federated_train_data = make_federated_data(emnist_train, sample_clients)
print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
6. 创建模型
def create_keras_model():
return tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer='zeros'),
tf.keras.layers.Softmax(),
])
def model_fn():
# We _must_ create a new model here, and _not_ capture it from an external
# scope. TFF will call this within different graph contexts.
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=preprocessed_example_dataset.element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
7. 模型训练
# Keep in mind that the argument needs to be a constructor (such as model_fn above)
# not an already-constructed instance
# Return tff.templates.IterativeProcess.
# https://www.tensorflow.org/federated/api_docs/python/tff/templates/IterativeProcess
iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
# The _clientoptimizer is only used to compute local model updates on each client.
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
# The _serveroptimizer applies the averaged update to the global model at the server.
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
# Returns the TFF type of this object (an instance of tff.Type).
str(iterative_process.initialize.type_signature)
# invoke the initialize computation to construct the server state.
state = iterative_process.initialize()
The second of the pair of federated computations, next
, represents a single round of Federated Averaging,
which consists of pushing the server state (including the model parameters) to the clients, on-device training on their local data, collecting and averaging model updates, and producing a new updated model at the server.
Conceptually, you can think of next
as having a functional type signature that looks as follows.
SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS
In particular, one should think about next()
not as being a function that runs on a server,
but rather being a declarative functional representation of the entire decentralized computation
some of the inputs are provided by the server (SERVER_STATE
), but each participating device contributes its own local dataset.
# run a single round of training and visualize the results.
state, metrics = iterative_process.next(state, federated_train_data)
print(f'round 1, metrics={metrics}')
round 1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13950618), ('loss', 2.9957898)])), ('stat', OrderedDict([('num_examples', 4860)]))])
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
state, metrics = iterative_process.next(state, federated_train_data)
print(f'round {round_num}, metrics={metrics}')
随机尝试
import random
def ranfomClients(NUM_CLIENTS):
# NUM_CLIENTS<=300
startId = random.randint(0, 3000)
# 取随机的前十个客户的id
sample_clients = emnist_train.client_ids[startId:NUM_CLIENTS+startId]
# 十个客户的数据
federated_train_data = make_federated_data(emnist_train, sample_clients)
print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
return federated_train_data
NUM_CLIENTS = 100
federated_train_data = ranfomClients(NUM_CLIENTS)
state, metrics = iterative_process.next(state, federated_train_data)
print('round 1, metrics={}'.format(metrics))
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
federated_train_data = ranfomClients(NUM_CLIENTS)
state, metrics = iterative_process.next(state, federated_train_data)
print(f'round {round_num}, state={state}, metrics={metrics}')
8. 在TensorBoard展示
logdir = "/tmp/logs/scalars/training/"
# Creates a summary file writer for the given log directory.
summary_writer = tf.summary.create_file_writer(logdir)
# 重新训练
state = iterative_process.initialize()
with summary_writer.as_default():
for round_num in range(1, NUM_ROUNDS):
state, metrics = iterative_process.next(state, federated_train_data)
for name, value in metrics['train'].items():
tf.summary.scalar(name, value, step=round_num)
# 打开TensorBoard
!ls {logdir}
%tensorboard --logdir {logdir} --port=0
9. 自定义模型
Defining model variables, forward pass, and metrics
This will include variables such as weights
and bias
that we will train, as well as variables that will hold various cumulative statistics and counters we will update during training, such as loss_sum
, accuracy_sum
, and num_examples
.
MnistVariables = collections.namedtuple(
'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')
Wrapping variable initializers as lambdas is a requirement imposed by resource variables.
'''
tf.Variable(
initial_value=None, trainable=None, validate_shape=True, caching_device=None,
name=None, variable_def=None, dtype=None, import_scope=None, constraint=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.compat.v1.VariableAggregation.NONE, shape=None
)
The Variable() constructor requires an initial value for the variable, which can be a Tensor of any type and shape. This initial value defines the type and shape of the variable. After construction, the type and shape of the variable are fixed. The value can be changed using one of the assign methods.
'''
def create_mnist_variables():
return MnistVariables(
weights=tf.Variable(
lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
name='weights',
trainable=True),
bias=tf.Variable(
lambda: tf.zeros(dtype=tf.float32, shape=(10)),
name='bias',
trainable=True),
num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))
we can now define the forward pass method that computes loss, emits predictions, and updates the cumulative statistics for a single batch of input data
def mnist_forward_pass(variables, batch):
y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
predictions = tf.cast(tf.argmax(y, 1), tf.int32)
flat_labels = tf.reshape(batch['y'], [-1])
loss = -tf.reduce_mean(
tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
accuracy = tf.reduce_mean(
tf.cast(tf.equal(predictions, flat_labels), tf.float32))
num_examples = tf.cast(tf.size(batch['y']), tf.float32)
variables.num_examples.assign_add(num_examples)
variables.loss_sum.assign_add(loss * num_examples)
variables.accuracy_sum.assign_add(accuracy * num_examples)
return loss, predictions
we define a function that returns a set of local metrics, again using TensorFlow.
Here, we simply return the average loss
and accuracy
, as well as the num_examples
, which we'll need to correctly weight the contributions from different users when computing federated aggregates.
def get_local_mnist_metrics(variables):
return collections.OrderedDict(
num_examples=variables.num_examples,
loss=variables.loss_sum / variables.num_examples,
accuracy=variables.accuracy_sum / variables.num_examples)
Finally, we need to determine how to aggregate the local metrics emitted by each device via get_local_mnist_metrics
.
@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
return collections.OrderedDict(
num_examples=tff.federated_sum(metrics.num_examples),
loss=tff.federated_mean(metrics.loss, metrics.num_examples),
accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))
The input metrics
argument corresponds to the OrderedDict
returned by get_local_mnist_metrics
above, but critically the values are no longer tf.Tensors
- they are "boxed" as tff.Value
s, to make it clear you can no longer manipulate them using TensorFlow, but only using TFF's federated operators like tff.federated_mean
and tff.federated_sum
. The returned dictionary of global aggregates defines the set of metrics which will be available on the server.
Constructing an instance of tff.learning.Model
With all of the above in place, we are ready to construct a model representation for use with TFF similar to one that's generated for you when you let TFF ingest a Keras model.
class MnistModel(tff.learning.Model):
def __init__(self):
self._variables = create_mnist_variables()
@property
def trainable_variables(self):
return [self._variables.weights, self._variables.bias]
@property
def non_trainable_variables(self):
return []
@property
def local_variables(self):
return [
self._variables.num_examples, self._variables.loss_sum,
self._variables.accuracy_sum
]
@property
def input_spec(self):
return collections.OrderedDict(
x=tf.TensorSpec([None, 784], tf.float32),
y=tf.TensorSpec([None, 1], tf.int32))
@tf.function
def forward_pass(self, batch, training=True):
del training
loss, predictions = mnist_forward_pass(self._variables, batch)
num_exmaples = tf.shape(batch['x'])[0]
return tff.learning.BatchOutput(
loss=loss, predictions=predictions, num_examples=num_exmaples)
@tf.function
def report_local_outputs(self):
return get_local_mnist_metrics(self._variables)
@property
def federated_output_computation(self):
return aggregate_mnist_metrics_across_clients
As you can see, the abstract methods and properties defined by tff.learning.Model
corresponds to the code snippets in the preceding section that introduced the variables and defined the loss and statistics.
Simulating federated training with the new model
just replace the model constructor with the constructor of our new model class, and use the two federated computations in the iterative process you created to cycle through training rounds.
iterative_process = tff.learning.build_federated_averaging_process(
MnistModel,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round 1, metrics={}'.format(metrics))
for round_num in range(2, 11):
state, metrics = iterative_process.next(state, federated_train_data)
print('round {:2d}, metrics={}'.format(round_num, metrics))
10. 评价
To perform evaluation on federated data, you can construct another federated computation designed for just this purpose, using the tff.learning.build_federated_evaluation
function, and passing in your model constructor as an argument. Note that unlike with Federated Averaging, where we've used MnistTrainableModel
, it suffices to pass the MnistModel
. Evaluation doesn't perform gradient descent, and there's no need to construct optimizers.
For experimentation and research, when a centralized test dataset is available, Federated Learning for Text Generation demonstrates another evaluation option: taking the trained weights from federated learning, applying them to a standard Keras model, and then simply calling tf.keras.models.Model.evaluate()
on a centralized dataset.
evaluation = tff.learning.build_federated_evaluation(MnistModel)
str(evaluation.type_signature)
train_metrics = evaluation(state.model, federated_train_data)
str(train_metrics)
Now, let's compile a test sample of federated data and rerun evaluation on the test data. The data will come from the same sample of real users, but from a distinct held-out data set.
federated_test_data = make_federated_data(emnist_test, sample_clients)
len(federated_test_data), federated_test_data[0]
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
This concludes the tutorial. We encourage you to play with the parameters (e.g., batch sizes, number of users, epochs, learning rates, etc.), to modify the code above to simulate training on random samples of users in each round, and to explore the other tutorials we've developed.