python中pytorch图像识别的示例分析
这篇文章将为大家详细讲解有关python中pytorch图像识别的示例分析,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。
一、数据集爬取现在的深度学习对数据集量的需求越来越大了,也有了许多现成的数据集可供大家查找下载,但是如果你只是想要做一下深度学习的实例以此熟练一下或者找不到好的数据集,那么你也可以尝试自己制作数据集——自己从网上爬取图片,下面是通过百度图片爬取数据的示例。
importosimporttimeimportrequestsimportredefimgdata_set(save_path,word,epoch):q=0#停止爬取图片条件a=0#图片名称while(True):time.sleep(1)url="https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word={}&pn={}&ct=&ic=0&lm=-1&width=0&height=0".format(word,q)#word=需要搜索的名字headers={'User-Agent':'Mozilla/5.0(WindowsNT10.0;Win64;x64)AppleWebKit/537.36(KHTML,likeGecko)Chrome/88.0.4324.96Safari/537.36Edg/88.0.705.56'}response=requests.get(url,headers=headers)#print(response.request.headers)html=response.text#print(html)urls=re.findall('"objURL":"(.*?)"',html)#print(urls)forurlinurls:print(a)#图片的名字response=requests.get(url,headers=headers)image=response.contentwithopen(os.path.join(save_path,"{}.jpg".format(a)),'wb')asf:f.write(image)a=a+1q=q+20if(q/20)>=int(epoch):breakif__name__=="__main__":save_path=input('你想保存的路径:')word=input('你想要下载什么图片?请输入:')epoch=input('你想要下载几轮图片?请输入(一轮为60张左右图片):')#需要迭代几次图片imgdata_set(save_path,word,epoch)
通过上述的代码可以自行选择自己需要保存的图片路径、图片种类和图片数目。如我下面做的几种常见的盆栽植物的图片爬取,只需要执行六次代码,改变相应的盆栽植物的名称就可以了。下面是爬取盆栽芦荟的输入示例,输入完成后按Enter执行即可,会自动爬取图片保存到指定文件夹,
如图即为爬取后的图片。
可以看到图片中出现了一些无法打开的图片,同时因为是直接爬取的网络上的图片,可能会出现一些相同的图片,这些都需要进行删除,这就需要我们进行第二步处理了。
二、数据处理由于上面直接爬取到的图片有一些瑕疵,这就需要对图片进行进一步的处理了,对图片进行去重处理
通过重复图片去重处理,将自己需要的数据集按照种类分别保存在各自的文件夹里。同样,由于数据集可能存在无法打开的图片,这就需要对数据集进行下一步处理了。
首先将上面去重处理后的文件夹统一保存在同一个文件夹里面,如下图所示。
记住此文件夹路径,我这里是‘C:\Users\Lenovo\Desktop\data’,将此路径输入到下面代码中。
importosfromPILimportImageroot_path=r"C:\Users\Lenovo\Desktop\data"#待处理文件夹绝对路径(可按‘Ctrl+Shift+c'复制)root_names=os.listdir(root_path)forroot_nameinroot_names:path=os.path.join(root_path,root_name)print("正在删除文件夹:",path)names=os.listdir(path)names_path=[]fornameinnames:#print(name)img=Image.open(os.path.join(path,name))name_path=os.path.join(path,name)ifimg==None:#筛选无法打开的图片names_path.append(name_path)print('成功保存错误图片路径:{}'.format(name))else:w,h=img.sizeifw<50orh<50:#筛选错误图片names_path.append(name_path)print('成功保存特小图片路径:{}'.format(name))print("开始删除需删除的图片")forrinnames_path:os.remove(r)print("已删除:",r)
经过上述处理即完成了图片数据集的处理。最后,也可以对图片数据集进行图片名称的处理,使图片的名称重新从零开始依次排列,方便计数(注意下面代码中的rename将会删除掉原文件夹中的图片)。
importosroot_dir=r"C:\Users\Lenovo\Desktop\pzlh"#原文件夹路径save_path=r"C:\Users\Lenovo\Desktop\pzlh3"#新建文件夹路径img_path=os.listdir(root_dir)a=0foriinimg_path:a+=1i=os.path.join(os.path.abspath(root_dir),i)new_name=os.path.join(os.path.abspath(save_path),str(a)+'_pzlh.jpg')#此处可以修改图片名称os.rename(i,new_name)#特别注意:rename会删除原图
最后,我们可以得到一个将完整的常见盆栽植物的数据集。如果此时数据集的图片数量不多,我们还可以采用数据增强的方法,如旋转,加噪等步骤,都可以在网上找到相应的教程。最后,我们可以得到数据集如下图所示。
三、开始识别首先,先为上面的图片数据集生成对应的标签文件,运行下面代码可以自动生成对应的标签文件。
importosroot_path=r"C:\Users\Lenovo\Desktop\data"save_path=r"C:\Users\Lenovo\Desktop\data_label"#对应的label文件夹下也要建好相应的空子文件夹names=os.listdir(root_path)#得到images文件夹下的子文件夹的名称fornameinnames:path=os.path.join(root_path,name)img_names=os.listdir(path)#得到子文件夹下的图片的名称forimg_nameinimg_names:save_name=img_name.split(".jpg")[0]+'.txt'#得到相应的lable名称txt_path=os.path.join(save_path,name)#得到label的子文件夹的路径withopen(os.path.join(txt_path,save_name),"w")asf:#结合子文件夹路径和相应子文件夹下图片的名称生成相应的子文件夹txt文件f.write(name)#将label写入对应txt文件夹print(f.name)
然后,将上面已经准备好的数据集按照7:3(其他比例也可以)分为训练数据集和验证数据集(图片和标签一定要完全对应即对应图片和标签应该都处于训练集或者数据集),并如下图所示放置。
最后,数据集准备好后,即可导入到模型开始训练,运行下列代码
importtimefromtorch.utils.tensorboardimportSummaryWriterfromtorchvision.datasetsimportImageFolderfromtorchvisionimporttransformsfromtorch.utils.dataimportDataLoaderimporttorchvision.modelsasmodelsimporttorch.nnasnnimporttorchprint("是否使用GPU训练:{}".format(torch.cuda.is_available()))#打印是否采用gpu训练iftorch.cuda.is_available:print("GPU名称为:{}".format(torch.cuda.get_device_name()))#打印相应的gpu信息#数据增强太多也可能造成训练出不好的结果,而且耗时长,宜增强两三倍即可。normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])#规范化transform=transforms.Compose([#数据处理transforms.Resize((64,64)),transforms.ToTensor(),normalize])dataset_train=ImageFolder('data/train',transform=transform)#训练数据集#print(dataset_tran[0])dataset_valid=ImageFolder('data/valid',transform=transform)#验证或测试数据集#print(dataset_train.classer)#返回类别print(dataset_train.class_to_idx)#返回类别及其索引#print(dataset_train.imgs)#返回图片路径print(dataset_valid.class_to_idx)train_data_size=len(dataset_train)#放回数据集长度test_data_size=len(dataset_valid)print("训练数据集的长度为:{}".format(train_data_size))print("测试数据集的长度为:{}".format(test_data_size))#torch自带的标准数据集加载函数dataloader_train=DataLoader(dataset_train,batch_size=4,shuffle=True,num_workers=0,drop_last=True)dataloader_test=DataLoader(dataset_valid,batch_size=4,shuffle=True,num_workers=0,drop_last=True)#2.模型加载model_ft=models.resnet18(pretrained=True)#使用迁移学习,加载预训练权重#print(model_ft)in_features=model_ft.fc.in_featuresmodel_ft.fc=nn.Sequential(nn.Linear(in_features,36),nn.Linear(36,6))#将最后的全连接改为(36,6),使输出为六个小数,对应六种植物的置信度#冻结卷积层函数#fori,parainenumerate(model_ft.parameters()):#ifi<18:#para.requires_grad=False#print(model_ft)#model_ft.half()#可改为半精度,加快训练速度,在这里不适用model_ft=model_ft.cuda()#将模型迁移到gpu#3.优化器loss_fn=nn.CrossEntropyLoss()loss_fn=loss_fn.cuda()#将loss迁移到gpulearn_rate=0.01#设置学习率optimizer=torch.optim.SGD(model_ft.parameters(),lr=learn_rate,momentum=0.01)#可调超参数total_train_step=0total_test_step=0epoch=50#迭代次数writer=SummaryWriter("logs_train_yaopian")best_acc=-1ss_time=time.time()foriinrange(epoch):start_time=time.time()print("--------第{}轮训练开始---------".format(i+1))model_ft.train()fordataindataloader_train:imgs,targets=data#iftorch.cuda.is_available():#imgs.float()#imgs=imgs.float()#为上述改为半精度操作,在这里不适用imgs=imgs.cuda()targets=targets.cuda()#imgs=imgs.half()outputs=model_ft(imgs)loss=loss_fn(outputs,targets)optimizer.zero_grad()#梯度归零loss.backward()#反向传播计算梯度optimizer.step()#梯度优化total_train_step=total_train_step+1iftotal_train_step%100==0:#一轮时间过长可以考虑加一个end_time=time.time()print("使用GPU训练100次的时间为:{}".format(end_time-start_time))print("训练次数:{},loss:{}".format(total_train_step,loss.item()))#writer.add_scalar("valid_loss",loss.item(),total_train_step)model_ft.eval()total_test_loss=0total_accuracy=0withtorch.no_grad():#验证数据集时禁止反向传播优化权重fordataindataloader_test:imgs,targets=data#iftorch.cuda.is_available():#imgs.float()#imgs=imgs.float()imgs=imgs.cuda()targets=targets.cuda()#imgs=imgs.half()outputs=model_ft(imgs)loss=loss_fn(outputs,targets)total_test_loss=total_test_loss+loss.item()accuracy=(outputs.argmax(1)==targets).sum()total_accuracy=total_accuracy+accuracyprint("整体测试集上的loss:{}(越小越好,与上面的loss无关此为测试集的总loss)".format(total_test_loss))print("整体测试集上的正确率:{}(越大越好)".format(total_accuracy/len(dataset_valid)))writer.add_scalar("valid_loss",(total_accuracy/len(dataset_valid)),(i+1))#选择性使用哪一个total_test_step=total_test_step+1iftotal_accuracy>best_acc:#保存迭代次数中最好的模型print("已修改模型")best_acc=total_accuracytorch.save(model_ft,"best_model_yaopian.pth")ee_time=time.time()zong_time=ee_time-ss_timeprint("训练总共用时:{}h:{}m:{}s".format(int(zong_time//3600),int((zong_time%3600)//60),int(zong_time%60)))#打印训练总耗时writer.close()
上述采用的迁移学习直接使用resnet18的模型进行训练,只对全连接的输出进行修改,是一种十分方便且实用的方法,同样,你也可以自己编写模型,然后使用自己的模型进行训练,但是这种方法显然需要训练更长的时间才能达到拟合。如图所示,只需要修改矩形框内部分,将‘model_ft=models.resnet18(pretrained=True)'改为自己的模型‘model_ft=model’即可。
四、模型测试经过上述的步骤后,我们将会得到一个‘best_model_yaopian.pth’的模型权重文件,最后运行下列代码就可以对图片进行识别了
importosimporttorchimporttorchvisionfromPILimportImagefromtorchimportnni=0#识别图片计数root_path="测试_data"#待测试文件夹names=os.listdir(root_path)fornameinnames:print(name)i=i+1data_class=['滴水观音','发财树','非洲茉莉','君子兰','盆栽芦荟','文竹']#按文件索引顺序排列image_path=os.path.join(root_path,name)image=Image.open(image_path)print(image)transforms=torchvision.transforms.Compose([torchvision.transforms.Resize((64,64)),torchvision.transforms.ToTensor()])image=transforms(image)print(image.shape)model_ft=torchvision.models.resnet18()#需要使用训练时的相同模型#print(model_ft)in_features=model_ft.fc.in_featuresmodel_ft.fc=nn.Sequential(nn.Linear(in_features,36),nn.Linear(36,6))#此处也要与训练模型一致model=torch.load("best_model_yaopian.pth",map_location=torch.device("cpu"))#选择训练后得到的模型文件#print(model)image=torch.reshape(image,(1,3,64,64))#修改待预测图片尺寸,需要与训练时一致model.eval()withtorch.no_grad():output=model(image)print(output)#输出预测结果#print(int(output.argmax(1)))print("第{}张图片预测为:{}".format(i,data_class[int(output.argmax(1))]))#对结果进行处理,使直接显示出预测的植物种类
最后,通过上述步骤我们可以得到一个简单的盆栽植物智能识别程序,对盆栽植物进行识别,如下图是识别结果说明。
关于“python中pytorch图像识别的示例分析”这篇文章就分享到这里了,希望以上内容可以对大家有一定的帮助,使各位可以学到更多知识,如果觉得文章不错,请把它分享出去让更多的人看到。
声明:本站所有文章资源内容,如无特殊说明或标注,均为采集网络资源。如若本站内容侵犯了原著者的合法权益,可联系本站删除。