pytorch 利用lstm做mnist手写数字识别分类的实例

yipeiwu_com6年前Python基础

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)    

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

78行Python代码实现现微信撤回消息功能

78行Python代码实现现微信撤回消息功能

Python曾经对我说:"时日不多,赶紧用Python"。于是看到了一个基于python的微信开源库:itchat,玩了一天,做了一个程序,把私聊撤回的信息可以收集起来并发送到个人微信的...

Python实现获取磁盘剩余空间的2种方法

Python实现获取磁盘剩余空间的2种方法

本文实例讲述了Python实现获取磁盘剩余空间的2种方法。分享给大家供大家参考,具体如下: 方法1: import ctypes import os import platform...

python发送伪造的arp请求

复制代码 代码如下:#!/usr/bin/env pythonimport socket s = socket.socket(socket.AF_PACKET, socket.SOCK_...

python 单线程和异步协程工作方式解析

在python3.4之后新增了asyncio模块,可以帮我们检测IO(只能是网络IO【HTTP连接就是网络IO操作】),实现应用程序级别的切换(异步IO)。注意:asyncio只能发tc...

pycharm配置git(图文教程)

pycharm配置git(图文教程)

下载git客户端  FileàDefault Settingà Version Controlà Git Path to Git executable 填写git客户端的git...