pytorch中如何实现手写数字图片识别
小编给大家分享一下pytorch中如何实现手写数字图片识别,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!
具体内容如下
数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:
importtorchfromtorchimportnnfromtorch.nnimportfunctionalasFfromtorchimportoptimimporttorchvisionfrommatplotlibimportpyplotaspltfromutilsimportplot_image,plot_curve,one_hot#step1loaddatasetbatch_size=512train_loader=torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size=batch_size,shuffle=True)test_loader=torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size=batch_size,shuffle=False)x,y=next(iter(train_loader))print(x.shape,y.shape,x.min(),x.max())plot_image(x,y,"image_sample")classNet(nn.Module):def__init__(self):super(Net,self).__init__()self.fc1=nn.Linear(28*28,256)self.fc2=nn.Linear(256,64)self.fc3=nn.Linear(64,10)defforward(self,x):#x:[b,1,28,28]#h2=relu(xw1+b1)x=F.relu(self.fc1(x))#h3=relu(h2w2+b2)x=F.relu(self.fc2(x))#h4=h3w3+b3x=self.fc3(x)returnxnet=Net()optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)train_loss=[]forepochinrange(3):forbatch_idx,(x,y)inenumerate(train_loader):#加载进来的图片是一个四维的tensor,x:[b,1,28,28],y:[512]#但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作x=x.view(x.size(0),28*28)#[b,10]out=net(x)y_onehot=one_hot(y)#loss=mse(out,y_onehot)loss=F.mse_loss(out,y_onehot)optimizer.zero_grad()loss.backward()#w'=w-lr*gradoptimizer.step()train_loss.append(loss.item())ifbatch_idx%10==0:print(epoch,batch_idx,loss.item())plot_curve(train_loss)#wegetoptimal[w1,b1,w2,b2,w3,b3]total_correct=0forx,yintest_loader:x=x.view(x.size(0),28*28)out=net(x)#out:[b,10]pred=out.argmax(dim=1)correct=pred.eq(y).sum().float().item()total_correct+=correcttotal_num=len(test_loader.dataset)acc=total_correct/total_numprint("acc:",acc)x,y=next(iter(test_loader))out=net(x.view(x.size(0),28*28))pred=out.argmax(dim=1)plot_image(x,pred,"test")
主程序中调用的函数(注意命名为utils):
importtorchfrommatplotlibimportpyplotaspltdefplot_curve(data):fig=plt.figure()plt.plot(range(len(data)),data,color='blue')plt.legend(['value'],loc='upperright')plt.xlabel('step')plt.ylabel('value')plt.show()defplot_image(img,label,name):fig=plt.figure()foriinrange(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')plt.title("{}:{}".format(name,label[i].item()))plt.xticks([])plt.yticks([])plt.show()defone_hot(label,depth=10):out=torch.zeros(label.size(0),depth)idx=torch.LongTensor(label).view(-1,1)out.scatter_(dim=1,index=idx,value=1)returnout
看完了这篇文章,相信你对“pytorch中如何实现手写数字图片识别”有了一定的了解,如果想了解更多相关知识,欢迎关注亿速云行业资讯频道,感谢各位的阅读!
声明:本站所有文章资源内容,如无特殊说明或标注,均为采集网络资源。如若本站内容侵犯了原著者的合法权益,可联系本站删除。