【学习tensorflow2】有用的API汇总

  1. 动态显存分配
from tensorflow.compat.v1 import ConfigProto, InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
  1. 随机数种子: 为
tf.random.set_seed(2317)
  1. 混合精度
opt = Adam()
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
model.compile(optimizer=opt, loss="...")
  1. 中断训练与继续训练
reloaded = False
# 参数为键值对, 如global_epoch=global_epoch, 等式左边是key(自行定义), 右边是value(tf的变量, 模型, 优化器等).
checkpoint = tf.train.Checkpoint(global_epoch=global_epoch, model=model)
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if reloaded:
    checkpoint.restore(manager.latest_checkpoint)
while True:
    # Train.
    manager.save()
  1. 日志可视化
log_writer = tf.summary.create_file_writer(log_dir)
def write_log(l, name):
    with log_writer.as_default():
        tf.summary.scalar(name, l, step=global_epoch)
    log_writer.flush()
# 使用tensorboard --logdir [log_dir]可视化日志.
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。