• 售前

  • 售后

热门帖子
入门百科

Pytorch数据读取与预处置惩罚该如何实现

[复制链接]
街角386 显示全部楼层 发表于 2021-10-26 13:50:00 |阅读模式 打印 上一主题 下一主题
  在炼丹时,数据的读取与预处理是关键一步。不同的模型所需要的数据以及预处理方式各不相同,如果每个轮子都我们自己写的话,是很浪费时间和精力的。Pytorch帮我们实现了方便的数据读取与预处理方法,下面记录两个DEMO,便于加速以后的代码效率。
  根据数据是否一次性读取完,将DEMO分为:
  1、串行式读取。也就是一次性读取完所有需要的数据到内存,模型训练时不会再访问外存。通常用在内存充足的环境下利用,速度更快。
  2、并行式读取。也就是边训练边读取数据。通常用在内存不敷的环境下利用,会占用盘算资源,如果分配的好的话,险些不损失速度。
  Pytorch官方的数据提取方式只管方便编码,但由于它提取数据方式比力枯燥,会浪费资源,下面对其举行分析。
1  串行式读取


1.1  DEMO代码
  1. import torch
  2. from torch.utils.data import Dataset,DataLoader
  3.   
  4. class MyDataSet(Dataset):# ————1————
  5. def __init__(self):  
  6.   self.data = torch.tensor(range(10)).reshape([5,2])
  7.   self.label = torch.tensor(range(5))
  8. def __getitem__(self, index):  
  9.   return self.data[index], self.label[index]
  10. def __len__(self):  
  11.   return len(self.data)
  12. my_data_set = MyDataSet()# ————2————
  13. my_data_loader = DataLoader(
  14. dataset=my_data_set,  # ————3————
  15. batch_size=2,     # ————4————
  16. shuffle=True,     # ————5————
  17. sampler=None,     # ————6————
  18. batch_sampler=None,  # ————7————
  19. num_workers=0 ,    # ————8————
  20. collate_fn=None,    # ————9————
  21. pin_memory=True,    # ————10————
  22. drop_last=True     # ————11————
  23. )
  24. for i in my_data_loader: # ————12————
  25. print(i)
复制代码
  注释处表明如下:
  1、重写数据集类,用于保存数据。除了 __init__() 外,必须实现 __getitem__() 和 __len__() 两个方法。前一个方法用于输出索引对应的数据。后一个方法用于获取数据集的长度。
  2~5、 2预备好数据集后,传入DataLoader来迭代天生数据。前三个参数分别是传入的数据集对象、每次获取的批量大小、是否打乱数据集输出。
  6、采样器,如果定义这个,shuffle只能设置为False。所谓采样器就是用于天生数据索引的可迭代对象,比如列表。因此,定义了采样器,采样都按它来,shuffle再打乱就没意义了。
  7、批量采样器,如果定义这个,batch_size、shuffle、sampler、drop_last都不能定义。实际上,如果没有特殊的数据天生顺序的要求,采样器并没有必要定义。torch.utils.data 中的各种 Sampler 就是采样器类,如果需要,可以利用它们来定义。
  8、用于天生数据的子进程数。默以为0,不并行。
  9、拼接多个样本的方法,默认是将每个batch的数据在第一维上举行拼接。这样可能说不清楚,并且由于这里可以探究一下获取数据的速度,背面再详细说明。
  10、是否利用锁页内存。用的话会更快,内存不充足最好别用。
  11、是否把最后小于batch的数据丢掉。
  12、迭代获取数据并输出。
1.2  速度探索


  首先看一下DEMO的输出:

  输出了两个batch的数据,每组数据中data和label都正确分列,符合我们的预期。那么DataLoader是怎么把数据整合起来的呢?首先,我们把collate_fn定义为直接映射(不用它默认的方法),来检察看每次DataLoader从MyDataSet中读取了什么,将上面部门代码修改如下:
  1. my_data_loader = DataLoader(
  2. dataset=my_data_set,  
  3. batch_size=2,      
  4. shuffle=True,      
  5. sampler=None,     
  6. batch_sampler=None,  
  7. num_workers=0 ,   
  8. collate_fn=lambda x:x, #修改处
  9. pin_memory=True,   
  10. drop_last=True     
  11. )
复制代码
  结果如下:

  输出还是两个batch,然而每个batch中,单个的data和label是在一个list中的。似乎可以看出,DataLoader是一个一个读取MyDataSet中的数据的,然后再举行相应数据的拼接。为了验证这点,代码修改如下:
  1. import torch
  2. from torch.utils.data import Dataset,DataLoader
  3.   
  4. class MyDataSet(Dataset):
  5. def __init__(self):  
  6.   self.data = torch.tensor(range(10)).reshape([5,2])
  7.   self.label = torch.tensor(range(5))
  8. def __getitem__(self, index):  
  9.   print(index)     #修改处2
  10.   return self.data[index], self.label[index]
  11. def __len__(self):  
  12.   return len(self.data)
  13. my_data_set = MyDataSet()
  14. my_data_loader = DataLoader(
  15. dataset=my_data_set,  
  16. batch_size=2,      
  17. shuffle=True,      
  18. sampler=None,     
  19. batch_sampler=None,  
  20. num_workers=0 ,   
  21. collate_fn=lambda x:x, #修改处1
  22. pin_memory=True,   
  23. drop_last=True     
  24. )
  25. for i in my_data_loader:
  26. print(i)
复制代码
  输出如下:

  验证了前面的猜想,的确是一个一个读取的。如果数据集定义的不是格式化的数据,那还好,但是我这里定义的是tensor,是可以直接通过列表来索引对应的tensor的。因此,DataLoader的操纵比直接索引多了拼接这一步,肯定是会慢许多的。一两次的读取还好,但在训练中,大量的读取累加起来,就会浪费许多时间了。
  自定义一个DataLoader可以证实这一点,代码如下:
  1. import torch
  2. from torch.utils.data import Dataset,DataLoader
  3. from time import time
  4.   
  5. class MyDataSet(Dataset):
  6. def __init__(self):  
  7.   self.data = torch.tensor(range(100000)).reshape([50000,2])
  8.   self.label = torch.tensor(range(50000))
  9. def __getitem__(self, index):  
  10.   return self.data[index], self.label[index]
  11. def __len__(self):  
  12.   return len(self.data)
  13. # 自定义DataLoader
  14. class MyDataLoader():
  15. def __init__(self, dataset,batch_size):
  16.   self.dataset = dataset
  17.   self.batch_size = batch_size
  18. def __iter__(self):
  19.   self.now = 0
  20.   self.shuffle_i = np.array(range(self.dataset.__len__()))
  21.   np.random.shuffle(self.shuffle_i)
  22.   return self
  23. def __next__(self):
  24.   self.now += self.batch_size
  25.   if self.now <= len(self.shuffle_i):
  26.    indexes = self.shuffle_i[self.now-self.batch_size:self.now]
  27.    return self.dataset.__getitem__(indexes)
  28.   else:
  29.    raise StopIteration
  30. # 使用官方DataLoader
  31. my_data_set = MyDataSet()
  32. my_data_loader = DataLoader(
  33. dataset=my_data_set,  
  34. batch_size=256,      
  35. shuffle=True,      
  36. sampler=None,     
  37. batch_sampler=None,  
  38. num_workers=0 ,   
  39. collate_fn=None,
  40. pin_memory=True,   
  41. drop_last=True     
  42. )
  43. start_t = time()
  44. for t in range(10):
  45. for i in my_data_loader:
  46.   pass
  47. print("官方:", time() - start_t)
  48. #自定义DataLoader
  49. my_data_set = MyDataSet()
  50. my_data_loader = MyDataLoader(my_data_set,256)
  51. start_t = time()
  52. for t in range(10):
  53. for i in my_data_loader:
  54.   pass
  55. print("自定义:", time() - start_t)
复制代码
运行结果如下:

  以上利用batch大小为256,仅各读取10 epoch的数据,都有30多倍的时间上的差距,更大的batch差距会更显着。别的,这里用于测试的每个数据只有两个浮点数,如果是图像,所需的时间可能会增加几百倍。因此,如果数据量和batch都比力大,并且数据是格式化的,最好自己写数据天生器。
2  并行式读取


2.1  DEMO代码
  1. import matplotlib.pyplot as plt
  2. from torch.utils.data import DataLoader
  3. from torchvision import transforms
  4. from torchvision.datasets import ImageFolder
  5. path = r'E:\DataSets\ImageNet\ILSVRC2012_img_train\10-19\128x128'
  6. my_data_set = ImageFolder(      #————1————
  7. root = path,            #————2————
  8. transform = transforms.Compose([  #————3————
  9.   transforms.ToTensor(),
  10.   transforms.CenterCrop(64)
  11. ]),
  12. loader = plt.imread         #————4————
  13. )
  14. my_data_loader = DataLoader(
  15. dataset=my_data_set,   
  16. batch_size=128,      
  17. shuffle=True,      
  18. sampler=None,      
  19. batch_sampler=None,   
  20. num_workers=0,      
  21. collate_fn=None,      
  22. pin_memory=True,      
  23. drop_last=True
  24. )      
  25. for i in my_data_loader:
  26. print(i)
复制代码
  注释处表明如下:
  1/2、ImageFolder类继续自DataSet类,因此可以按索引读取图像。路径必须包含文件夹,ImageFolder会给每个文件夹中的图像添加索引,并且每张图像会给予其地点文件夹的标签。举个例子,代码中my_data_set[0] 输出的是图像对象和它对应的标签构成的列表。
  3、图像到格式化数据的转换组合。更多的转换方法可以看 transform 模块。
  4、图像法的读取方式,默认是PIL.Image.open(),但我发现plt.imread()更快一些。
  由于是边训练边读取,transform会占用许多时间,因此可以先将图像转换为需要的情势存入外存再读取,从而避免重复操纵。
  其中transform.ToTensor()会把正常读取的图像转换为torch.tensor,并且像素值会映射至[0,1][0,1]。由于plt.imread()读取png图像时,像素值在[0,1][0,1],而读取jpg图像时,像素值却在[0,255][0,255],因此利用transform.ToTensor()能将图像像素区间统一化。
以上就是Pytorch数据读取与预处理该怎样实现的详细内容,更多关于Pytorch数据读取与预处理的资料请关注草根技术分享其它相干文章!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有帐号?立即注册

x

帖子地址: 

回复

使用道具 举报

分享
推广
火星云矿 | 预约S19Pro,享500抵1000!
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

草根技术分享(草根吧)是全球知名中文IT技术交流平台,创建于2021年,包含原创博客、精品问答、职业培训、技术社区、资源下载等产品服务,提供原创、优质、完整内容的专业IT技术开发社区。
  • 官方手机版

  • 微信公众号

  • 商务合作