• 售前

  • 售后

热门帖子
入门百科

PyTorch数据读取的实现示例

[复制链接]
a18945178687 显示全部楼层 发表于 2021-10-26 13:24:23 |阅读模式 打印 上一主题 下一主题
前言
  1. PyTorch
复制代码
作为一款深度学习框架,已经帮助我们实现了许多许多的功能了,包括数据的读取和转换了,那么这一章节就先容一下
  1. PyTorch
复制代码
内置的数据读取模块吧
模块先容

      
  • pandas 用于方便操作含有字符串的表文件,如csv  
  • zipfile python内置的文件解压包  
  • cv2 用于图片处置惩罚的模块,读入的图片模块为BGR,N H W C  
  • torchvision.transforms 用于图片的操作库,好比随机裁剪、缩放、模糊等等,可用于数据的增广,但也不仅限于内置的图片操作,也可以自行进行图片数据的操作,这章也会解说  
  • torch.utils.data.Dataset torch内置的对象类型  
  • torch.utils.data.DataLoader 和Dataset共同使用可以实现数据的加速读取和随机读取等等功能
  1. import zipfile # 解压
  2. import pandas as pd # 操作数据
  3. import os # 操作文件或文件夹
  4. import cv2 # 图像操作库
  5. import matplotlib.pyplot as plt # 图像展示库
  6. from torch.utils.data import Dataset # PyTorch内置对象
  7. from torchvision import transforms # 图像增广转换库 PyTorch内置
  8. import torch
复制代码
初步读取数据

数据下载到此处
我们先初步编写一个脚本来实现图片的展示
  1. # 解压文件到指定目录
  2. def unzip_file(root_path, filename):
  3.   full_path = os.path.join(root_path, filename)
  4.   file = zipfile.ZipFile(full_path)
  5.   file.extractall(root_path)
  6. unzip_file(root_path, zip_filename)
  7. # 读入csv文件
  8. face_landmarks = pd.read_csv(os.path.join(extract_path, csv_filename))
  9. # pandas读出的数据如想要操作索引 使用iloc
  10. image_name = face_landmarks.iloc[:,0]
  11. landmarks = face_landmarks.iloc[:,1:]
  12. # 展示
  13. def show_face(extract_path, image_file, face_landmark):
  14.   plt.imshow(plt.imread(os.path.join(extract_path, image_file)), cmap='gray')
  15.   point_x = face_landmark.to_numpy()[0::2]
  16.   point_y = face_landmark.to_numpy()[1::2]
  17.   plt.scatter(point_x, point_y, c='r', s=6)
  18.   
  19. show_face(extract_path, image_name.iloc[1], landmarks.iloc[1])
复制代码

使用内置库来实现

实现MyDataset
使用内置库是我们的代码更加的规范,而且可读性也大大增长
继承Dataset,须要我们实现的有两个地方:
      
  • 实现
    1. __len__
    复制代码
    返回数据的长度,实例化调用
    1. len()
    复制代码
    时返回  
    1. __getitem__
    复制代码
    给定命据的索引返回对应索引的数据如:a[0]  
    1. transform
    复制代码
    数据的额外操作时调用
  1. class FaceDataset(Dataset):
  2.   def __init__(self, extract_path, csv_filename, transform=None):
  3.     super(FaceDataset, self).__init__()
  4.     self.extract_path = extract_path
  5.     self.csv_filename = csv_filename
  6.     self.transform = transform
  7.     self.face_landmarks = pd.read_csv(os.path.join(extract_path, csv_filename))
  8.   def __len__(self):
  9.     return len(self.face_landmarks)
  10.   def __getitem__(self, idx):
  11.     image_name = self.face_landmarks.iloc[idx,0]
  12.     landmarks = self.face_landmarks.iloc[idx,1:].astype('float32')
  13.     point_x = landmarks.to_numpy()[0::2]
  14.     point_y = landmarks.to_numpy()[1::2]
  15.     image = plt.imread(os.path.join(self.extract_path, image_name))
  16.     sample = {'image':image, 'point_x':point_x, 'point_y':point_y}
  17.     if self.transform is not None:
  18.       sample = self.transform(sample)
  19.     return sample
复制代码
测试功能是否正常
  1. face_dataset = FaceDataset(extract_path, csv_filename)
  2. sample = face_dataset[0]
  3. plt.imshow(sample['image'], cmap='gray')
  4. plt.scatter(sample['point_x'], sample['point_y'], c='r', s=2)
  5. plt.title('face')
复制代码

实现本身的数据处置惩罚模块

内置的在
  1. torchvision.transforms
复制代码
模块下,由于我们的数据结构不能满意内置模块的要求,我们就必须本身实现
图片的缩放,由于缩放后人脸的标注位置也应该发生对应的变革,以是要本身实现对应的变革
  1. class Rescale(object):
  2.   def __init__(self, out_size):
  3.     assert isinstance(out_size,tuple) or isinstance(out_size,int), 'out size isinstance int or tuple'
  4.     self.out_size = out_size
  5.   def __call__(self, sample):
  6.     image, point_x, point_y = sample['image'], sample['point_x'], sample['point_y']
  7.     new_h, new_w = self.out_size if isinstance(self.out_size,tuple) else (self.out_size, self.out_size)
  8.     new_image = cv2.resize(image,(new_w, new_h))
  9.     h, w = image.shape[0:2]
  10.     new_y = new_h / h * point_y
  11.     new_x = new_w / w * point_x
  12.     return {'image':new_image, 'point_x':new_x, 'point_y':new_y}
复制代码
将数据转换为
  1. torch
复制代码
熟悉的数据格式因此,就必须转换为
  1. tensor
复制代码
  1. 注意
复制代码
:
  1. cv2
复制代码
  1. matplotlib
复制代码
读出的图片默认的shape为
  1. N H W C
复制代码
,而
  1. torch
复制代码
默认担当的是
  1. N C H W
复制代码
因此使用
  1. tanspose
复制代码
转换维度,
  1. torch
复制代码
转换多维度使用
  1. permute
复制代码
  1. class ToTensor(object):
  2.   def __call__(self, sample):
  3.     image, point_x, point_y = sample['image'], sample['point_x'], sample['point_y']
  4.     new_image = image.transpose((2,0,1))
  5.     return {'image':torch.from_numpy(new_image), 'point_x':torch.from_numpy(point_x), 'point_y':torch.from_numpy(point_y)}
复制代码
测试
  1. transform = transforms.Compose([Rescale((1024, 512)), ToTensor()])
  2. face_dataset = FaceDataset(extract_path, csv_filename, transform=transform)
  3. sample = face_dataset[0]
  4. plt.imshow(sample['image'].permute((1,2,0)), cmap='gray')
  5. plt.scatter(sample['point_x'], sample['point_y'], c='r', s=2)
  6. plt.title('face')
复制代码

使用Torch内置的loader加速读取数据
  1. data_loader = DataLoader(face_dataset, batch_size=4, shuffle=True, num_workers=0)
  2. for i in data_loader:
  3.   print(i['image'].shape)
  4.   break
复制代码
  1. torch.Size([4, 3, 1024, 512])
复制代码
  1. 注意
复制代码
:
  1. windows
复制代码
环境尽量不使用
  1. num_workers
复制代码
会发生报错
总结

这节使用内置的数据读取模块,帮助我们规范代码,也帮助我们简化代码,加速读取数据也可以加速训练,数据的增广可以大大的增长我们的训练精度,以是本节也是训练中比较紧张环节
到此这篇关于PyTorch数据读取的实现示例的文章就先容到这了,更多相干PyTorch数据读取内容请搜索草根技术分享从前的文章或继续浏览下面的相干文章渴望各人以后多多支持草根技术分享!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有帐号?立即注册

x

帖子地址: 

回复

使用道具 举报

分享
推广
火星云矿 | 预约S19Pro,享500抵1000!
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

草根技术分享(草根吧)是全球知名中文IT技术交流平台,创建于2021年,包含原创博客、精品问答、职业培训、技术社区、资源下载等产品服务,提供原创、优质、完整内容的专业IT技术开发社区。
  • 官方手机版

  • 微信公众号

  • 商务合作