200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > PyTorch:保存/加载训练好的模型测试

PyTorch:保存/加载训练好的模型测试

时间:2021-01-10 14:15:16

相关推荐

PyTorch:保存/加载训练好的模型测试

保存

torch.save(model.state_dict(), './cnn.pth')

加载

model = VGG16() #加载模型前要创建一个模型的实例对象

model.load_state_dict(torch.load("./cnn.pth"))

例子

import torchimport torch.nn as nnfrom torch import optimfrom torch.autograd import Variablefrom torch.utils.data import DataLoaderfrom torchvision import transformsfrom torchvision import datasetsfrom tqdm import tqdmfrom PIL import Image'''定义网络模型'''class VGG16(nn.Module):def __init__(self, num_classes=10):super(VGG16, self).__init__()self.features = nn.Sequential(#1nn.Conv2d(3,64,kernel_size=3,padding=1),nn.BatchNorm2d(64),nn.ReLU(True),#2nn.Conv2d(64,64,kernel_size=3,padding=1),nn.BatchNorm2d(64),nn.ReLU(True),nn.MaxPool2d(kernel_size=2,stride=2),#3nn.Conv2d(64,128,kernel_size=3,padding=1),nn.BatchNorm2d(128),nn.ReLU(True),#4nn.Conv2d(128,128,kernel_size=3,padding=1),nn.BatchNorm2d(128),nn.ReLU(True),nn.MaxPool2d(kernel_size=2,stride=2),#5nn.Conv2d(128,256,kernel_size=3,padding=1),nn.BatchNorm2d(256),nn.ReLU(True),#6nn.Conv2d(256,256,kernel_size=3,padding=1),nn.BatchNorm2d(256),nn.ReLU(True),#7nn.Conv2d(256,256,kernel_size=3,padding=1),nn.BatchNorm2d(256),nn.ReLU(True),nn.MaxPool2d(kernel_size=2,stride=2),#8nn.Conv2d(256,512,kernel_size=3,padding=1),nn.BatchNorm2d(512),nn.ReLU(True),#9nn.Conv2d(512,512,kernel_size=3,padding=1),nn.BatchNorm2d(512),nn.ReLU(True),#10nn.Conv2d(512,512,kernel_size=3,padding=1),nn.BatchNorm2d(512),nn.ReLU(True),nn.MaxPool2d(kernel_size=2,stride=2),#11nn.Conv2d(512,512,kernel_size=3,padding=1),nn.BatchNorm2d(512),nn.ReLU(True),#12nn.Conv2d(512,512,kernel_size=3,padding=1),nn.BatchNorm2d(512),nn.ReLU(True),#13nn.Conv2d(512,512,kernel_size=3,padding=1),nn.BatchNorm2d(512),nn.ReLU(True),nn.MaxPool2d(kernel_size=2,stride=2),nn.AvgPool2d(kernel_size=1,stride=1),)self.classifier = nn.Sequential(#14nn.Linear(512,4096),nn.ReLU(True),nn.Dropout(),#15nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),#16nn.Linear(4096,num_classes),)#self.classifier = nn.Linear(512, 10)def forward(self, x):out = self.features(x)# print(out.shape)out = out.view(out.size(0), -1)# print(out.shape)out = self.classifier(out)# print(out.shape)return out'''创建model实例对象,并检测是否支持使用GPU'''model = VGG16()use_gpu = torch.cuda.is_available() # 判断是否有GPU加速if use_gpu:model = model.cuda()model.eval()'''测试'''def prediect(img_path):model.load_state_dict(torch.load("./cnn.pth"))# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')torch.no_grad()transform_valid = pose([transforms.Resize((32, 32), interpolation=2),transforms.ToTensor()])img = Image.open(img_path)img = transform_valid(img).unsqueeze(0) # 拓展维度if use_gpu:img = Variable(img, volatile=True).cuda()# label = Variable(label, volatile=True).cuda()else:img = Variable(img)# label = Variable(label)out = model(img)_, pred = torch.max(out, 1) # 求出out最大值索引print('this picture maybe :', classes[pred])if __name__ == '__main__':prediect('./Test_Image/dog.jpg')

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。