一,mnist数据集
形如上图的数字手写体就是mnist数据集。
二,GAN原理(生成对抗网络)
GAN网络一共由两部分组成:一个是伪造器(Generator,简称G),一个是判别器(Discrimniator,简称D)
一开始,G由服从某几个分布(如高斯分布)的噪音组成,生成的图片不断送给D判断是否正确,直到G生成的图片连D都判断以为是真的。D每一轮除了看过G生成的假图片以外,还要见数据集中的真图片,以前者和后者得到的损失函数值为依据更新D网络中的权值。因此G和D都在不停地更新权值。以下图为例:
在v1时的G只不过是 一堆噪声,见过数据集(real images)的D肯定能判断出G所生成的是假的。当然G也能知道D判断它是假的这个结果,因此G就会更新权值,到v2的时候,G就能生成更逼真的图片来让D判断,当然在v2时D也是会先看一次真图片,再去判断G所生成的图片。以此类推,不断循环就是GAN的思想。
三,训练代码
import argparse import os import numpy as np import math import torchvision.transforms as transforms from torchvision.utils import save_image from torch.utils.data import DataLoader from torchvision import datasets from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F import torch os.makedirs("images",exist_ok=True) parser = argparse.ArgumentParser() parser.add_argument("--n_epochs",type=int,default=200,help="number of epochs of training") parser.add_argument("--batch_size",default=64,help="size of the batches") parser.add_argument("--lr",type=float,default=0.0002,help="adam: learning rate") parser.add_argument("--b1",default=0.5,help="adam: decay of first order momentum of gradient") parser.add_argument("--b2",default=0.999,help="adam: decay of first order momentum of gradient") parser.add_argument("--n_cpu",default=8,help="number of cpu threads to use during batch generation") parser.add_argument("--latent_dim",default=100,help="dimensionality of the latent space") parser.add_argument("--img_size",default=28,help="size of each image dimension") parser.add_argument("--channels",default=1,help="number of image channels") parser.add_argument("--sample_interval",default=400,help="interval betwen image samples") opt = parser.parse_args() print(opt) img_shape = (opt.channels,opt.img_size,opt.img_size) # 确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1 cuda = True if torch.cuda.is_available() else False class Generator(nn.Module): def __init__(self): super(Generator,self).__init__() def block(in_feat,out_feat,normalize=True): layers = [nn.Linear(in_feat,out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat,0.8)) layers.append(nn.LeakyReLU(0.2,inplace=True)) return layers self.model = nn.Sequential( *block(opt.latent_dim,128,normalize=False),*block(128,256),*block(256,512),*block(512,1024),nn.Linear(1024,int(np.prod(img_shape))),nn.Tanh() ) def forward(self,z): img = self.model(z) img = img.view(img.size(0),*img_shape) return img class Discriminator(nn.Module): def __init__(self): super(Discriminator,self).__init__() self.model = nn.Sequential( nn.Linear(int(np.prod(img_shape)),nn.LeakyReLU(0.2,inplace=True),nn.Linear(512,nn.Linear(256,1),nn.Sigmoid(),) def forward(self,img): img_flat = img.view(img.size(0),-1) validity = self.model(img_flat) return validity # Loss function adversarial_loss = torch.nn.BCELoss() # Initialize generator and discriminator generator = Generator() discriminator = Discriminator() if cuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() # Configure data loader os.makedirs("../../data/mnist",exist_ok=True) DataLoader = torch.utils.data.DataLoader( datasets.MNIST( "../../data/mnist",train=True,download=True,transform=transforms.Compose( [transforms.Resize(opt.img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5])] ),),batch_size=opt.batch_size,shuffle=True,) # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(),opt.b2)) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor # ---------- # Training # ---------- if __name__ == '__main__': for epoch in range(opt.n_epochs): for i,(imgs,_) in enumerate(DataLoader): # print(imgs.shape) # Adversarial ground truths valid = Variable(Tensor(imgs.size(0),1).fill_(1.0),requires_grad=False) # 全1 fake = Variable(Tensor(imgs.size(0),1).fill_(0.0),requires_grad=False) # 全0 # Configure input real_imgs = Variable(imgs.type(Tensor)) # ----------------- # Train Generator # ----------------- optimizer_G.zero_grad() # 清空G网络 上一个batch的梯度 # Sample noise as generator input z = Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],opt.latent_dim)))) # 生成的噪音,均值为0方差为1维度为(64,100)的噪音 # Generate a batch of images gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator g_loss = adversarial_loss(discriminator(gen_imgs),valid) g_loss.backward() # g_loss用于更新G网络的权值,g_loss于D网络的判断结果 有关 optimizer_G.step() # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() # 清空D网络 上一个batch的梯度 # Measure discriminator's ability to classify real from generated samples real_loss = adversarial_loss(discriminator(real_imgs),valid) fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),fake) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() # d_loss用于更新D网络的权值 optimizer_D.step() print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch,opt.n_epochs,i,len(DataLoader),d_loss.item(),g_loss.item()) ) batches_done = epoch * len(DataLoader) + i if batches_done % opt.sample_interval == 0: save_image(gen_imgs.data[:25],"images/%d.png" % batches_done,nrow=5,normalize=True) # 保存一个batchsize中的25张 if (epoch+1) %2 ==0: print('save..') torch.save(generator,'g%d.pth' % epoch) torch.save(discriminator,'d%d.pth' % epoch)
运行结果:
一开始时,G生成的全是杂音:
然后逐渐呈现数字的雏形:
最后一次生成的结果:
四,测试代码:
导入最后保存生成器的模型:
from gan import Generator,Discriminator import torch import matplotlib.pyplot as plt from torch.autograd import Variable import numpy as np from torchvision.utils import save_image device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') Tensor = torch.cuda.FloatTensor g = torch.load('g199.pth') #导入生成器Generator模型 #d = torch.load('d.pth') g = g.to(device) #d = d.to(device) z = Variable(Tensor(np.random.normal(0,(64,100)))) #输入的噪音 gen_imgs =g(z) #生产图片 save_image(gen_imgs.data[:25],"images.png",normalize=True)
生成结果:
以上这篇pytorch GAN伪造手写体mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。