tf.keras为我们提供了易用的TF API,其中keras.Model是最重要的API之一,它封装了模型的参数、结构等信息及训练、测试等过程。为了让用户能够更好地定制训练的过程,TF 2.2为该API引入了新的可扩展接口。

在TensorFlow开发者峰会2020(TF Dev Summit '20)中,相关人员介绍了TF 2.2为keras.Model引入的自定义训练过程接口train_step:

在之前的版本中,虽然tf.keras的keras.Model模型封装了模型的训练过程,但由于这种封装过于黑盒,使得许多开发者并不愿意使用keras.Model自带的训练功能,而选择显式地调用tf.GradientTape等来进行反向传播和参数更新。一般,开发者会定义如下的训练过程:``` def train_step(images, labels): with tf.GradientTape() as tape: logits = mnist_model(images, training=True)

Add asserts to check the shape of the output.

tf.debugging.assert_equal(logits.shape, (32, 10))

loss_value = loss_object(labels, logits)

loss_history.append(loss_value.numpy().mean()) grads = tape.gradient(loss_value, mnist_model.trainable_variables) optimizer.apply_gradients(zip(grads, mnist_model.trainable_variables))




然后通过循环来手动调度训练过程:```
        def train(epochs):
 for epoch in range(epochs):
 for(batch, (images, labels)) in enumerate(dataset):
 train_step(images, labels)
 print('Epoch {} finished'.format(epoch))

keras.Model自带了许多非常好用的功能,例如进度显示、基于回调的TensorBoard日志、基于回调的Early Stop等。一般需要使用keras.Model自带的训练机制才可以享受到这些便捷的功能,上面这种手动调用的方法虽然能够让开发者对训练过程有着完全的掌控,但也使得他们不能享受部分keras.Model自带的便捷功能。

TF 2.2在keras.Model类中直接引入了train_step方法,这样开发者只需要在继承keras.Model模型时用自定义的方法覆盖父类中train_step的方法,就可以自定义可控的训练过程,并使用keras.Model自带的调度机制来进行训练:

参考链接:

成为VIP会员查看完整内容
35

相关内容

Google发布的第二代深度学习系统TensorFlow
【IJCAI2020-华为诺亚】面向深度强化学习的策略迁移框架
专知会员服务
25+阅读 · 2020年5月25日
《强化学习—使用 Open AI、TensorFlow和Keras实现》174页pdf
专知会员服务
136+阅读 · 2020年3月1日
【模型泛化教程】标签平滑与Keras, TensorFlow,和深度学习
专知会员服务
20+阅读 · 2019年12月31日
【干货】谷歌Joshua Gordon 《TensorFlow 2.0讲解》,63页PPT
专知会员服务
24+阅读 · 2019年11月2日
tf.GradientTape 详解
TensorFlow
117+阅读 · 2020年2月21日
使用 Keras Tuner 调节超参数
TensorFlow
15+阅读 · 2020年2月6日
TF Boys必看!一文搞懂TensorFlow 2.0新架构!
引力空间站
18+阅读 · 2019年1月16日
基于Keras进行迁移学习
论智
12+阅读 · 2018年5月6日
keras系列︱深度学习五款常用的已训练模型
数据挖掘入门与实战
10+阅读 · 2018年3月27日
Arxiv
6+阅读 · 2019年7月29日
Physical Primitive Decomposition
Arxiv
4+阅读 · 2018年9月13日
Arxiv
3+阅读 · 2018年8月17日
VIP会员
相关VIP内容
相关资讯
tf.GradientTape 详解
TensorFlow
117+阅读 · 2020年2月21日
使用 Keras Tuner 调节超参数
TensorFlow
15+阅读 · 2020年2月6日
TF Boys必看!一文搞懂TensorFlow 2.0新架构!
引力空间站
18+阅读 · 2019年1月16日
基于Keras进行迁移学习
论智
12+阅读 · 2018年5月6日
keras系列︱深度学习五款常用的已训练模型
数据挖掘入门与实战
10+阅读 · 2018年3月27日
微信扫码咨询专知VIP会员