pytorch怎么实现变分自动编码器
小编给大家分享一下pytorch怎么实现变分自动编码器,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!
这个例子是用MNIST数据集生成为例子#-*-coding:utf-8-*-"""CreatedonFriOct1211:42:192018@author:www"""importosimporttorchfromtorch.autogradimportVariableimporttorch.nn.functionalasFfromtorchimportnnfromtorch.utils.dataimportDataLoaderfromtorchvision.datasetsimportMNISTfromtorchvisionimporttransformsastfsfromtorchvision.utilsimportsave_imageim_tfs=tfs.Compose([tfs.ToTensor(),tfs.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])#标准化])train_set=MNIST('E:data',transform=im_tfs)train_data=DataLoader(train_set,batch_size=128,shuffle=True)classVAE(nn.Module):def__init__(self):super(VAE,self).__init__()self.fc1=nn.Linear(784,400)self.fc21=nn.Linear(400,20)#meanself.fc22=nn.Linear(400,20)#varself.fc3=nn.Linear(20,400)self.fc4=nn.Linear(400,784)defencode(self,x):h2=F.relu(self.fc1(x))returnself.fc21(h2),self.fc22(h2)defreparametrize(self,mu,logvar):std=logvar.mul(0.5).exp_()eps=torch.FloatTensor(std.size()).normal_()iftorch.cuda.is_available():eps=Variable(eps.cuda())else:eps=Variable(eps)returneps.mul(std).add_(mu)defdecode(self,z):h4=F.relu(self.fc3(z))returnF.tanh(self.fc4(h4))defforward(self,x):mu,logvar=self.encode(x)#编码z=self.reparametrize(mu,logvar)#重新参数化成正态分布returnself.decode(z),mu,logvar#解码,同时输出均值方差net=VAE()#实例化网络iftorch.cuda.is_available():net=net.cuda()x,_=train_set[0]x=x.view(x.shape[0],-1)iftorch.cuda.is_available():x=x.cuda()x=Variable(x)_,mu,var=net(x)print(mu)#可以看到,对于输入,网络可以输出隐含变量的均值和方差,这里的均值方差还没有训练#下面开始训练reconstruction_function=nn.MSELoss(size_average=False)defloss_function(recon_x,x,mu,logvar):"""recon_x:generatingimagesx:originimagesmu:latentmeanlogvar:latentlogvariance"""MSE=reconstruction_function(recon_x,x)#loss=0.5*sum(1+log(sigma^2)-mu^2-sigma^2)KLD_element=mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)KLD=torch.sum(KLD_element).mul_(-0.5)#KLdivergencereturnMSE+KLDoptimizer=torch.optim.Adam(net.parameters(),lr=1e-3)defto_img(x):'''定义一个函数将最后的结果转换回图片'''x=0.5*(x+1.)x=x.clamp(0,1)x=x.view(x.shape[0],1,28,28)returnxforeinrange(100):forim,_intrain_data:im=im.view(im.shape[0],-1)im=Variable(im)iftorch.cuda.is_available():im=im.cuda()recon_im,mu,logvar=net(im)loss=loss_function(recon_im,im,mu,logvar)/im.shape[0]#将loss平均optimizer.zero_grad()loss.backward()optimizer.step()if(e+1)%20==0:print('epoch:{},Loss:{:.4f}'.format(e+1,loss.item()))save=to_img(recon_im.cpu().data)ifnotos.path.exists('./vae_img'):os.mkdir('./vae_img')save_image(save,'./vae_img/image_{}.png'.format(e+1))
补充:PyTorch 深度学习快速入门——变分自动编码器
变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成。
回忆一下,自动编码器有个问题,就是并不能任意生成图片,因为我们没有办法自己去构造隐藏向量,需要通过一张图片输入编码我们才知道得到的隐含向量是什么,这时我们就可以通过变分自动编码器来解决这个问题。
其实原理特别简单,只需要在编码过程给它增加一些限制,迫使其生成的隐含向量能够粗略的遵循一个标准正态分布,这就是其与一般的自动编码器最大的不同。
这样我们生成一张新图片就很简单了,我们只需要给它一个标准正态分布的随机隐含向量,这样通过解码器就能够生成我们想要的图片,而不需要给它一张原始图片先编码。
一般来讲,我们通过 encoder 得到的隐含向量并不是一个标准的正态分布,为了衡量两种分布的相似程度,我们使用 KL divergence,利用其来表示隐含向量与标准正态分布之间差异的 loss,另外一个 loss 仍然使用生成图片与原图片的均方误差来表示。
KL divergence 的公式如下
重参数 为了避免计算 KL divergence 中的积分,我们使用重参数的技巧,不是每次产生一个隐含向量,而是生成两个向量,一个表示均值,一个表示标准差,这里我们默认编码之后的隐含向量服从一个正态分布的之后,就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布,最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布,也就是希望均值为 0,方差为 1
所以最后我们可以将我们的 loss 定义为下面的函数,由均方误差和 KL divergence 求和得到一个总的 loss
defloss_function(recon_x,x,mu,logvar):"""recon_x:generatingimagesx:originimagesmu:latentmeanlogvar:latentlogvariance"""MSE=reconstruction_function(recon_x,x)#loss=0.5*sum(1+log(sigma^2)-mu^2-sigma^2)KLD_element=mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)KLD=torch.sum(KLD_element).mul_(-0.5)#KLdivergencereturnMSE+KLD
用 mnist 数据集来简单说明一下变分自动编码器
importosimporttorchfromtorch.autogradimportVariableimporttorch.nn.functionalasFfromtorchimportnnfromtorch.utils.dataimportDataLoaderfromtorchvision.datasetsimportMNISTfromtorchvisionimporttransformsastfsfromtorchvision.utilsimportsave_imageim_tfs=tfs.Compose([tfs.ToTensor(),tfs.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])#标准化])train_set=MNIST('./mnist',transform=im_tfs)train_data=DataLoader(train_set,batch_size=128,shuffle=True)classVAE(nn.Module):def__init__(self):super(VAE,self).__init__()self.fc1=nn.Linear(784,400)self.fc21=nn.Linear(400,20)#meanself.fc22=nn.Linear(400,20)#varself.fc3=nn.Linear(20,400)self.fc4=nn.Linear(400,784)defencode(self,x):h2=F.relu(self.fc1(x))returnself.fc21(h2),self.fc22(h2)defreparametrize(self,mu,logvar):std=logvar.mul(0.5).exp_()eps=torch.FloatTensor(std.size()).normal_()iftorch.cuda.is_available():eps=Variable(eps.cuda())else:eps=Variable(eps)returneps.mul(std).add_(mu)defdecode(self,z):h4=F.relu(self.fc3(z))returnF.tanh(self.fc4(h4))defforward(self,x):mu,logvar=self.encode(x)#编码z=self.reparametrize(mu,logvar)#重新参数化成正态分布returnself.decode(z),mu,logvar#解码,同时输出均值方差net=VAE()#实例化网络iftorch.cuda.is_available():net=net.cuda()x,_=train_set[0]x=x.view(x.shape[0],-1)iftorch.cuda.is_available():x=x.cuda()x=Variable(x)_,mu,var=net(x)print(mu)Variablecontaining:Columns0to9-0.0307-0.1439-0.04350.34720.0368-0.03390.0274-0.56080.02800.2742Columns10to19-0.6221-0.0894-0.09330.42410.16110.32670.5755-0.02370.2714-0.2806[torch.cuda.FloatTensorofsize1x20(GPU0)]
可以看到,对于输入,网络可以输出隐含变量的均值和方差,这里的均值方差还没有训练 下面开始训练
reconstruction_function=nn.MSELoss(size_average=False)defloss_function(recon_x,x,mu,logvar):"""recon_x:generatingimagesx:originimagesmu:latentmeanlogvar:latentlogvariance"""MSE=reconstruction_function(recon_x,x)#loss=0.5*sum(1+log(sigma^2)-mu^2-sigma^2)KLD_element=mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)KLD=torch.sum(KLD_element).mul_(-0.5)#KLdivergencereturnMSE+KLDoptimizer=torch.optim.Adam(net.parameters(),lr=1e-3)defto_img(x):'''定义一个函数将最后的结果转换回图片'''x=0.5*(x+1.)x=x.clamp(0,1)x=x.view(x.shape[0],1,28,28)returnxforeinrange(100):forim,_intrain_data:im=im.view(im.shape[0],-1)im=Variable(im)iftorch.cuda.is_available():im=im.cuda()recon_im,mu,logvar=net(im)loss=loss_function(recon_im,im,mu,logvar)/im.shape[0]#将loss平均optimizer.zero_grad()loss.backward()optimizer.step()if(e+1)%20==0:print('epoch:{},Loss:{:.4f}'.format(e+1,loss.data[0]))save=to_img(recon_im.cpu().data)ifnotos.path.exists('./vae_img'):os.mkdir('./vae_img')save_image(save,'./vae_img/image_{}.png'.format(e+1))epoch:20,Loss:61.5803epoch:40,Loss:62.9573epoch:60,Loss:63.4285epoch:80,Loss:64.7138epoch:100,Loss:63.3343
变分自动编码器虽然比一般的自动编码器效果要好,而且也限制了其输出的编码 (code) 的概率分布,但是它仍然是通过直接计算生成图片和原始图片的均方误差来生成 loss,这个方式并不好,生成对抗网络中,我们会讲一讲这种方式计算 loss 的局限性,然后会介绍一种新的训练办法,就是通过生成对抗的训练方式来训练网络而不是直接比较两张图片的每个像素点的均方误差。
以上是“pytorch怎么实现变分自动编码器”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注亿速云行业资讯频道!
声明:本站所有文章资源内容,如无特殊说明或标注,均为采集网络资源。如若本站内容侵犯了原著者的合法权益,可联系本站删除。