pytorch标签转onehot形式实例

yipeiwu_com6年前Python基础

代码:

import torch

class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size, 1).random_() % class_num
print(label.size())

one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
print(one_hot)

输出:

torch.Size([4, 1])
tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

注意:

label的形状必须是[n,1]的,也就是必须是二维的,且第二个维度长度为1,如果是一维度的,则需要升维度,代码如下:

import torch

class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size).random_() % class_num
print(label.size())
label = torch.unsqueeze(label,dim=1)
print(label.size())

以上这篇pytorch标签转onehot形式实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Linux下python制作名片示例

Linux下python制作名片示例

建立cards_main文件: # _*_ coding:utf-8 _*_ """ file: cards_main.py date: 2018-07-18 19:47 auth...

Python动态参数/命名空间/函数嵌套/global和nonlocal

1. 函数的动态参数    1.1 *args 位置参数动态传参 def chi(*food): print("我要吃", food) chi("大米饭", "小米饭")...

Python的Flask开发框架简单上手笔记

最简单的hello world #!/usr/bin/env python # encoding: utf-8 from flask import Flask app = Fla...

Flask框架Flask-Login用法分析

本文实例讲述了Flask框架Flask-Login用法。分享给大家供大家参考,具体如下: Flask-Login插件中带了6种信号,可以基于其中的信号做一些额外工作,比如user_log...

Mac安装python3的方法步骤

Mac安装python3的方法步骤

Python有两个版本,一个是2.x版,一个是3.x版,这两个版本是不兼容的。 现在 Mac 上默认安装的 python 版本为 2.7 版本,若 安装 新版本需要 通过 该地址进行下载...