PyTorch实现AlexNet示例

yipeiwu_com6年前Python基础

PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

import torch
import torch.nn as nn
import torchvision

class AlexNet(nn.Module):
  def __init__(self,num_classes=1000):
    super(AlexNet,self).__init__()
    self.feature_extraction = nn.Sequential(
      nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
      nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
      nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
    )
    self.classifier = nn.Sequential(
      nn.Dropout(p=0.5),
      nn.Linear(in_features=256*6*6,out_features=4096),
      nn.ReLU(inplace=True),
      nn.Dropout(p=0.5),
      nn.Linear(in_features=4096, out_features=4096),
      nn.ReLU(inplace=True),
      nn.Linear(in_features=4096, out_features=num_classes),
    )
  def forward(self,x):
    x = self.feature_extraction(x)
    x = x.view(x.size(0),256*6*6)
    x = self.classifier(x)
    return x


if __name__ =='__main__':
  # model = torchvision.models.AlexNet()
  model = AlexNet()
  print(model)

  input = torch.randn(8,3,224,224)
  out = model(input)
  print(out.shape)

以上这篇PyTorch实现AlexNet示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python遍历文件夹下所有excel文件

大数据处理经常要用到一堆表格,然后需要把数据导入一个list中进行各种算法分析,简单讲一下自己的做法: 1.如何读取excel文件 网上的版本很多,在xlrd模块基础上,找到一些源码...

pycharm设置注释颜色的方法

操作方法如下所示: File-->Settings-->Editor-->Color&Fonts-->LanguageDefaults-->Linecomm...

python hough变换检测直线的实现方法

python hough变换检测直线的实现方法

1 原理  2 检测步骤 将参数空间(ρ,θ) 量化成m*n(m为ρ的等份数,n为θ的等份数)个单元,并设置累加器矩阵,初始值为0; 对图像边界上的每一个点(x,y)带入ρ=...

Django uwsgi Nginx 的生产环境部署详解

配置生产环境 #setting.py 文件中 DEBUG = False # 生产环境 # 允许访问的域名,域名前加一个点表示允许访问该域名下的子域名,比如 www.zmre...

理解Python垃圾回收机制

一.垃圾回收机制 Python中的垃圾回收是以引用计数为主,分代收集为辅。引用计数的缺陷是循环引用的问题。 在Python中,如果一个对象的引用数为0,Python虚拟机就会回收这个对象...