200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > pytorch用FCN语义分割手提包数据集(训练+预测单张输入图片代码)

pytorch用FCN语义分割手提包数据集(训练+预测单张输入图片代码)

时间:2023-10-21 22:13:11

相关推荐

pytorch用FCN语义分割手提包数据集(训练+预测单张输入图片代码)

一,手提包数据集

数据集下载:用pytorch写FCN进行手提包的语义分割。

training data(/yunlongdong/FCN-pytorch-easiest/tree/master/last),放到bag_data文件夹下

ground-truth label(/yunlongdong/FCN-pytorch-easiest/tree/master/last_msk),放到bag_data_mask文件夹下

项目目录结构:

训练数据:

训练label:

从这个手提包数据集可以看出,这是个二分类的,就是只分割出手提包和背景两个类别。所以label处黑色的表示手提包,白色的就是无关的背景。

二,训练代码(用来读取数据集,包括手提包图片和手提包图片的label)

2.1:数据集读取的代码

###BagData.pyimport osimport torchimport torch.nn as nnfrom torch.utils.data import DataLoader, Dataset, random_splitfrom torchvision import transformsimport numpy as npimport cv2#from onehot import onehottransform = pose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])def onehot(data, n):buf = np.zeros(data.shape + (n, ))nmsk = np.arange(data.size)*n + data.ravel()buf.ravel()[nmsk-1] = 1return bufclass BagDataset(Dataset):def __init__(self, transform=None):self.transform = transformdef __len__(self):return len(os.listdir('bag_data'))def __getitem__(self, idx):img_name = os.listdir('bag_data')[idx]imgA = cv2.imread('bag_data/'+img_name)imgA = cv2.resize(imgA, (160, 160))#print(imgA.shape)imgB = cv2.imread('bag_data_msk/'+img_name, 0)imgB = cv2.resize(imgB, (160, 160))#print(imgB.shape)imgB = imgB/255imgB = imgB.astype('uint8')imgB = onehot(imgB, 2) #因为此代码是二分类问题,即分割出手提包和背景两样就行,因此这里参数是2imgB = imgB.transpose(2,0,1) #imgB不经过transform处理,所以要手动把(H,W,C)转成(C,H,W)imgB = torch.FloatTensor(imgB)if self.transform:imgA = self.transform(imgA) #一转成向量后,imgA通道就变成(C,H,W)return imgA, imgBbag = BagDataset(transform)train_size = int(0.9 * len(bag)) #整个训练集中,百分之90为训练集test_size = len(bag) - train_sizetrain_dataset, test_dataset = random_split(bag, [train_size, test_size]) #划分训练集和测试集train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=4)if __name__ =='__main__':for train_batch in train_dataloader:print(train_batch)for test_batch in test_dataloader:print(test_batch)

贴了数据集读取的代码后,我觉得有必要说一下onehot这个函数

1.就是数据集label的onehot化:

onehot化是把label化成一个一维向量。

onehot化的函数如下:

def onehot(data, n):buf = np.zeros(data.shape + (n, ))nmsk = np.arange(data.size)*n + data.ravel()buf.ravel()[nmsk-1] = 1return buf

输入的data为以灰度图形式读取的label,n为分割的类别数(此数据集是2)

buf = np.zeros(data.shape + (n, ))#设data的shape为(a,b),则构造一个全0矩阵,维度为(a,b,n)

因为n是2,所以意思就是,2层的(a,b)的全0矩阵,一层用来表示手提包的,剩下一层则是用来表示背景的。

nmsk = np.arange(data.size)*n + data.ravel()

这行则比较妙一点,首先设data的size为5,则arange(5)为,(0,1,2,3,4),其实就是表示data各个元素的位置。arange(5)*2为(0,2,4,6,8),其实这是变相表示原来长度x2的位置。而data因为是label,且归一化过的,所以data里的值要么是0要么是1,data.ravel()是把data展成一维数组,arange(5)*2+data.ravel()意思是在(0,2,4,6,8)中,表示手提包的则+1,表示背景的则+0。这里打个比方,例如第三个和第五个位置是表示手提包的,则是(0,2,5,6,9),到这里可能还看不出什么,结合下一句代码就明白了。

buf.ravel()[nmsk-1] = 1

用回刚刚的例子(0,2,5,6,9),nmsk-1后,是(9,1,4,5,8),与初始的(0,2,4,6,8)对比,若原来是1的位置会保持原样(因为+1后又-1了),而原本是0的,表示其位置就会-1。这样的结果就是把(a,b)的label投射到(a,b)*2的长度中。这样做的原因数据集是2分类的,所以网络输出肯定是(a,b,2)这样的,所以label必须要和网络输出维度形式一样才能比较,得出损失函数。

2.2,模型代码

#####FCN.pyimport torchimport torch.nn as nnfrom torchvision import modelsfrom torchvision.models.vgg import VGGclass FCNs(nn.Module):def __init__(self, pretrained_net, n_class):super().__init__()self.n_class = n_classself.pretrained_net = pretrained_netself.relu = nn.ReLU(inplace=True)self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)self.bn1= nn.BatchNorm2d(512)self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)self.bn2= nn.BatchNorm2d(256)self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)self.bn3= nn.BatchNorm2d(128)self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)self.bn4= nn.BatchNorm2d(64)self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)self.bn5= nn.BatchNorm2d(32)self.classifier = nn.Conv2d(32, n_class, kernel_size=1) # classifier is 1x1 conv, to reduce channels from 32 to n_classdef forward(self, x):output = self.pretrained_net(x)x5 = output['x5'] x4 = output['x4'] x3 = output['x3'] x2 = output['x2'] x1 = output['x1'] score = self.bn1(self.relu(self.deconv1(x5)))score = score + x4 score = self.bn2(self.relu(self.deconv2(score))) score = score + x3 score = self.bn3(self.relu(self.deconv3(score))) score = score + x2 score = self.bn4(self.relu(self.deconv4(score))) score = score + x1 score = self.bn5(self.relu(self.deconv5(score))) score = self.classifier(score)return score class VGGNet(VGG):def __init__(self, pretrained=False, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):super().__init__(make_layers(cfg[model]))self.ranges = ranges[model]if pretrained:exec("self.load_state_dict(models.%s(pretrained=False).state_dict())" % model)if not requires_grad:for param in super().parameters():param.requires_grad = False# delete redundant fully-connected layer params, can save memory# 去掉vgg最后的全连接层(classifier)if remove_fc: del self.classifierif show_params:for name, param in self.named_parameters():print(name, param.size())def forward(self, x):output = {}# get the output of each maxpooling layer (5 maxpool in VGG net)for idx, (begin, end) in enumerate(self.ranges):#self.ranges = ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)) (vgg16 examples)for layer in range(begin, end):x = self.features[layer](x)output["x%d"%(idx+1)] = xreturn outputranges = {'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)),'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))}# Vgg-Net config # Vgg网络结构配置cfg = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],}# make layers using Vgg-Net config(cfg)# 由cfg构建vgg-Netdef make_layers(cfg, batch_norm=False):layers = []in_channels = 3for v in cfg:if v == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)if batch_norm:layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]else:layers += [conv2d, nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)'''VGG-16网络参数Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))'''if __name__ == "__main__":pass

2.3。train代码

########train.pyfrom datetime import datetimeimport matplotlib.pyplot as pltimport numpy as npimport torchimport torch.nn as nnimport torch.optim as optimimport visdomfrom BagData import test_dataloader, train_dataloaderfrom FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNetdef train(epo_num=50, show_vgg_params=False):#vis = visdom.Visdom()device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')vgg_model = VGGNet(requires_grad=True, show_params=show_vgg_params)fcn_model = FCNs(pretrained_net=vgg_model, n_class=2)fcn_model = fcn_model.to(device)criterion = nn.BCELoss().to(device)optimizer = optim.SGD(fcn_model.parameters(), lr=1e-2, momentum=0.7)all_train_iter_loss = []all_test_iter_loss = []# start timingprev_time = datetime.now()for epo in range(epo_num):train_loss = 0fcn_model.train()for index, (bag, bag_msk) in enumerate(train_dataloader):# bag.shape is torch.Size([4, 3, 160, 160])# bag_msk.shape is torch.Size([4, 2, 160, 160])bag = bag.to(device)bag_msk = bag_msk.to(device)optimizer.zero_grad()output = fcn_model(bag)output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])# print(output)# print(bag_msk)loss = criterion(output, bag_msk)loss.backward()iter_loss = loss.item()all_train_iter_loss.append(iter_loss)train_loss += iter_lossoptimizer.step()output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160) output_np = np.argmin(output_np, axis=1)bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160) bag_msk_np = np.argmin(bag_msk_np, axis=1)test_loss = 0fcn_model.eval()with torch.no_grad():for index, (bag, bag_msk) in enumerate(test_dataloader):bag = bag.to(device)bag_msk = bag_msk.to(device)optimizer.zero_grad()output = fcn_model(bag)output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])loss = criterion(output, bag_msk)iter_loss = loss.item()all_test_iter_loss.append(iter_loss)test_loss += iter_lossoutput_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160) output_np = np.argmin(output_np, axis=1)bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160) bag_msk_np = np.argmin(bag_msk_np, axis=1)cur_time = datetime.now()h, remainder = divmod((cur_time - prev_time).seconds, 3600)m, s = divmod(remainder, 60)time_str = "Time %02d:%02d:%02d" % (h, m, s)prev_time = cur_timeprint('epoch train loss = %f, epoch test loss = %f, %s'%(train_loss/len(train_dataloader), test_loss/len(test_dataloader), time_str))if np.mod(epo, 5) == 0:torch.save(fcn_model, 'checkpoints/fcn_model_{}.pt'.format(epo))print('saveing checkpoints/fcn_model_{}.pt'.format(epo))if __name__ == "__main__":train(epo_num=100, show_vgg_params=False)

三,预测代码:

#######t.pyimport matplotlib.pyplot as pltimport numpy as npimport torchfrom torchvision import transformsimport osimport cv2import matplotlib.pyplot as pltfrom BagData import test_dataloader, train_dataloaderfrom FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNetdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = torch.load('checkpoints/fcn_model_95.pt') # 加载模型model = model.to(device)transform = pose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])if __name__ =='__main__':img_name =r'bag3.jpg' #预测的图片imgA = cv2.imread(img_name)imgA = cv2.resize(imgA, (160, 160))imgA = transform(imgA)imgA = imgA.to(device)imgA = imgA.unsqueeze(0)output = model(imgA)output = torch.sigmoid(output)output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)print(output_np.shape) #(1, 2, 160, 160)output_np = np.argmin(output_np, axis=1)print(output_np.shape) #(1,160, 160)plt.subplot(1, 2, 1)#plt.imshow(np.squeeze(bag_msk_np[0, ...]), 'gray')#plt.subplot(1, 2, 2)plt.imshow(np.squeeze(output_np[0, ...]), 'gray')plt.pause(3)

四,效果展示

输入图片:

输出效果:

项目代码下载:/download/u014453898/11244794

运行时,直接运行train.py得到模型后,再运行t.py则可以进行预测

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。