Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com5年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

详解基于python的多张不同宽高图片拼接成大图

详解基于python的多张不同宽高图片拼接成大图

半年前写过一篇将多张图片拼接成大图的博客,是讲的把所有图片先转换为256×256的图片后再进行拼接,今天看到一个朋友的评论说如何拼接非正方形图片,如47×57,之前有个朋友也问过这个,我...

python实现批量文件重命名

python实现批量文件重命名

本文实例为大家分享了python批量文件重命名的具体代码,供大家参考,具体内容如下 问题描述 最近遇到朋友求助,如何将大量文件名前面的某些字符删除。 即将图中文件前的编号删除。 P...

Python调用服务接口的实例

如下所示: #! /usr/bin/env python # coding=utf-8 ###############################################...

Python 实现自动导入缺失的库

Python 实现自动导入缺失的库

在写 Python 项目的时候,我们可能经常会遇到导入模块失败的错误:ImportError: No module named 'xxx' 或者 ModuleNotFoundError:...

python 随机打乱 图片和对应的标签方法

如下所示: # -*- coding: utf-8 -*- import os import numpy as np import pandas as pd import h5p...