200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > GAN变种ACGAN利用手写数字识别mnist生成手写数字

GAN变种ACGAN利用手写数字识别mnist生成手写数字

时间:2023-11-28 15:47:35

相关推荐

GAN变种ACGAN利用手写数字识别mnist生成手写数字

1、摘要

本文主要讲解:GAN变种ACGAN利用手写数字识别mnist数据集进行训练,最终生成手写数字图片

主要思路:

Initialize generator and discriminatorInitialize weightsConfigure data loaderOptimizers AdamTrain GeneratorTrain DiscriminatorSaves a grid of generated digits ranging from 0 to 9

2、数据介绍

mnist手写数字识别数据集

MNIST数据集由Yann LeCun搜集,是一个大型的手写体数字数据库,通常用于训练各种图像处理系统,也被广泛用于机器学习领域的训练和测试。MNIST数字文字识别数据集数据量不会太多,而且是单色的图像,较简单,适合深度学习初学者练习建立模型、训练、预测。MNIST数据库中的图像集是NIST(National Institute of Standards and Technology)的两个数据库的组合:专用数据库1和特殊数据库3。数据集是有250人手写数字组成,一半是高中生,一半是美国人口普查局。

3、相关技术

本文主要使用pytorch实现ACGAN

pytorch在GitHub上的星星多余tensorflow了,tensorflow升级到2.0版导致以前的很多优秀库不兼容,这是硬伤

ACGAN是在CGAN基础上的进一步拓展,采用辅助分类器(Auxiliary Classifier)使得GAN获取的图像分类的功能。

CGAN通过结合标签信息来提高生成数据的质量,SGAN通过重建标签信息来提高生成数据的质量,那么我们可不可以两者都用,答案是显然的,因为ACGAN就是这样干的。更加详细的内容可以参见论文:Conditional Image Synthesis with Auxiliary Classifier GANs

————————————————

版权声明:本文为CSDN博主「时光碎了天」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。

原文链接:ACGAN 简介与代码实战

4、完整代码和步骤

ACGAN代码在4400个回合生成的手写数字如下:

主运行程序入口

import argparseimport osimport numpy as npimport torchimport torch.nn as nnimport torchvision.transforms as transformsfrom torch.autograd import Variablefrom torch.utils.data import DataLoaderfrom torchvision import datasetsfrom torchvision.utils import save_imageos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")parser.add_argument("--channels", type=int, default=1, help="number of image channels")parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")opt = parser.parse_args()print(opt)cuda = True if torch.cuda.is_available() else Falsedef weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)self.init_size = opt.img_size // 4 # Initial size before upsamplingself.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise, labels):gen_input = torch.mul(self.label_emb(labels), noise)out = self.l1(gen_input)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, bn=True):"""Returns layers of each discriminator block"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return blockself.conv_blocks = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# The height and width of downsampled imageds_size = opt.img_size // 2 ** 4# Output layersself.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label# Loss functionsadversarial_loss = torch.nn.BCELoss()auxiliary_loss = torch.nn.CrossEntropyLoss()# Initialize generator and discriminatorgenerator = Generator()discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()auxiliary_loss.cuda()# Initialize weightsgenerator.apply(weights_init_normal)discriminator.apply(weights_init_normal)# Configure data loaderos.makedirs("../../data/mnist", exist_ok=True)dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=pose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,)# Optimizersoptimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensorLongTensor = torch.cuda.LongTensor if cuda else torch.LongTensordef sample_image(n_row, batches_done):"""Saves a grid of generated digits ranging from 0 to n_classes"""# Sample noisez = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))# Get labels ranging from 0 to n_classes for n rowslabels = np.array([num for _ in range(n_row) for num in range(n_row)])labels = Variable(LongTensor(labels))gen_imgs = generator(z, labels)save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)for epoch in range(opt.n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# Adversarial ground truthsvalid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(FloatTensor))labels = Variable(labels.type(LongTensor))# -----------------# Train Generator# -----------------optimizer_G.zero_grad()# Sample noise and labels as generator inputz = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))# Generate a batch of imagesgen_imgs = generator(z, gen_labels)# Loss measures generator's ability to fool the discriminatorvalidity, pred_label = discriminator(gen_imgs)g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))g_loss.backward()optimizer_G.step()# ---------------------# Train Discriminator# ---------------------optimizer_D.zero_grad()# Loss for real imagesreal_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2# Loss for fake imagesfake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2# Total discriminator lossd_loss = (d_real_loss + d_fake_loss) / 2# Calculate discriminator accuracypred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:sample_image(n_row=10, batches_done=batches_done)

5、学习链接

既能生成图像又能进行分类的ACGAN

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。