Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com5年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python实现多进程共享数据的方法分析

本文实例讲述了Python实现多进程共享数据的方法。分享给大家供大家参考,具体如下: 示例一: # -*- coding:utf-8 -*- from multiprocessing...

深入解析python中的实例方法、类方法和静态方法

深入解析python中的实例方法、类方法和静态方法

1、实例方法/对象方法 实例方法或者叫对象方法,指的是我们在类中定义的普通方法。 只有实例化对象之后才可以使用的方法,该方法的第一个形参接收的一定是对象本身 2、静态方法 (1).格式...

学习python可以干什么

python是什么? python的中文名称是蟒蛇,是一种计算机程序设计语言;是一种动态的、面向对象的脚本语言。最初是用来编写自动化脚本的,随着版本的不断更新和语言新功能的添加,越来越多...

python 删除非空文件夹的实例

一般删除文件时使用os库,然后利用os.remove(path)即可完成删除,如果删除空文件夹则可使用os.removedirs(path)即可, 但是如果需要删除整个文件夹,且文件夹非...

Python中如何获取类属性的列表

前言 最近工作中遇到个需求是要得到一个类的静态属性,也就是说有个类 Type ,我要动态获取 Type.FTE 这个属性的值。 最简单的方案有两个: getattr(Type, 'F...