Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com6年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python Pandas对缺失值的处理方法

Pandas使用这些函数处理缺失值: isnull和notnull:检测是否是空值,可用于df和series dropna:丢弃、删除缺失值 axis : 删除行...

Flask框架各种常见装饰器示例

Flask框架各种常见装饰器示例

本文实例讲述了Flask框架各种常见装饰器。分享给大家供大家参考,具体如下: 效果类似django的process_request的装饰器 @app.before_request d...

使用Filter过滤python中的日志输出的实现方法

事情是这样的,我写了一个tornado的服务,过程当中我用logging记录一些内容,由于一开始并没有仔细观察tornado自已的日志管理,所以我就一般用debug来记录普通日志,err...

在Python下使用Txt2Html实现网页过滤代理的教程

在撰写本 developerWorks 系列文章的过程中,我曾遇到过以最佳格式进行撰写的问题。文字处理程序格式都是专用的,在格式之间转换总不能尽如人意,也很麻烦(而且每种格式都会各自将文...

用python制作游戏外挂

玩过电脑游戏的同学对于外挂肯定不陌生,但是你在用外挂的时候有没有想过如何做一个外挂呢?(当然用外挂不是那么道义哈,呵呵),那我们就来看一下如何用python来制作一个外挂。。。。 我打开...