# general config
dir_output = "results/test/"
dir_model = dir_output + "model.weights/"
path_log = dir_output + "log.txt"
这是设置的模型保存地址。
def save_session(self):
"""Saves session = weights"""
if not os.path.exists(self.config.dir_model):
os.makedirs(self.config.dir_model)
self.saver.save(self.sess, self.config.dir_model)
这里模型的保存用到两个参数:session对象和存储位置。
接下来是模型的读取。
model.restore_session(config.dir_model)
def restore_session(self, dir_model):
"""Reload weights into session
Args:
sess: tf.Session()
dir_model: dir with weights
"""
self.logger.info("Reloading the latest trained model...")
self.saver.restore(self.sess, dir_model)