admin 管理员组文章数量: 887006
4.分类训练之训练脚本(Pytorch)
前言
解析好数据,构建好网络之后,那就可以进行训练了。
训练
训练有一下步骤:
1)导入之前构建好的网络和自定义的数据集
2)配置相应的参数(学习率,训练次数等)
3)创建网络对象,定义损失函数,定义优化器,动态调整学习率
4)加载tain_loader进行模型训练,加载test_loader进行模型验证
每一步都有注释,不懂的留言,写错的也可以留言,一起进步。
import torch
import torch.nn as nn
import torchvision
from vggnet import VGGNet
from load_cifar10 import train_loader, test_loader
import os# 是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 遍历样本200次
epoch_num = 200
# 学习率定义0.1
lr = 0.1
# 每次读取128张图片
batch_size = 128
# 定义网络
net = VGGNet().to(device)#定义损失loss
loss_func = nn.CrossEntropyLoss()#定义优化器optimizer
optimizer = torch.optim.Adam(net.parameters(), lr= lr)# optimizer = torch.optim.SGD(net.parameters(), lr = lr,momentum=0.9, weight_decay=5e-4)
# 调整学习率,指数衰减的方式,10个epoch进行学习率的衰减,变成原来的0.9
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.9)for epoch in range(epoch_num):net.train() #若网络中出现BN dropout层,训练过程中使用net.train(),测试过程中使用net.eval()for i, data in enumerate(train_loader):# 获取训练图片的数据和类别inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 将数据传入网络中outputs = net(inputs)# 计算lossloss = loss_func(outputs, labels)# 初始化优化器optimizer.zero_grad()# 反向传播lossloss.backward()# 更新参数optimizer.step()# 10维向量对应类别的概率分布,取概率值最高的地方对应的索引_, pred = torch.max(outputs.data, dim=1)# 预测正确的样本数量当pred和labels的结果一致时,统计结果correct = pred.eq(labels.data).cpu().sum()#print('epoch is ',epoch," step is ",i,'train loss is:',loss.item(),'train mini_batch correct is :',1.0 * correct / batch_size)if not os.path.exists('models'):os.mkdir('models')# 保存模型torch.save(net.state_dict(),'models/{}.pth'.format(epoch + 1))# 更新学习率scheduler.step()print('lr is:',optimizer.state_dict()['param_groups'][0]['lr'])sum_loss = 0sum_correct = 0sum_acc = 0for i, data in enumerate(test_loader):net.eval()# 获取训练图片的数据和类别inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 将数据传入网络中outputs = net(inputs)# 计算lossloss = loss_func(outputs, labels)sum_loss += loss.item()# 10维向量对应类别的概率分布,取概率值最高的地方对应的索引_, pred = torch.max(outputs.data, dim=1)# 预测正确的样本数量当pred和labels的结果一致时,统计结果correct = pred.eq(labels.data).cpu().sum()sum_correct += correct.item()#print('epoch is ',epoch," step is ",i,'test loss is:',loss.item(),'test mini_batch correct is :',1.0 * correct / batch_size)test_loss = 1.0 * sum_loss / len(test_loader)test_correct = 100.0 * sum_correct / (len(test_loader)*batch_size)print('test_loss',test_loss ,'test_loss',test_correct)
本文标签: 4分类训练之训练脚本(Pytorch)
版权声明:本文标题:4.分类训练之训练脚本(Pytorch) 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.freenas.com.cn/jishu/1732355573h1534336.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论