Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com5年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Flask使用Pyecharts在单个页面展示多个图表的方法

Flask使用Pyecharts在单个页面展示多个图表的方法

在Flask页面展示echarts,主要有两种方法: 方法1、原生echarts方法 自己在前端引入echarts.js文件、自己创建div、自己初始化echarts对象、自己从官网复制...

python tensorflow学习之识别单张图片的实现的示例

python tensorflow学习之识别单张图片的实现的示例

假设我们已经安装好了tensorflow。 一般在安装好tensorflow后,都会跑它的demo,而最常见的demo就是手写数字识别的demo,也就是mnist数据集。 然而我们仅仅是...

python对excel文档去重及求和的实例

废话不多说,估计只有我这个菜鸟废了2个小时才搞出来,主要是我想了太多方法来实现,最后都因为这因为那的原因失败了 间接说明自己对可变与不可变类型的了解,还是不够透彻 最后就用了个笨方法解决...

Python 3.x读写csv文件中数字的方法示例

前言 本文主要给大家介绍了关于Python3.x读写csv文件中数字的相关内容,分享出来供大家参考学习,下面话不多说了,来一起看看详细的介绍吧。 读写csv文件 读文件时先产生str的列...

Python 利用高德地图api实现经纬度与地址的批量转换

我们都知道,可以使用高德地图api实现经纬度与地址的转换。那么,当我们有很多个地址与经纬度,需要批量转换的时候,应该怎么办呢? 在这里,选用高德Web服务的API,其中的地址/逆地址编码...