admin 管理员组

文章数量: 887032

Tensorflow 2.* 网络训练(二) fit(x, y, batch

在完成数据集合,网络搭建、以及训练编译设置以后,最后就是要开始训练(拟合)网络

tf.keras.Model.fit

如下fit的参数是相对比较多的,且参数间相互关系较为复杂

fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,validation_split=0.0, validation_data=None, shuffle=True, class_weight=None,sample_weight=None, initial_epoch=0, steps_per_epoch=None,validation_steps=None, validation_batch_size=None, validation_freq=1,max_queue_size=10, workers=1, use_multiprocessing=False
)

x,y

分别对应网络的input和target data,支持如下的数据输入格式,但是必须保证x和y的格式一致

  1. Numpy array,或者 Numpy array列表
  2. Tensorflow tensor ,或者Tensorflow tenso列表
  3. 字典,name为网络Input层的name,数据为input layer name对应的Numpy array或Tensorflow tensor
  4. tf.data.dataset, 返回(inputs, targets) or (inputs, targets, sample_weights) 如何创建tf.data.dataset参考博文
  5. 生成器generator 或者 keras.utils.Sequence,返回(inputs, targets) or (inputs, targets, sample_weights)如何创建生成器参考博文

注:前三种情况适用于数据集合不大,可一次性读入内存的情况,后两种适用于大样本数据集合。

batch_size

对于x,y非5的数据集格式,定义每次训练的批量数(整数型),默认为32。

epochs

训练模型的次数,一个epoch表示x,y中完整数据迭代完一次。
注意在与initial_epoch配合适用时,epoch表示训练到第几个epoch结束,即“final epoch”,而非需要训练多少eopochs

verbose

verbose:日志显示
verbose = 0 为不在标准输出流输出日志信息
verbose = 1 为输出进度条记录(默认)
verbose = 2 为每个epoch输出一行记录

validation_split

对于x,y非4和5的数据集格式,0到1之间的浮点数,即此部分的数据被划分验证集。数据集中序列靠后的数据进行划分

validation_data

优先级高于validation_split,定义验证集合,用于每组epoch训练集之后,在验证集上评价loss和metrics。格式(x_val, y_val) or (x_val, y_val, val_sample_weights)元组。
对于数据1和2,必须定义batch_size;
对于数据5,允许定义validation_steps ;
暂不支持数据3和5(适用fit_generator);

shuffle

布尔型是否打乱数据集;
当数据类型为generator时,shuffle失效;
当step_per_epoch=None时,shuffle失效;

class_weight

告诉网络那类标签对于计算loss更为重要,需要pay more attention
字典类型,类索引(整数):权重值(浮点型)

sample_weight

作用类似class_weight,对输入样本赋予权重
可以是与输入数据长度相同的1D numpy array ,也可以是 2D array with shape (samples, sequence_length),其中sequence_length表示每个样本在时刻内的时间戳

initial_epoch

从第几epooch开始训练,适用于重启之前的训练

steps_per_epoch

网络训练过程中,每个epoch的step数(每个step表示代入batch_size的样本后网络,权值更新的过程)
默认值等于样本总数除以batch_size,对于无限迭代的数据集必须定义
numpy array类型不支持

validation_steps

效果类似steps_per_epoch,网络验证过程中,每次验证多少steps(多少组batch_size的数据代入网络验证)
默认值等于样本总数除以batch_size,只适用于设置validation_data参数且数据类型tf.data.dataset
注意,如果用于验证的集合只是总的验证集合的一部分,这一部分验证集是从开始切片,保证每次验证适用相同的。

  • 适用generator每次都是随机的,无法保证从头

validation_batch_size

验证集中的batch_size,默认为batch_size设置

validation_freq

仅当validation_data设置时有效,表示训练完几组epoch后,进行验证。
也可以通过一个列表,表示在特定epoch后进行验证。

max_queue_size

对于数据类型5,运行最大队列数,默认为10。

  • generator queue指?

wokers

对于数据类型5,允许运行最大线程数,默认为1,0表述generator在主线程运行。

  • using process-based threading硬件线程数

use_multiprocessing

对于数据类型5,布尔型,True表示using process-based threading, 默认False

返回值history

history.history可以记录训练过程中,每个epoch中的training loss 和 metriic values,以及验证过程中 validation loss 和 metriic values

参考

官网API介绍

本文标签: Tensorflow 2*网络训练(二) fit(x Y batch