Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com6年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

用Python的线程来解决生产者消费问题的示例

我们将使用Python线程来解决Python中的生产者—消费者问题。这个问题完全不像他们在学校中说的那么难。 如果你对生产者—消费者问题有了解,看这篇博客会更有意义。 为什么要关心生产者...

python和flask中返回JSON数据的方法

python和flask中返回JSON数据的方法

在python中可以使用json将数据格式化为JSON格式: 1.将字典转换成JSON数据格式: s=['张三','年龄','姓名'] t={} t['data']=s ret...

python如何解析配置文件并应用到项目中

配置文件的类型 通常自动化测试中的配置文件是以.ini 和 .conf 为后缀的文件 配置文件的组成 1.section 2.option 3.value 配置文件的格式 [s...

PyQt弹出式对话框的常用方法及标准按钮类型

PyQt弹出式对话框的常用方法及标准按钮类型

PyQt之弹出式对话框(QMessageBox)的常用方法及标准按钮类型 一、控件说明 QMessageBox是一种通用的弹出式对话框,用于显示消息,允许用户通过单击不同的标准按钮对消息...

Python实现KNN邻近算法

简介 邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用...