pytorch 数据集图片显示方法

yipeiwu_com6年前Python基础

图片显示

pytorch 载入的数据集是元组tuple 形式,里面包括了数据及标签(train_data,label),其中的train_data数据可以转换为torch.Tensor形式,方便后面计算使用。

同样给一些刚入门的同学在使用载入的数据显示图片的时候带来一些难以理解的地方,这里主要是将Tensor与numpy转换的过程,理解了这些就可以就行转换了

CIAFA10数据集

首先载入数据集,这里做了一些数据处理,包括图片尺寸、数据归一化等

import torch
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
import torchvision.datasets as dset
import torchvision.transforms as transforms
from autoencoder import AutoEncoder
import torch.nn as nn
import torchvision
import numpy as np
dataset = dset.CIFAR10(root='../train/data', download=True, 
    transform=transforms.Compose([
    transforms.Scale(200),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Gray()
    ]))

在这里 dataset 是一个CIFAR10对象,(大家可以查看一下他的源代码)

方式一

dataset[1] = ([torch.FloatTensor of size 1x200x200],9)

载入的第二个数据是个tensor格式,包含一个标签 9

这里我们做的就是将torch.FloatTensor 转换为numpy,然后显示

b = dataset[1][0].numpy()
#取数据,不取标签

因为这里的b仍然是1*200*200的大小,所以要重新reshape一下,适合输出图像

plt.imshow(b.reshape(200,200),cmap = 'gray')
plt.show()

然后可以显示图像了

方式二

利用torch的接口

img = torchvision.utils.make_grid(dataset[1][0]).numpy()
plt.imshow(np.transpose(img,(1,2,0)))
plt.show()

这用np.transpose 是因为plt.imshow在显示 时候输入的是(imgsize,imgsieze,channels),而这里得到的img是(3,200,200)的格式,所以进行了转换,才能显示

以上这篇pytorch 数据集图片显示方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

微信跳一跳python辅助脚本(总结)

微信跳一跳python辅助脚本(总结)

这段时间微信跳一跳这个游戏非常火爆,但是上分又非常的难,对于程序员来说第一个念头就是通过写一个辅助脚本外挂让上分变的容易,python现在比较火,我们一起来以python语言为基础总结以...

python中使用 xlwt 操作excel的常见方法与问题

前言 Python可以操作Excel的模块不止一种,我习惯使用的写入模块是xlwt(一般都是读写模块分开的) python中使用xlwt操作excel非常方,和Java使用调框架apac...

python多线程同步实例教程

python多线程同步实例教程

前言 进程之间通信与线程同步是一个历久弥新的话题,对编程稍有了解应该都知道,但是细说又说不清。一方面除了工作中可能用的比较少,另一方面就是这些概念牵涉到的东西比较多,而且相对较深。网络编...

pandas删除行删除列增加行增加列的实现

创建df: >>> df = pd.DataFrame(np.arange(16).reshape(4, 4), columns=list('ABCD'), ind...

python 字符串追加实例

通过一个for循环,将一个一个字符追加到字符串中: 方法一: string = '' str=u"追加字符" for i in range(len(str)): string+=...