Call the function train_on_batch_TB
def train_on_batch_TB(nb_epoch, batch_size, model,
X, Y, validation_x, validation_y,
model_file_path, TB_log_dir='./logs', ):
"""
train model by train_on_batch function
:param nb_epoch: number of epochs
:param batch_size:
:param model:
:param X: X shape is (samples, timesteps, features)
:param Y: Y shape is (samples, timesteps)
:param validation_x:
:param validation_y:
:param model_file_path:
:param TB_log_dir: folder of logs for tensorBoard
"""
callback = TensorBoard(TB_log_dir)
callback.set_model(model)
min_val_loss = float('inf') # store min value of validation loss
for epoch_no in range(nb_epoch):
print(f'{epoch_no}th train')
n_batch_per_epoch = len(X) // batch_size
batches_logs = np.zeros((n_batch_per_epoch, len(model.metrics_names))) # collect all batches
for batch_no in range(n_batch_per_epoch):
batch_start_index = batch_no * batch_size
_x, _y = X[batch_start_index:batch_start_index + batch_size], \
Y[batch_start_index:batch_start_index + batch_size]
# train
logs = model.train_on_batch(_x, _y)
batches_logs[batch_no] = [logs] if type(logs) != list else logs
if (batch_no + 1) % n_batch_per_epoch == 0:
write_log(callback, model.metrics_names, np.mean(batches_logs, axis=0), epoch_no)
val_loss = validate(model, batch_size, validation_x, validation_y)
write_log(callback, ['val_loss'], val_loss, epoch_no)
# save model
if isinstance(val_loss, Number) and val_loss < min_val_loss:
print(f"val_loss {val_loss:.4f} < {min_val_loss:.4f}, save the model.")
min_val_loss = val_loss
model.save(model_file_path, overwrite=True, include_optimizer=True)
model.reset_states()
def validate(model, batch_size,
validation_x, validation_y):
n_batch = len(validation_x) // batch_size
batches_logs = np.zeros((n_batch, len(model.metrics_names)))
for batch_no in range(n_batch):
batch_start_index = batch_no * batch_size
_x, _y = validation_x[batch_start_index:batch_start_index + batch_size], \
validation_y[batch_start_index:batch_start_index + batch_size]
# train and validate
logs = model.train_on_batch(_x, _y)
batches_logs[batch_no] = [logs] if type(logs) != list else logs
return np.mean(batches_logs, axis=0)[0]
def write_log(callback, names, logs, batch_no):
if isinstance(logs, Number): # check
logs = [logs]
for name, value in zip(names, logs):
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value
summary_value.tag = name
callback.writer.add_summary(summary, batch_no)
callback.writer.flush()