• 售前

  • 售后

热门帖子
入门百科

超详细PyTorch实现手写数字辨认器的示例代码

[复制链接]
小野妹子868 显示全部楼层 发表于 2021-10-26 14:10:25 |阅读模式 打印 上一主题 下一主题
前言

深度学习中有很多玩具数据,
  1. mnist
复制代码
就是此中一个,一个人可否入门深度学习往往就是以可否玩转
  1. mnist
复制代码
数据来判断的,在前面很多基础先容后我们就可以来实现一个简单的手写数字辨认的网络了
数据的处理惩罚

我们使用pytorch自带的包进行数据的预处理惩罚
  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. transform = transforms.Compose([
  7.   transforms.ToTensor(),
  8.   transforms.Normalize((0.5), (0.5))
  9. ])
  10. trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  11. trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True,num_workers=2)
复制代码
  1. 注释
复制代码
:
  1. transforms.Normalize
复制代码
用于数据的标准化,具体实现
  1. mean
复制代码
:均值 总和后除个数
  1. std
复制代码
:方差 每个元素减去均值再平方再除个数
  1. norm_data = (tensor - mean) / std
复制代码
这里就直接将图片标准化到了-1到1的范围,标准化的原因就是由于如果某个数在数据中很大很大,就导致其权重较大,从而影响到其他数据,而自己我们的数据都是平等的,所以标准化后将数据分布到-1到1的范围,使得全部数据都不会有太大的权重导致网络出现巨大的波动
  1. trainloader
复制代码
如今是一个可迭代的对象,那么我们可以使用
  1. for
复制代码
循环进行遍历了,由于是使用yield返回的数据,为了节省内存
观察一下数据
  1. def imshow(img):
  2.    img = img / 2 + 0.5 # unnormalize
  3.    npimg = img.numpy()
  4.    plt.imshow(np.transpose(npimg, (1, 2, 0)))
  5.    plt.show()
  6. # torchvision.utils.make_grid 将图片进行拼接
  7. imshow(torchvision.utils.make_grid(iter(trainloader).next()[0]))
复制代码

构建网络
  1. from torch import nn
  2. import torch.nn.functional as F
  3. class Net(nn.Module):
  4.   def __init__(self):
  5.     super(Net, self).__init__()
  6.     self.conv1 = nn.Conv2d(in_channels=1, out_channels=28, kernel_size=5) # 14
  7.     self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # 无参数学习因此无需设置两个
  8.     self.conv2 = nn.Conv2d(in_channels=28, out_channels=28*2, kernel_size=5) # 7
  9.     self.fc1 = nn.Linear(in_features=28*2*4*4, out_features=1024)
  10.     self.fc2 = nn.Linear(in_features=1024, out_features=10)
  11.   def forward(self, inputs):
  12.     x = self.pool(F.relu(self.conv1(inputs)))
  13.     x = self.pool(F.relu(self.conv2(x)))
  14.     x = x.view(inputs.size()[0],-1)
  15.     x = F.relu(self.fc1(x))
  16.     return self.fc2(x)
复制代码
下面是卷积的动态演示

  1. in_channels
复制代码
:为输入通道数 彩色图片有3个通道 好坏有1个通道
  1. out_channels
复制代码
:输出通道数
  1. kernel_size
复制代码
:卷积核的巨细
  1. stride
复制代码
:卷积的步长
  1. padding
复制代码
:外边距巨细

输出的size盘算公式
      
  • h = (h - kernel_size + 2*padding)/stride + 1  
  • w = (w - kernel_size + 2*padding)/stride + 1
  1. MaxPool2d
复制代码
:是没有参数进行运算的


实例化网络优化器,并且使用GPU进行训练
  1. net = Net()
  2. opt = torch.optim.Adam(params=net.parameters(), lr=0.001)
  3. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  4. net.to(device)
复制代码
  1. Net(
  2. (conv1): Conv2d(1, 28, kernel_size=(5, 5), stride=(1, 1))
  3. (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  4. (conv2): Conv2d(28, 56, kernel_size=(5, 5), stride=(1, 1))
  5. (fc1): Linear(in_features=896, out_features=1024, bias=True)
  6. (fc2): Linear(in_features=1024, out_features=10, bias=True)
  7. )
复制代码
训练紧张代码
  1. for epoch in range(50):
  2.   for images, labels in trainloader:
  3.     images = images.to(device)
  4.     labels = labels.to(device)
  5.     pre_label = net(images)
  6.     loss = F.cross_entropy(input=pre_label, target=labels).mean()
  7.     pre_label = torch.argmax(pre_label, dim=1)
  8.     acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32)
  9.     net.zero_grad()
  10.     loss.backward()
  11.     opt.step()
  12.   print(acc.detach().cpu().numpy(), loss.detach().cpu().numpy())
复制代码
  1. F.cross_entropy
复制代码
交错熵函数



源码中已经资助我们实现了
  1. softmax
复制代码
因此不需要自己进行
  1. softmax
复制代码
操作了
  1. torch.argmax
复制代码
盘算最大数所在索引值
  1. acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32)
  2. # pre_label==labels 相同维度进行比较相同返回True不同的返回False,True为1 False为0, 即可获取到相等的个数,再除总个数,就得到了Accuracy准确度了
复制代码
预测
  1. testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
  2. testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True,num_workers=2)
  3. images, labels = iter(testloader).next()
  4. images = images.to(device)
  5. labels = labels.to(device)
  6. with torch.no_grad():
  7.   pre_label = net(images)
  8.   pre_label = torch.argmax(pre_label, dim=1)
  9.   acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32)
  10.   print(acc)
复制代码
总结

本节我们相识了
  1. 标准化数据·
复制代码
  1. 卷积的原理
复制代码
  1. 简答的构建了一个网络
复制代码
,并让它去辨认手写体,也是对前面章节的总汇了
到此这篇关于超具体PyTorch实现手写数字辨认器的示例代码的文章就先容到这了,更多相关PyTorch 手写数字辨认器内容请搜索草根技术分享从前的文章或继续欣赏下面的相关文章渴望大家以后多多支持草根技术分享!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有帐号?立即注册

x

帖子地址: 

回复

使用道具 举报

分享
推广
火星云矿 | 预约S19Pro,享500抵1000!
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

草根技术分享(草根吧)是全球知名中文IT技术交流平台,创建于2021年,包含原创博客、精品问答、职业培训、技术社区、资源下载等产品服务,提供原创、优质、完整内容的专业IT技术开发社区。
  • 官方手机版

  • 微信公众号

  • 商务合作