Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com6年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

详解Python3迁移接口变化采坑记

1、除法相关 在python3之前, print 13/4 #result=3 然而在这之后,却变了! print(13 / 4) #result=3.25 "/”符号运算...

Python中对元组和列表按条件进行排序的方法示例

在python中对一个元组排序 我的同事Axel Hecht 给我展示了一些我所不知道的关于python排序的东西。 在python里你可以对一个元组进行排序。例子是最好的说明: &...

Python django使用多进程连接mysql错误的解决方法

Python django使用多进程连接mysql错误的解决方法

问题 mysql 查询出现错误 error: (2014, "Commands out of sync; you can't run this command now")1 查询 m...

Python的函数的一些高阶特性

高阶函数英文叫Higher-order function。什么是高阶函数?我们以实际代码为例子,一步一步深入概念。 变量可以指向函数 以Python内置的求绝对值的函数abs()为例,调...

Python中使用第三方库xlrd来读取Excel示例

本篇文章介绍如何使用xlrd来读取Excel表格中的内容,xlrd是第三方库,所以在使用前我们需要安装xlrd。另外我们一般会使用xlwt来写Excel,所以下一篇文章我们会来介绍如何使...