Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com6年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python 上下文管理器及自定义原理解析

这篇文章主要介绍了python 上下文管理器原理解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 Python 提供了 with 语...

python多线程http压力测试脚本

本文实例为大家分享了python多线程http压力测试的具体代码,供大家参考,具体内容如下 #coding=utf-8 import sys import time import...

python 正则式使用心得

1.match() 从开始位置开始匹配 2.search() 任意位置匹配,如果有多个匹配,只返回第一个 3.finditer() 返回所有匹配 4.每次匹配,都是尽量最大匹配。例如:...

Python实现的排列组合、破解密码算法示例

本文实例讲述了Python实现的排列组合、破解密码算法。分享给大家供大家参考,具体如下: 排列组合(破解密码) 1.排列 itertools.permutations(iterabl...

python在html中插入简单的代码并加上时间戳的方法

python在html中插入简单的代码并加上时间戳的方法

建议用pycharm,使用比较方便,并且可以直接编辑html文件 import time locatime = time.strftime("%Y-%m-%d" ) report =...