Tensorflow模块:tf.train.Checkpoint

tf.train.Checkpoint :变量的保存与恢复

  Tensorflow的Checkpoint机制将可追踪变量以二进制的方式储存成一个.ckpt文件,储存了变量的名称及对应张量的值。

  Checkpoint 只保存模型的参数,不保存模型的计算过程,因此一般用于在具有模型源代码的时候恢复之前训练好的模型参数。如果需要导出模型(无需源代码也能运行模型)。
  很多时候,我们希望在模型训练完成后能将训练好的参数(变量)保存起来。在需要使用模型的其他地方载入模型和参数,就能直接得到训练好的模型。可能你第一个想到的是用 Python 的序列化模块 pickle 存储 model.variables。但不幸的是,TensorFlow 的变量类型 ResourceVariable 并不能被序列化。

  好在 TensorFlow 提供了 tf.train.Checkpoint 这一强大的变量保存与恢复类,可以使用其 save()restore() 方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言,tf.keras.optimizertf.Variabletf.keras.Layer 或者 tf.keras.Model 实例都可以被保存。其使用方法非常简单,我们首先声明一个 Checkpoint:

checkpoint = tf.train.Checkpoint(model=model)

  这里 tf.train.Checkpoint() 接受的初始化参数比较特殊,是一个 **kwargs 。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。例如,如果我们希望保存一个继承 tf.keras.Model 的模型实例 model 和一个继承 tf.train.Optimizer 的优化器 optimizer ,我们可以这样写:

checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)

  这里 myAwesomeModel 是我们为待保存的模型 model 所取的任意键名。注意,在恢复变量的时候,我们还将使用这一键名。

接下来,当模型训练完成需要保存的时候,使用:

checkpoint.save(save_path_with_prefix)

就可以。 save_path_with_prefix 是保存文件的目录 + 前缀。

  • 注解

  例如,在源代码目录建立一个名为 save 的文件夹并调用一次 checkpoint.save('./save/model.ckpt') ,我们就可以在可以在 save 目录下发现名为 checkpointmodel.ckpt-1.indexmodel.ckpt-1.data-00000-of-00001 的三个文件,这些文件就记录了变量信息。checkpoint.save() 方法可以运行多次,每运行一次都会得到一个. index 文件和. data 文件,序号依次累加。
  当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。就像下面这样:

model_to_be_restored = MyModel()                                        # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)   # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)

即可恢复模型变量。 save_path_with_prefix_and_index 是之前保存的文件的目录 + 前缀 + 编号。例如,调用 checkpoint.restore('./save/model.ckpt-1') 就可以载入前缀为 model.ckpt ,序号为 1 的文件来恢复模型。

  当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path) 这个辅助函数f。例如如果 save 目录下有 model.ckpt-1.indexmodel.ckpt-10.index 的 10 个保存文件, tf.train.latest_checkpoint('./save') 即返回 ./save/model.ckpt-10

总体而言,恢复与保存变量的典型代码框架如下:

# train.py 模型训练阶段

model = MyModel()
# 实例化Checkpoint,指定保存对象为model(如果需要保存Optimizer的参数也可加入)
checkpoint = tf.train.Checkpoint(myModel=model)
# ...(模型训练代码)
# 模型训练完毕后将参数保存到文件(也可以在模型训练过程中每隔一段时间就保存一次)
checkpoint.save('./save/model.ckpt')
# test.py 模型使用阶段

model = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model)             # 实例化Checkpoint,指定恢复对象为model
checkpoint.restore(tf.train.latest_checkpoint('./save'))    # 从文件恢复模型参数
# 模型使用代码
  • 注解

  tf.train.Checkpoint 与以前版本常用的 tf.train.Saver 相比,强大之处在于其支持在即时执行模式下 “延迟” 恢复变量。具体而言,当调用了 checkpoint.restore() ,但模型中的变量还没有被建立的时候,Checkpoint 可以等到变量被建立的时候再进行数值的恢复。即时执行模式下,模型中各个层的初始化和变量的建立是在模型第一次被调用的时候才进行的(好处在于可以根据输入的张量形状而自动确定变量形状,无需手动指定)。这意味着当模型刚刚被实例化的时候,其实里面还一个变量都没有,这时候使用以往的方式去恢复变量数值是一定会报错的。比如,你可以试试在 train.py 调用 tf.keras.Modelsave_weight() 方法保存 model 的参数,并在 test.py 中实例化 model 后立即调用 load_weight() 方法,就会出错,只有当调用了一遍 model 之后再运行 load_weight() 方法才能得到正确的结果。可见, tf.train.Checkpoint 在这种情况下可以给我们带来相当大的便利。另外, tf.train.Checkpoint 同时也支持图执行模式。

在代码目录下建立 save 文件夹并运行代码进行训练后,save 文件夹内将会存放每隔 100 个 batch 保存一次的模型变量数据。在命令行参数中加入 --mode=test 并再次运行代码,将直接使用最后一次保存的变量值恢复模型并在测试集上测试模型性能,可以直接获得 95% 左右的准确率。

使用 tf.train.CheckpointManager 删除旧的 Checkpoint 以及自定义文件编号

在模型的训练过程中,我们往往每隔一定步数保存一个 Checkpoint 并进行编号。不过很多时候我们会有这样的需求:

  • 在长时间的训练后,程序会保存大量的 Checkpoint,但我们只想保留最后的几个 Checkpoint;

  • Checkpoint 默认从 1 开始编号,每次累加 1,但我们可能希望使用别的编号方式(例如使用当前 Batch 的编号作为文件编号)。

这时,我们可以使用 TensorFlow 的 tf.train.CheckpointManager 来实现以上需求。具体而言,在定义 Checkpoint 后接着定义一个 CheckpointManager:

checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)

此处, directory 参数为文件保存的路径, checkpoint_name 为文件名前缀(不提供则默认为 ckpt ), max_to_keep 为保留的 Checkpoint 数目。

在需要保存模型的时候,我们直接使用 manager.save() 即可。如果我们希望自行指定保存的 Checkpoint 的编号,则可以在保存时加入 checkpoint_number 参数。例如 manager.save(checkpoint_number=100)

以下是一个基于CIFAR10数据集的一个示例,读者可进行参考。


GPU环境测试

import tensorflow  as tf
# 使用显卡进行时,将GPU的显存使用策略设置为 “仅在需要时申请显存空间”,不然会申请所有显存空间,报错
gpu_devices = tf.config.experimental.list_physical_devices("GPU")
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)

返回运行时可见的物理设备列表,默认情况下,所有发现的CPU和GPU设备都被视为可见的。

tf.config.experimental.list_physical_devices(device_type=None)
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU')]

查看GPU设备信息

!nvidia-smi
Tue May 19 23:17:57 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64       Driver Version: 440.64       CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce RTX 2070    Off  | 00000000:01:00.0  On |                  N/A |
|  0%   51C    P8    19W / 175W |   4923MiB /  7979MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0       903      G   /usr/lib/xorg/Xorg                            66MiB |
|    0      1566      G   /usr/bin/gnome-shell                          85MiB |
|    0      5359      C   /home/wcjb/anaconda3/bin/python             4759MiB |
+-----------------------------------------------------------------------------+

检查GPU是否可用

tf.test.is_gpu_available()
True

查看GPU是否可用

tf.config.experimental.list_physical_devices('GPU')
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
# 启用设备放置日志记录将导致打印任何张量分配或操作
tf.debugging.set_log_device_placement(True)
tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)
VirtualDeviceConfiguration(memory_limit=1024)

数据处理

数据载入

  • CIFAR-10数据集

  CIFAR-10数据集是一个用于识别普适物体的小型数据集,它一共包含10个类别的RGB彩色图片:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。图片的尺寸为32x32,该数据集一共有50000张训练图片和10000张测试图片。

  1个10000x3072大小的uint8s数组。数组的每行存储1张32*32的图像,第1个1024包含红色通道值,下1个包含绿色,最后的1024包含蓝色。图像存储以行顺序为主,所以数组的前32列为图像第1行的红色通道值。

import pickle
import os
from PIL import Image
from tqdm import tqdm
import numpy as np
class CIFAR10(object):
    
    def __init__(self,path='/home/wcjb/Code/Dataset/cifar-10-batches-py/'):
        self.trainpath = [os.path.join(path,'data_batch_'+str(i+1)) for i in range(5)]
        self.testpath = [os.path.join(path,'test_batch')]
    
    def unpickle(self,file):
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding = 'iso-8859-1') # 
        return dict
    
    def load_batch(self,file):
        
        with open(file, 'rb')as f:
            datadict = self.unpickle(file)
            data = datadict['data']
            label = datadict['labels']
            data = data.reshape(10000, 3, 32, 32)
            label = np.array(label)
            return data,label
        
    def toimg(self,data):
        img = []
        for i in range(data.shape[0]):
            imgs = data[i - 1]
            r = imgs[0]
            g = imgs[1]
            b = imgs[2]
            R = Image.fromarray(r)
            G = Image.fromarray(g)
            B = Image.fromarray(b)
            
            img.append(Image.merge("RGB",(R,G,B)))
        return img
            
        
    def cif2img(self):
        
        train_img,test_img = [],[]
    
        for tp in tqdm(self.trainpath,desc='Train-img'):
            
            data,label = self.load_batch(tp)
            train_img.append(self.toimg(data))
            
        for tp in tqdm(self.testpath,desc='Test-img '):
            
            data,label = self.load_batch(tp)
            test_img.append(self.toimg(data))
            
            
        return train_img,test_img
        
    def cif2data(self):
        
        x_train,y_train,x_test,y_test = [],[],[],[]
        for tp in tqdm(self.trainpath,desc='Train'):
            data,label = self.load_batch(tp)
            x_train.append(data)
            y_train.append(label)
        for tp in tqdm(self.testpath,desc='Test '):
            data,label = self.load_batch(tp)
            x_test.append(data)
            y_test.append(label)
        x_train,y_train = np.array(x_train).reshape(-1,3,32,32),np.array(y_train).reshape(-1,)
        x_test,y_test = np.array(x_test).reshape(-1,3,32,32),np.array(y_test).reshape(-1,)
        x_train,x_test = np.rollaxis(x_train, 1,4),np.rollaxis(x_test,1, 4)
        
        return x_train,y_train,x_test,y_test
cif = CIFAR10()
# 将CIFAR10数据集加载为图片数据
train_img,test_img = cif.cif2img()
# 将将CIFAR10数据集加载为多维数据用于训练
x_train,y_train,x_test,y_test = cif.cif2data()
Train-img: 100%|██████████| 5/5 [00:02<00:00,  2.39it/s]
Test-img : 100%|██████████| 1/1 [00:00<00:00,  2.41it/s]
Train: 100%|██████████| 5/5 [00:00<00:00, 40.06it/s]
Test : 100%|██████████| 1/1 [00:00<00:00, 40.89it/s]

查看数据集样本的图片

import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(train_img[0][0])
<matplotlib.image.AxesImage at 0x7fdfd8cbe2d0>
plt.imshow(test_img[0][0])
<matplotlib.image.AxesImage at 0x7fdfd8bf4250>

查看数据集的样本的数组形态

x_train.shape
(50000, 32, 32, 3)

数据增强处理函数

  • 直方图均衡化

  直方图均衡化通常用来增加许多图像的全局对比度,尤其是当图像的有用数据的对比度相当接近的时候。通过这种方法,亮度可以更好地在直方图上分布。这样就可以用于增强局部的对比度而不影响整体的对比度,直方图均衡化通过有效地扩展常用的亮度来实现这种功能。这种方法对于背景和前景都太亮或者太暗的图像非常有用,这种方法尤其是可以带来X光图像中更好的骨骼结构显示以及曝光过度或者曝光不足照片中更好的细节。这种方法的一个主要优势是它是一个相当直观的技术并且是可逆操作,如果已知均衡化函数,那么就可以恢复原始的直方图,并且计算量也不大。这种方法的一个缺点是它对处理的数据不加选择,它可能会增加背景噪声的对比度并且降低有用信号的对比度。

import shutil
from PIL import Image
import sys
import cv2
from tqdm import notebook
class DataAugumentation(object):
    
    def __init__(self,num=10):
        
        self.num = num

    def CLAHE(self,img):
        grayimg = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        # 局部直方图均值化
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        cl1 = clahe.apply(grayimg)
        return cl1

    def Histograms_Equalization(self,img):
        
        grayimg = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        # 直方图均值化
        equ = cv2.equalizeHist(grayimg)
        return equ

    def make_one_hot(self,data):
        return (np.arange(self.num)==data[:,None]).astype(np.int64)
    
    def augument(self,imgs,labels):
        '''
        使用图像处理方法进行数据增强,直方图均值化和局部直方图均值化
        再加上灰度图和原图片可以将数据集增大三倍
        '''
        x_data,y_data = [],[]
        
        for img,label in notebook.tqdm(zip(imgs,labels),desc='数据增强进度'):
            
            imggray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            x_data.append(imggray.astype('float32') / 255.0)
            y_data.append(label)

            he_image = self.Histograms_Equalization(img)
            x_data.append(he_image.astype('float32') / 255.0)
            y_data.append(label)

            clahe_img = self.CLAHE(img)
            x_data.append(clahe_img.astype('float32') / 255.0)
            y_data.append(label)
            
        return np.array(x_data),np.array(y_data)

处理训练集

da = DataAugumentation()
x_new_train,y_new_train = da.augument(x_train,y_train)
x_new_test,y_new_test = da.augument(x_test,y_test)
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='数据增强进度', max=1.0, style=ProgressStyle(d…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='数据增强进度', max=1.0, style=ProgressStyle(d…

  • 扩展数据维度,计算卷积
x_new_train = np.expand_dims(x_new_train, 3)
x_new_test = np.expand_dims(x_new_test,3)

可以看到,数据集增强之后比较大,所以可以把增强后的数据集保存在本地方便再次复用。

pickle.dump(x_new_train, open('./CifaData/x_new_train.p', 'wb'))
pickle.dump(y_new_train, open('./CifaData/y_new_train.p', 'wb'))
pickle.dump(x_new_test, open('./CifaData/x_new_test.p', 'wb'))
pickle.dump(y_new_test, open('./CifaData/y_new_test.p', 'wb'))
!cd CifaData && ls -hl
总用量 705M
-rw-rw-r-- 1 wcjb wcjb 118M 5月  19 22:17 x_new_test.p
-rw-rw-r-- 1 wcjb wcjb 586M 5月  19 22:17 x_new_train.p
-rw-rw-r-- 1 wcjb wcjb 235K 5月  19 22:17 y_new_test.p
-rw-rw-r-- 1 wcjb wcjb 1.2M 5月  19 22:17 y_new_train.p
# with open('./CifaData/y_new_train.p', 'rb') as fo:
#     y_n_train = pickle.load(fo, encoding = 'iso-8859-1')

搭建模型

import tensorflow as tf
import datetime
import time

MODEL_DIR = "./models"

class network(tf.keras.Model):
    
    def __init__(self,n_class=10,learning_rate=1e-4):
        
        super(network,self).__init__()
        
        # 定义网络结构
        self.conv2d_01 = tf.keras.layers.Convolution2D (kernel_size = (5, 5),input_shape=(32,32,1), filters = 100, activation='relu')
        self.maxpool2d_01 = tf.keras.layers.MaxPool2D()
        self.conv2d_02 = tf.keras.layers.Convolution2D (kernel_size = (3, 3), filters = 150, activation='relu')
        self.maxpool2d_02 = tf.keras.layers.MaxPool2D()
        self.conv2d_03 = tf.keras.layers.Convolution2D (kernel_size = (3, 3), filters = 250, padding='same', activation='relu')
        self.maxpool2d_03 = tf.keras.layers.MaxPool2D()
        self.flatten = tf.keras.layers.Flatten()
        self.dense_01 = tf.keras.layers.Dense(512, activation='relu')
        self.dense_02 = tf.keras.layers.Dense(300, activation='relu')
        self.dense_03 = tf.keras.layers.Dense(10,activation='softmax')
        
        # 优化器
        self.optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
        # 确认模型日志目录是否存在,若不存在则创建
        if not tf.io.gfile.exists(MODEL_DIR):
            tf.io.gfile.makedirs(MODEL_DIR)
        # 申明训练和测试日志路径
        train_dir = os.path.join(MODEL_DIR, 'summaries', 'train')
        test_dir = os.path.join(MODEL_DIR, 'summaries', 'eval')
        
        # 根据给定文件在当前上下文环境中创建日志记录器,记录数据摘要,便于可视化及分析并且每个10000刷新
        self.train_summary_writer = tf.summary.create_file_writer(train_dir, flush_millis=10000)
        self.test_summary_writer = tf.summary.create_file_writer(test_dir, flush_millis=10000, name='test')
        
        # 将可追踪变量以二进制的方式储存成一个checkpoint 档(.ckpt),
        # 即储存变量的名字和对应的张量的数值。
        checkpoint_dir = os.path.join(MODEL_DIR, 'checkpoints')
        self.checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
        
        self.checkpoint = tf.train.Checkpoint(model=self, optimizer=self.optimizer)
        # 只保存最近10个模型文件
        tf.train.CheckpointManager(self.checkpoint, directory=checkpoint_dir, checkpoint_name='network.ckpt', max_to_keep=10)
        # 返回目录下最近一次checkpoint的文件名,并恢复模型参数
        self.checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
        
    def call(self,inputs):
        
        x = self.conv2d_01(inputs)
        x = self.maxpool2d_01(x)
        
        x = self.conv2d_02(x)
        x = self.maxpool2d_02(x)
        
        x = self.conv2d_03(x)
        x = self.maxpool2d_03(x)
        
        x = self.flatten(x)
      
        x = self.dense_01(x)
        x = self.dense_02(x)
        x = self.dense_03(x)
        
        return x
        
    @tf.function()
    def loss(self, logits, labels):
        return tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True))

    @tf.function()
    def accuracy(self, logits, labels):
        return tf.keras.metrics.sparse_categorical_accuracy(labels, logits)
    
    @tf.function(experimental_relax_shapes=True)
    def train_step(self, images, labels):
        
        with tf.device('/GPU:0'):
            with tf.GradientTape() as tape:
                # 前向计算
                logits = self.call(images)
                # 计算当前批次模型的损失函数
                loss = self.loss(logits, labels)
                # 计算当前批次的模型准确率
                accuracy = self.accuracy(logits, labels)
            #=====================反向过程=====================
            # 计算梯度
            grads = tape.gradient(loss, self.trainable_variables)
            # 使用梯度更新可训练集合的变量
            self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
            
        return loss, accuracy, logits

    def train(self, train_dataset, test_dataset, epochs=1, log_freq=50):
        
        for i in range(epochs):
            
            train_start = time.time()
            # 在该上下文环境中记录可追踪变量
            with self.train_summary_writer.as_default():
                
                start = time.time()
               # metrics指标是有状态的。当调用.result()时,会计算累计值并返回累计的结果。使用.reset_states()可以清除累积值
                avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)  
                avg_accuracy = tf.keras.metrics.Mean('accuracy', dtype=tf.float32)  

                for images, labels in train_dataset:

                    loss, accuracy, logits = self.train_step(images, labels)
                    # 持续纪律损失值
                    avg_loss(loss)
                    # 持续记录分类正确率
                    avg_accuracy(accuracy)
                    # 在训练log_freq次后,记录变量,并计算累计指标值
                    # optimizer.iterations 记录了优化器运行的训练步数
                    if tf.equal(self.optimizer.iterations % log_freq, 0):
                        # 在日志中写入变量的摘要
                        tf.summary.scalar('loss', avg_loss.result(), step=self.optimizer.iterations)
                        tf.summary.scalar('accuracy', avg_accuracy.result(), step=self.optimizer.iterations)
                        # 计算完成一个批次训练所需要的时间
                        rate = log_freq / (time.time() - start)
                        print('Step{} Loss: {:0.4f} accuracy: {:0.2f}% ({:0.2f} steps/sec)'.format(self.optimizer.iterations.numpy(), loss, (avg_accuracy.result() * 100), rate))
                        
                        # 清除当前训练批次的指标累计值,进入下一训练批次
                        avg_loss.reset_states()
                        avg_accuracy.reset_states()
                        start = time.time()

            train_end = time.time()
            
            print('\nTrain time for epoch: {} ({} total steps): {}'.format(i + 1, self.optimizer.iterations.numpy(), train_end - train_start))
            
            with self.test_summary_writer.as_default():
                self.test(test_dataset, self.optimizer.iterations)
            # 保存当前epoch的模型参数
            self.checkpoint.save(self.checkpoint_prefix)
            #在训练后保存模型会报错,暂时没有解决
#         self.export_path = os.path.join(MODEL_DIR, 'export')
#         tf.saved_model.save(self, self.export_path)
        
    
    def test(self, test_dataset, step_num):
        """
        评估模型在验证集上的正确率
        """
        
        avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)
        avg_accuracy = tf.keras.metrics.Mean('accuracy', dtype=tf.float32)
        # 只需要计算前向过程,需要计算相应指标
        for (images, labels) in test_dataset:
            logits = self.call(images)
            avg_loss(self.loss(logits, labels))
            avg_accuracy(self.accuracy(logits, labels))

        print('Test-Loss: {:0.4f} Test-Accuracy: {:0.2f}%'.format(avg_loss.result(), avg_accuracy.result() * 100))
        tf.summary.scalar('loss', avg_loss.result(), step=step_num)
        tf.summary.scalar('accuracy', avg_accuracy.result(), step=step_num)
            
    def evaluat(self, test_dataset):
        # 模型保存报错,暂未解决,故无法读取
#         restored_model = tf.saved_model.restore(self.export_path)
#         y_predict = restored_model(x_test)
        avg_accuracy = tf.keras.metrics.Mean('accuracy', dtype=tf.float32)

        for (images, labels) in test_dataset:
            logits = self.call(images)
            avg_accuracy(self.accuracy(logits, labels))

        print('Model accuracy: {:0.2f}%'.format(avg_accuracy.result() * 100))

    def forward(self, xs):
        """
        完成模型的前向计算,用于实际预测
        """
        predictions = self.call(xs)
        logits = tf.nn.softmax(predictions)

        return logits
    

使用tf.data.Dataset创建可迭代访问的数据集,便于按批次进行训练

# 由于用显卡进行训练,不是大显存请使用较小的Batch Size
val_dataset = tf.data.Dataset.from_tensor_slices((x_new_test.astype(np.float32), y_new_test))
val_dataset = val_dataset.shuffle(10000).batch(1024)

dataset = tf.data.Dataset.from_tensor_slices((x_new_train.astype(np.float32), y_new_train))
dataset = dataset.shuffle(5000).batch(1024)  
net = network()
net.train(dataset, val_dataset,1)
Step72550 Loss: 1.5211 accuracy: 94.14% (13.43 steps/sec)
Step72600 Loss: 1.5097 accuracy: 94.07% (15.38 steps/sec)
Step72650 Loss: 1.5188 accuracy: 94.38% (15.44 steps/sec)

Train time for epoch: 1 (72666 total steps): 11.80623173713684
Test-Loss: 1.7723 Test-Accuracy: 68.68%
net.forward(x_new_train[:1])
<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[0.0853368 , 0.0853368 , 0.0853368 , 0.08533724, 0.0853368 ,
        0.0853368 , 0.23196831, 0.0853368 , 0.0853368 , 0.0853368 ]],
      dtype=float32)>
y_new_train[:1]
array([6])

CIFAR-10数据集
如果有读者想自己在本地复现,可以参考我的代码:
CIFAR10-Tensorflow
由于文件较大,可能需要一定时间加载,欢迎大家尽情Star、Fork。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,657评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,662评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,143评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,732评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,837评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,036评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,126评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,868评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,315评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,641评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,773评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,470评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,126评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,859评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,095评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,584评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,676评论 2 351