Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com5年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python实现读取Properties配置文件的方法

本文实例讲述了Python实现读取Properties配置文件的方法。分享给大家供大家参考,具体如下: JAVA本身提供了对于Properties文件操作的类,项目中的很多配置信息都是放...

Python网络编程之TCP套接字简单用法示例

本文实例讲述了Python网络编程之TCP套接字简单用法。分享给大家供大家参考,具体如下: 上学期学的计算机网络,因为之前还未学习python,而java则一知半解,C写起来又麻烦,所以...

Python Web框架Flask信号机制(signals)介绍

信号(signals) Flask信号(signals, or event hooking)允许特定的发送端通知订阅者发生了什么(既然知道发生了什么,那我们可以知道接下来该做什么了)。...

在django模板中实现超链接配置

django中的超链接,在template中可以用{% url 'app_name:url_name' param%} 其中app_name在工程urls中配置的namespace取值,...

python提取xml里面的链接源码详解

因群里朋友需要提取xml地图里面的链接,就写了这个程序。 代码: #coding=utf-8 import urllib import urllib.request import r...