Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com6年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python正则表达式re之compile函数解析

re正则表达式模块还包括一些有用的操作正则表达式的函数。下面主要介绍compile函数。 定义: compile(pattern[,flags] ) 根据包含正则表达式的字符串创...

python通过正则查找微博@(at)用户的方法

本文实例讲述了python通过正则查找微博@(at)用户的方法。分享给大家供大家参考。具体如下: 这段代码用到了python正则的findall方法,查找所有被@的用户,使用数组形式返回...

Python中join和split用法实例

join用来连接字符串,split恰好相反,拆分字符串的。 不用多解释,看完代码,其意自现了。 复制代码 代码如下: >>>li = ['my','name','is'...

python str与repr的区别

尽管str(),repr()和``运算在特性和功能方面都非常相似,事实上repr()和``做的是完全一样的事情,它们返回的是一个对象的“官方”字符串表示,也就是说绝大多数情况下可以通过求...

在django中自定义字段Field详解

Django的Field类中方法有: to_python() # 把数据库数据转成python数据 from_db_value() # 把数据库数据转成python数据 get_pre_...