Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com6年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python+numpy实现矩阵的行列扩展方式

Python+numpy实现矩阵的行列扩展方式

对于numpy矩阵,行列扩展有三种比较常用的方法: 1、使用矩阵对象的c_方法扩展列,使用矩阵对象的r_方法扩展行。 2、使用numpy扩展库提供的insert()函数,使用axis参数...

Python探索之Metaclass初步了解

先以一个大牛的一段关于Python Metapgramming的著名的话来做开头: Metaclasses are deeper magic than 99% of users sho...

Python实现对象转换为xml的方法示例

本文实例讲述了Python实现对象转换为xml的方法。分享给大家供大家参考,具体如下: # -*- coding:UTF-8 -*- ''''' Created on 2010-4-...

python实现12306登录并保存cookie的方法示例

经过倒腾12306的登录,还是实现了,请求头很重要...各位感兴趣的可以继续写下去..... import sys import time import requests from...

Python实现SSH远程登陆,并执行命令的方法(分享)

在自动化测试过程中,比较常用的操作就是对远程主机进行操作,如何操作呢?使用SSH远程登陆到主机,然后执行相应的command即可。 使用Python来实现这些操作就相当简单了。下面是测试...