pytorch 数据集图片显示方法

yipeiwu_com6年前Python基础

图片显示

pytorch 载入的数据集是元组tuple 形式,里面包括了数据及标签(train_data,label),其中的train_data数据可以转换为torch.Tensor形式,方便后面计算使用。

同样给一些刚入门的同学在使用载入的数据显示图片的时候带来一些难以理解的地方,这里主要是将Tensor与numpy转换的过程,理解了这些就可以就行转换了

CIAFA10数据集

首先载入数据集,这里做了一些数据处理,包括图片尺寸、数据归一化等

import torch
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
import torchvision.datasets as dset
import torchvision.transforms as transforms
from autoencoder import AutoEncoder
import torch.nn as nn
import torchvision
import numpy as np
dataset = dset.CIFAR10(root='../train/data', download=True, 
    transform=transforms.Compose([
    transforms.Scale(200),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Gray()
    ]))

在这里 dataset 是一个CIFAR10对象,(大家可以查看一下他的源代码)

方式一

dataset[1] = ([torch.FloatTensor of size 1x200x200],9)

载入的第二个数据是个tensor格式,包含一个标签 9

这里我们做的就是将torch.FloatTensor 转换为numpy,然后显示

b = dataset[1][0].numpy()
#取数据,不取标签

因为这里的b仍然是1*200*200的大小,所以要重新reshape一下,适合输出图像

plt.imshow(b.reshape(200,200),cmap = 'gray')
plt.show()

然后可以显示图像了

方式二

利用torch的接口

img = torchvision.utils.make_grid(dataset[1][0]).numpy()
plt.imshow(np.transpose(img,(1,2,0)))
plt.show()

这用np.transpose 是因为plt.imshow在显示 时候输入的是(imgsize,imgsieze,channels),而这里得到的img是(3,200,200)的格式,所以进行了转换,才能显示

以上这篇pytorch 数据集图片显示方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python 字典修改键(key)的几种方法

python 字典修改键(key)的几种方法

python中获取字典的key列表和value列表 # -*- coding: utf-8 -*- # 定义一个字典 dic = {'剧情': 11, '犯罪': 10, '动作...

Python图片裁剪实例代码(如头像裁剪)

Python图片裁剪实例代码(如头像裁剪)

今天就来说个常用的功能,图片裁剪,可用于头像裁剪啊之类的。用的还是我们之前用的哪个模块pillow 1. 安装pillow 用pip安装 pip install pillow 2....

python使用tensorflow保存、加载和使用模型的方法

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:...

给我一面国旗 python帮你实现

给我一面国旗 python帮你实现

本文实例为大家分享了Python之给我一面国旗的具体代码,供大家参考,具体内容如下 1、“给我一面国旗@微信官方” 今天“给我一面国旗@微信官方”刷爆了朋友圈,我也蹭波热度,出个Pyth...

基于Python Numpy的数组array和矩阵matrix详解

基于Python Numpy的数组array和矩阵matrix详解

NumPy的主要对象是同种元素的多维数组。这是一个所有的元素都是一种类型、通过一个正整数元组索引的元素表格(通常是元素是数字)。 在NumPy中维度(dimensions)叫做轴(axe...