Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com5年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python障碍式期权定价公式

早期写的python障碍式期权的定价脚本,供大家参考,具体内容如下 #coding:utf-8 ''' 障碍期权 q=x/s H = h/x H 障碍价格 [1] Down-and-...

彻彻底底地理解Python中的编码问题

Python处理文本的功能非常强大,但是如果是初学者,没有搞清楚python中的编码机制,也经常会遇到乱码或者decode error。本文的目的是简明扼要地说明python的编码机制,...

Python查找相似单词的方法

本文实例讲述了Python查找相似单词的方法。分享给大家供大家参考。具体分析如下: 问题: 给你一个单词a,如果通过交换单词中字母的顺序可以得到另外的单词b,那么定义b是a的兄弟单词。现...

Python3基础之基本数据类型概述

本文针对Python3中基本数据类型进行实例介绍,这些对于Python初学者而言是必须掌握的知识,具体内容如下: 首先,Python中的变量不需要声明。每个变量在使用前都必须赋值,变量赋...

PyTorch的自适应池化Adaptive Pooling实例

PyTorch的自适应池化Adaptive Pooling实例

简介 自适应池化Adaptive Pooling是PyTorch含有的一种池化层,在PyTorch的中有六种形式: 自适应最大池化Adaptive Max Pooling: torch....