作者:郎嬅不绘画_875 | 来源:互联网 | 2023-01-23 14:02
参考链接:
墙裂推荐:https://cloud.tencent.com/developer/article/1049579
英文版原文:https://machinelearningmastery.com/check-point-deep-learning-models-keras/
keras文档回调函数:http://keras-cn.readthedocs.io/en/latest/other/callbacks/#modelcheckpoint
先看一下ModelCheckpoint的参数:
keras.callbacks.ModelCheckpoint(
filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
period=1
)
1. filename:字符串,保存模型的路径
2. monitor:需要监视的值,val_acc或这val_loss
3. verbose:信息展示模式,0为不打印输出信息,1打印
4. save_best_only:当设置为True时,将只保存在验证集上性能最好的模型
5. mode:‘auto’,‘min’,‘max’之一,在save_best_Only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
6. save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)
7. period:CheckPoint之间的间隔的epoch数
代码实现过程:
① 从keras.callbacks导入ModelCheckpoint类
from keras.callbacks import ModelCheckpoint
② 在训练阶段的model.compile之后加入下列代码实现每一次epoch(period=1)保存最好的参数
checkpoint = ModelCheckpoint(filepath,
monitor='val_loss', save_weights_only=True,verbose=1,save_best_only=True, period=1)
提醒:filepath为保存参数的路径,我这里是"logs/000/trained_best_weights.h5"
③ 在训练阶段的model.fit之前加载先前保存的参数
if os.path.exists(filepath):
model.load_weights(filepath)
# 若成功加载前面保存的参数,输出下列信息
print("checkpoint_loaded")
④ 在model.fit添加callbacks=[checkpoint]实现回调
model.fit_generator(data_generator_wrap(lines[:num_train], batch_size, input_shape, anchors, num_classes),
steps_per_epoch=max(1, num_train//batch_size),
validation_data=data_generator_wrap(lines[num_train:], batch_size, input_shape, anchors, num_classes),
validation_steps=max(1, num_val//batch_size),
epochs=3,
initial_epoch=0,
callbacks=[checkpoint])
测试输出:
① 第一次输出,没有参数可以加载,不会打印“checkpoint_loaded”,输出如下(测试epoch=3)
② 再次执行train.py,利用刚才的代码可以直接在model.fit之前加载保存前一次训练的参数,继续训练(loss的变化)。注意输出了“checkpoint_load”表示成功加载前面保存的参数
提示:参考链接中有简单的测试代码,以上仅在我的训练数据上的所做的测试,更多详细内容请阅读参考链接
The end.