pytorch 数据集图片显示方法

yipeiwu_com5年前Python基础

图片显示

pytorch 载入的数据集是元组tuple 形式,里面包括了数据及标签(train_data,label),其中的train_data数据可以转换为torch.Tensor形式,方便后面计算使用。

同样给一些刚入门的同学在使用载入的数据显示图片的时候带来一些难以理解的地方,这里主要是将Tensor与numpy转换的过程,理解了这些就可以就行转换了

CIAFA10数据集

首先载入数据集,这里做了一些数据处理,包括图片尺寸、数据归一化等

import torch
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
import torchvision.datasets as dset
import torchvision.transforms as transforms
from autoencoder import AutoEncoder
import torch.nn as nn
import torchvision
import numpy as np
dataset = dset.CIFAR10(root='../train/data', download=True, 
    transform=transforms.Compose([
    transforms.Scale(200),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Gray()
    ]))

在这里 dataset 是一个CIFAR10对象,(大家可以查看一下他的源代码)

方式一

dataset[1] = ([torch.FloatTensor of size 1x200x200],9)

载入的第二个数据是个tensor格式,包含一个标签 9

这里我们做的就是将torch.FloatTensor 转换为numpy,然后显示

b = dataset[1][0].numpy()
#取数据,不取标签

因为这里的b仍然是1*200*200的大小,所以要重新reshape一下,适合输出图像

plt.imshow(b.reshape(200,200),cmap = 'gray')
plt.show()

然后可以显示图像了

方式二

利用torch的接口

img = torchvision.utils.make_grid(dataset[1][0]).numpy()
plt.imshow(np.transpose(img,(1,2,0)))
plt.show()

这用np.transpose 是因为plt.imshow在显示 时候输入的是(imgsize,imgsieze,channels),而这里得到的img是(3,200,200)的格式,所以进行了转换,才能显示

以上这篇pytorch 数据集图片显示方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python 点击指定位置验证码破解的实现代码

思路: 创建浏览器驱动对象 加载登录页面 等待页面加载完毕 切换到用户名和密码登录模式 输入手机号, 注意此处需要等待并获取输入框 输入密码 点击验证按钮 获取弹出验证图...

Django日志模块logging的配置详解

前言 Django对于日志输出的信息是很完善的,request的信息,setting配置,trackback的信息,一应俱全,足够我们调试了。但是在线上环境,如果让用户看到这些信息,是很...

win10系统下Anaconda3安装配置方法图文教程

win10系统下Anaconda3安装配置方法图文教程

本文主要介绍在 windows 10 系统中安装 Anaconda3 的详细过程。 下载 Anaconda 官网下载地址 目前最新版本是 python 3.6,默认下载也是 Python...

Python常用特殊方法实例总结

本文实例讲述了Python常用特殊方法。分享给大家供大家参考,具体如下: 1 __init__和__new__ __init__方法用来初始化类实例;__new__方法用来创建类实例。...

python分批定量读取文件内容,输出到不同文件中的方法

一、文件内容的分发 应用场景:分批读取共有358086行内容的txt文件,每取1000条输出到一个文件当中 # coding=utf-8 # 分批读取共有358086行内容的txt...