admin 管理员组文章数量: 887032
Tensorflow 2.* 网络训练(二) fit(x, y, batch
在完成数据集合,网络搭建、以及训练编译设置以后,最后就是要开始训练(拟合)网络
tf.keras.Model.fit
如下fit的参数是相对比较多的,且参数间相互关系较为复杂
fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,validation_split=0.0, validation_data=None, shuffle=True, class_weight=None,sample_weight=None, initial_epoch=0, steps_per_epoch=None,validation_steps=None, validation_batch_size=None, validation_freq=1,max_queue_size=10, workers=1, use_multiprocessing=False
)
x,y
分别对应网络的input和target data,支持如下的数据输入格式,但是必须保证x和y的格式一致
。
- Numpy array,或者 Numpy array列表
- Tensorflow tensor ,或者Tensorflow tenso列表
- 字典,name为网络Input层的name,数据为input layer name对应的Numpy array或Tensorflow tensor
- tf.data.dataset, 返回(inputs, targets) or (inputs, targets, sample_weights) 如何创建tf.data.dataset参考博文
- 生成器generator 或者 keras.utils.Sequence,返回(inputs, targets) or (inputs, targets, sample_weights)如何创建生成器参考博文
注:前三种情况适用于数据集合不大,可一次性读入内存的情况,后两种适用于大样本数据集合。
batch_size
对于x,y非5的数据集格式
,定义每次训练的批量数(整数型),默认为32。
epochs
训练模型的次数,一个epoch表示x,y中完整数据迭代完一次。
注意在与initial_epoch配合适用时,epoch表示训练到第几个epoch结束,即“final epoch”,而非需要训练多少eopochs
verbose
verbose:日志显示
verbose = 0 为不在标准输出流输出日志信息
verbose = 1 为输出进度条记录(默认)
verbose = 2 为每个epoch输出一行记录
validation_split
对于x,y非4和5的数据集格式
,0到1之间的浮点数,即此部分的数据被划分验证集。数据集中序列靠后的数据进行划分
validation_data
优先级高于validation_split,定义验证集合,用于每组epoch训练集之后,在验证集上评价loss和metrics。格式(x_val, y_val) or (x_val, y_val, val_sample_weights)元组。
对于数据1和2,必须定义batch_size;
对于数据5,允许定义validation_steps ;
暂不支持数据3和5(适用fit_generator);
shuffle
布尔型是否打乱数据集;
当数据类型为generator时,shuffle失效;
当step_per_epoch=None时,shuffle失效;
class_weight
告诉网络那类标签对于计算loss更为重要,需要pay more attention
字典类型,类索引(整数):权重值(浮点型)
sample_weight
作用类似class_weight,对输入样本赋予权重
可以是与输入数据长度相同的1D numpy array ,也可以是 2D array with shape (samples, sequence_length),其中sequence_length表示每个样本在时刻内的时间戳
initial_epoch
从第几epooch开始训练,适用于重启之前的训练
steps_per_epoch
网络训练过程中,每个epoch的step数(每个step表示代入batch_size的样本后网络,权值更新的过程)
默认值等于样本总数除以batch_size,对于无限迭代的数据集必须定义
numpy array
类型不支持
validation_steps
效果类似steps_per_epoch,网络验证过程中,每次验证多少steps(多少组batch_size的数据代入网络验证)
默认值等于样本总数除以batch_size,只适用于设置validation_data参数且数据类型tf.data.dataset
注意,如果用于验证的集合只是总的验证集合的一部分,这一部分验证集是从开始切片,保证每次验证适用相同的。
- 适用generator每次都是随机的,无法保证从头
validation_batch_size
验证集中的batch_size,默认为batch_size设置
validation_freq
仅当validation_data
设置时有效,表示训练完几组epoch后,进行验证。
也可以通过一个列表,表示在特定epoch后进行验证。
max_queue_size
对于数据类型5
,运行最大队列数,默认为10。
- generator queue指?
wokers
对于数据类型5
,允许运行最大线程数,默认为1,0表述generator在主线程运行。
- using process-based threading硬件线程数
use_multiprocessing
对于数据类型5
,布尔型,True表示using process-based threading, 默认False
返回值history
history.history可以记录训练过程中,每个epoch中的training loss 和 metriic values,以及验证过程中 validation loss 和 metriic values
参考
官网API介绍
本文标签: Tensorflow 2*网络训练(二) fit(x Y batch
版权声明:本文标题:Tensorflow 2.*网络训练(二) fit(x, y, batch 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.freenas.com.cn/jishu/1687905617h155784.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论