Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com6年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python 访问限制 private public的详细介绍

 一、知识点 在一个模块中,我们可能会定义很多函数和变量。但有的函数和变量我们希望能给别人使用,有的函数和变量我们希望仅仅在模块内部使用,so? 我们可以通过定义该函...

在python tkinter中Canvas实现进度条显示的方法

在python tkinter中Canvas实现进度条显示的方法

如下所示: from tkinter import * import time #更新进度条函数 def change_schedule(now_schedule,all_sch...

pyqt5 使用label控件实时显示时间的实例

pyqt5 使用label控件实时显示时间的实例

如下所示: import sys from PyQt5 import QtGui, QtCore, QtWidgets from PyQt5.QtWidgets import * f...

使用python将mysql数据库的数据转换为json数据的方法

使用python将mysql数据库的数据转换为json数据的方法

由于产品运营部需要采用第三方个推平台,来推送消息。如果手动一个个键入字段和字段值,容易出错,且非常繁琐,需要将mysql的数据转换为json数据,直接复制即可。 本文将涉及到如何使用Py...

Python实现找出数组中第2大数字的方法示例

本文实例讲述了Python实现找出数组中第2大数字的方法。分享给大家供大家参考,具体如下: 题目比较简单直接看实现即可,具体的注释在代码中都有: #!usr/bin/env pyth...