200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > 508任务一:用pytorch简单实现LeNet5网络对MNIST数据集训练

508任务一:用pytorch简单实现LeNet5网络对MNIST数据集训练

时间:2019-01-31 07:39:22

相关推荐

508任务一:用pytorch简单实现LeNet5网络对MNIST数据集训练

看了一些pytorch教学视频,结合别人的代码,按自己的喜好写出来了比较简单的实现,其实还可以把loss数据绘个表,还可以在训练时2加个循环,多训练几次。

import torchimport torchvisionfrom torch import nnfrom torch.utils.data import DataLoaderfrom torch.nn import functional as F#准备数据集train_data=torchvision.datasets.MNIST("./MNISTdata",train=True,transform=torchvision.transforms.ToTensor(),download=True)test_data=torchvision.datasets.MNIST("./MNISTdata",train=False,transform=torchvision.transforms.ToTensor(),download=True)#lengthtrain_data_size=len(train_data)test_data_size=len(test_data)print("训练数据集的长度为:{}".format(train_data_size))print("测试数据集的长度为:{}".format(test_data_size))#利用DataLoader加载数据集train_dataloader=DataLoader(train_data,batch_size=64)test_dataloader=DataLoader(test_data,batch_size=64)#创建神经网络class LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5, padding=2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xmodel = LeNet5()if torch.cuda.is_available():model=model.cuda()#损失函数loss_fn=nn.CrossEntropyLoss()if torch.cuda.is_available():loss_fn=loss_fn.cuda()#优化器optimzer = torch.optim.SGD(model.parameters(), lr=0.01)#训练次数total_train_step = 0#测试的次数total_test_step = 0#训练步骤开始for data in train_dataloader:imgs,targets=dataif torch.cuda.is_available():imgs = imgs.cuda()targets=targets.cuda()outputs=model(imgs)loss=loss_fn(outputs,targets)optimzer.zero_grad()loss.backward()optimzer.step()total_train_step=total_train_step+1if total_train_step%100==0:print("训练次数:{},loss:{}".format(total_train_step,loss.item()))total_test_loss=0total_accuracy=00with torch.no_grad():for data in train_dataloader:imgs,targets=dataif torch.cuda.is_available():imgs = imgs.cuda() # *********CUDAtargets = targets.cuda() # *********CUDAoutputs = model(imgs)loss = loss_fn(outputs,targets)total_test_loss=total_test_loss+loss.item()accuracy = (outputs.argmax(1)==targets).sum()total_accuracy=total_accuracy+accuracyprint("整体测试集上的loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))total_test_step=total_test_step+1

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。