Pytorch提取模型特征向量保存至csv的例子

yipeiwu_com6年前Python基础

Pytorch提取模型特征向量

# -*- coding: utf-8 -*-
"""
dj
"""
import torch
import torch.nn as nn
import os
from torchvision import models, transforms
from torch.autograd import Variable 
import numpy as np
from PIL import Image 
import torchvision.models as models
import pretrainedmodels
import pandas as pd
class FCViewer(nn.Module):
 def forward(self, x):
  return x.view(x.size(0), -1)
class M(nn.Module):
 def __init__(self, backbone1, drop, pretrained=True):
  super(M,self).__init__()
  if pretrained:
   img_model = pretrainedmodels.__dict__[backbone1](num_classes=1000, pretrained='imagenet') 
  else:
   img_model = pretrainedmodels.__dict__[backbone1](num_classes=1000, pretrained=None)  
  self.img_encoder = list(img_model.children())[:-2]
  self.img_encoder.append(nn.AdaptiveAvgPool2d(1))
  self.img_encoder = nn.Sequential(*self.img_encoder)
  if drop > 0:
   self.img_fc = nn.Sequential(FCViewer())         
  else:
   self.img_fc = nn.Sequential(
    FCViewer())
 def forward(self, x_img):
  x_img = self.img_encoder(x_img)
  x_img = self.img_fc(x_img)
  return x_img 
model1=M('resnet18',0,pretrained=True)
features_dir = '/home/cc/Desktop/features' 
transform1 = transforms.Compose([
  transforms.Resize(256),
  transforms.CenterCrop(224),
  transforms.ToTensor()]) 
file_path='/home/cc/Desktop/picture'
names = os.listdir(file_path)
print(names)
for name in names:
 pic=file_path+'/'+name
 img = Image.open(pic)
 img1 = transform1(img)
 x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
 y = model1(x)
 y = y.data.numpy()
 y = y.tolist()
 #print(y)
 test=pd.DataFrame(data=y)
 #print(test)
 test.to_csv("/home/cc/Desktop/features/3.csv",mode='a+',index=None,header=None)

jiazaixunlianhaodemoxing

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
class ResidualBlock(nn.Module):
 def __init__(self, inchannel, outchannel, stride=1):
  super(ResidualBlock, self).__init__()
  self.left = nn.Sequential(
   nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
   nn.BatchNorm2d(outchannel),
   nn.ReLU(inplace=True),
   nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
   nn.BatchNorm2d(outchannel)
  )
  self.shortcut = nn.Sequential()
  if stride != 1 or inchannel != outchannel:
   self.shortcut = nn.Sequential(
    nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
    nn.BatchNorm2d(outchannel)
   )

 def forward(self, x):
  out = self.left(x)
  out += self.shortcut(x)
  out = F.relu(out)
  return out

class ResNet(nn.Module):
 def __init__(self, ResidualBlock, num_classes=10):
  super(ResNet, self).__init__()
  self.inchannel = 64
  self.conv1 = nn.Sequential(
   nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
   nn.BatchNorm2d(64),
   nn.ReLU(),
  )
  self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
  self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
  self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
  self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
  self.fc = nn.Linear(512, num_classes)

 def make_layer(self, block, channels, num_blocks, stride):
  strides = [stride] + [1] * (num_blocks - 1) #strides=[1,1]
  layers = []
  for stride in strides:
   layers.append(block(self.inchannel, channels, stride))
   self.inchannel = channels
  return nn.Sequential(*layers)

 def forward(self, x):
  out = self.conv1(x)
  out = self.layer1(out)
  out = self.layer2(out)
  out = self.layer3(out)
  out = self.layer4(out)
  out = F.avg_pool2d(out, 4)
  out = out.view(out.size(0), -1)
  out = self.fc(out)
  return out


def ResNet18():

 return ResNet(ResidualBlock)

import os
from torchvision import models, transforms
from torch.autograd import Variable 
import numpy as np
from PIL import Image 
import torchvision.models as models
import pretrainedmodels
import pandas as pd
class FCViewer(nn.Module):
 def forward(self, x):
  return x.view(x.size(0), -1)
class M(nn.Module):
 def __init__(self, backbone1, drop, pretrained=True):
  super(M,self).__init__()
  if pretrained:
   img_model = pretrainedmodels.__dict__[backbone1](num_classes=1000, pretrained='imagenet') 
  else:
   img_model = ResNet18()
   we='/home/cc/Desktop/dj/model1/incption--7'
   # 模型定义-ResNet
   #net = ResNet18().to(device)
   img_model.load_state_dict(torch.load(we))#diaoyong  
  self.img_encoder = list(img_model.children())[:-2]
  self.img_encoder.append(nn.AdaptiveAvgPool2d(1))
  self.img_encoder = nn.Sequential(*self.img_encoder)
  if drop > 0:
   self.img_fc = nn.Sequential(FCViewer())         
  else:
   self.img_fc = nn.Sequential(
    FCViewer())
 def forward(self, x_img):
  x_img = self.img_encoder(x_img)
  x_img = self.img_fc(x_img)
  return x_img 
model1=M('resnet18',0,pretrained=None)
features_dir = '/home/cc/Desktop/features' 
transform1 = transforms.Compose([
  transforms.Resize(56),
  transforms.CenterCrop(32),
  transforms.ToTensor()]) 
file_path='/home/cc/Desktop/picture'
names = os.listdir(file_path)
print(names)
for name in names:
 pic=file_path+'/'+name
 img = Image.open(pic)
 img1 = transform1(img)
 x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
 y = model1(x)
 y = y.data.numpy()
 y = y.tolist()
 #print(y)
 test=pd.DataFrame(data=y)
 #print(test)
 test.to_csv("/home/cc/Desktop/features/3.csv",mode='a+',index=None,header=None)

以上这篇Pytorch提取模型特征向量保存至csv的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python中模块与包有相同名字的处理方法

前言 在编程开发中,个人觉得,只要按照规范去做,很少会出问题。刚开始学习一门技术时,的确会遇到很多的坑。踩的坑多了,这是好事,会学到更多东西,也会越来越觉得按照规范做的重要性,规范的制定...

Python While循环语句实例演示及原理解析

Python While循环语句实例演示及原理解析

这篇文章主要介绍了Python While循环语句实例演示及原理解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 Python 编程...

python备份文件以及mysql数据库的脚本代码

复制代码 代码如下: #!/usr/local/python import os import time import string source=['/var/www/html/xxx...

深入理解Python中字典的键的使用

字典的键        字典中的值没有任何限制, 可以是任意Python对象,即从标准对象到用户自定义对象皆可,但是字典中的键...

Python使用Pandas库实现MySQL数据库的读写

Python使用Pandas库实现MySQL数据库的读写

本次分享将介绍如何在Python中使用Pandas库实现MySQL数据库的读写。首先我们需要了解点ORM方面的知识 ORM技术 对象关系映射技术,即ORM(Object-Relatio...