这篇文章主要介绍“Pytorch如何继承Subset类完成自定义数据拆分”,在日常操作中,相信很多人在Pytorch如何继承Subset类完成自定义数据拆分问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”Pytorch如何继承Subset类完成自定义数据拆分”的疑惑有所帮助!接下来,请跟着小编一起来学习吧!

下面是加载内置训练数据集的常见操作:

fromtorchvision.datasetsimportFashionMNISTfromtorchvision.transformsimportCompose,ToTensor,NormalizeRAW_DATA_PATH='./rawdata'transform=Compose([ToTensor(),Normalize((0.1307,),(0.3081,))])train_data=FashionMNIST(root=RAW_DATA_PATH,download=True,train=True,transform=transform)

这里的train_data 做为 dataset 对象,它拥有许多熟悉,我们可以通过以下方法获取样本数据的分类类别集合、样本的特征维度、样本的标签集合等信息。

classes=train_data.classesnum_features=train_data.data[0].shape[0]train_labels=train_data.targetsprint(classes)print(num_features)print(train_labels)

输出如下:

['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
28
tensor([9, 0, 0, ..., 3, 0, 5])

但是,我们常常会在训练集的基础上拆分出验证集(或者只用部分数据来进行训练)。我们想到的第一个方法是使用 torch.utils.data.random_splitdataset 进行划分,下面我们假设划分10000个样本做为训练集,其余样本做为验证集:

fromtorch.utils.dataimportrandom_splitk=10000train_data,valid_data=random_split(train_data,[k,len(train_data)-k])

注意我们如果打印 train_data 和 valid_data 的类型,可以看到显示:

<class'torch.utils.data.dataset.Subset'>

已经不再是torchvision.datasets.mnist.FashionMNIST 对象,而是一个所谓的 Subset 对象!此时 Subset 对象虽然仍然还存有 data 属性,但是内置的 target classes 属性已经不复存在,

比如如果我们强行访问 valid_data 的 target 属性:

valid_target=valid_data.target

就会报如下错误:

'Subset' object has no attribute 'target'

但如果我们在后续的代码中常常会将拆分后的数据集也默认为 dataset 对象,那么该如何做到代码的一致性呢?

这里有一个trick,那就是以继承 SubSet 类的方式的方式定义一个新的 CustomSubSet 类,使新类在保持 SubSet 类的基本属性的基础上,拥有和原本数据集类相似的属性,如 targets classes 等:

fromtorch.utils.dataimportSubsetclassCustomSubset(Subset):'''Acustomsubsetclass'''def__init__(self,dataset,indices):super().__init__(dataset,indices)self.targets=dataset.targets#保留targets属性self.classes=dataset.classes#保留classes属性def__getitem__(self,idx):#同时支持索引访问操作x,y=self.dataset[self.indices[idx]]returnx,ydef__len__(self):#同时支持取长度操作returnlen(self.indices)

然后就引出了第二种划分方法,即通过初始化 CustomSubset 对象的方式直接对数据集进行划分(这里为了简化省略了shuffle的步骤):

importnumpyasnpfromcopyimportdeepcopyorigin_data=deepcopy(train_data)train_data=CustomSubset(origin_data,np.arange(k))valid_data=CustomSubset(origin_data,np.arange(k,len(origin_data))-k)

注意: CustomSubset 类的初始化方法的第二个参数 indices 为样本索引,我们可以通过 np.arange() 的方法来创建。

然后,我们再访问 valid_data 对应的 classes 和 targes 属性:

print(valid_data.classes)print(valid_data.targets)

此时,我们发现可以成功访问这些属性了:

['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankleboot']tensor([9,0,0,...,3,0,5])

当然, CustomSubset 的作用并不只是添加数据集的属性,我们还可以自定义一些数据预处理操作。

我们将类的结构修改如下:

classCustomSubset(Subset):'''Acustomsubsetclasswithcustomizabledatatransformation'''def__init__(self,dataset,indices,subset_transform=None):super().__init__(dataset,indices)self.targets=dataset.targetsself.classes=dataset.classesself.subset_transform=subset_transformdef__getitem__(self,idx):x,y=self.dataset[self.indices[idx]]ifself.subset_transform:x=self.subset_transform(x)returnx,ydef__len__(self):returnlen(self.indices)

我们可以在使用样本前设置好数据预处理算子:

fromtorchvisionimporttransformsvalid_data.subset_transform=transforms.Compose(\[transforms.RandomRotation((180,180))])

这样,我们再像下列这样用索引访问取出数据集样本时,就会自动调用算子完成预处理操作:

print(valid_data[0])

打印结果缩略如下:

(tensor([[[-0.4242, -0.4242, -0.4242, ......-0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]), 9)

到此,关于“Pytorch如何继承Subset类完成自定义数据拆分”的学习就结束了,希望能够解决大家的疑惑。理论与实践的搭配能更好的帮助大家学习,快去试试吧!若想继续学习更多相关知识,请继续关注亿速云网站,小编会继续努力为大家带来更多实用的文章!