pytorch 利用lstm做mnist手写数字识别分类的实例

yipeiwu_com6年前Python基础

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)    

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

django 在原有表格添加或删除字段的实例

一、如果models.py文件为时: timestamp = models.DateTimeField('保存日期') 会提示: (env8) D:\Desktop\env8\...

Python中的闭包总结

前几天又有人在我的这篇文章 python项目练习一:即时标记 下留言,关于其中一个闭包和re.sub的使用不太清楚。我在自己的博客上搜索了下,发现没有写过闭包相关的东西,所以决定总结一下...

详解Python的Django框架中的通用视图

详解Python的Django框架中的通用视图

通用视图 1. 前言 回想一下,在Django中view层起到的作用是相当于controller的角色,在view中实施的 动作,一般是取得请求参数,再从model中得到数据,再通过数据...

python距离测量的方法

之所以写这个,其实就是希望能对距离有一些概念,当然这个也是很基础的,不过千里之行始于足下嘛,各种路径算法,比如a*什么的都会用到这个 距离测量有三种方式 1、欧式距离,这个是最常用的距离...

Python删除空文件和空文件夹的方法

本文实例讲述了Python删除空文件和空文件夹的方法。分享给大家供大家参考。具体实现方法如下: #-*- coding:cp936 -*- """ os.walk() 函数声明:wa...