Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com6年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

在centos7中分布式部署pyspider

1.搭建环境: 系统版本:Linux centos-linux.shared 3.10.0-123.el7.x86_64 #1 SMP Mon Jun 30 12:09:22 UTC 2...

Django配置MySQL数据库的完整步骤

Django配置MySQL数据库的完整步骤

一、在settings.py中配置 DATABASES = { 'default': { 'ENGINE': 'django.db.backends.mysql',  # 数...

python不换行之end=与逗号的意思及用途

在python中我们偶尔会用到输出不换行的效果,python2中使用逗号,即可,而python3中使用end=''来实现的,这里简单为大家介绍一下,需要的朋友可以参考下 python输出...

用Python的Django框架编写从Google Adsense中获得报表的应用

 我完成了更新我们在 Neutron的实时收入统计。在我花了一周的时间完成并且更新了我们的PHP脚本之后,我最终认决定开始使用Python进行抓取,这是值得我去花费我的时间和精...

Python的Django框架中if标签的相关使用

{% if %} 标签检查(evaluate)一个变量,如果这个变量为真(即,变量存在,非空,不是布尔值假),系统会显示在 {% if %} 和 {% endif %} 之间的任何内容,...