pytorch自定义二值化网络层方式

yipeiwu_com6年前Python基础

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Django项目主urls导入应用中views的红线问题解决

Django项目主urls导入应用中views的红线问题解决

使用PyCharm学习Django框架,从项目的主urls中导入app中的views的时候,导入的包中下面有红线报错,但是却能正常使用。要是这样也就没什么事了,但是导入之后的提示功能就丧...

wxPython窗口中文乱码解决方法

本文实例讲述了wxPython窗口中文乱码解决方法,分享给大家供大家参考。具体方法如下: 文件保存为 utf-8 文件开头添加 # -*- coding: utf-8 -*- 在有中文字...

python中装饰器级连的使用方法示例

前言 最近在学习python,学会了为什么要使用装饰器,也明白了装饰器是什么了,但是你也许会问,是否可以在装饰器前面再添加一层装饰器,会怎么样呢?就像大楼一样,一层一层地叠在一起。其实是...

python+splinter自动刷新抢票功能

抢票脚本,python +splinter自动刷新抢票,可以成功抢到(依赖自己的网络环境太厉害,还有机器的好坏),但是感觉不是很完美。 有大神请指导完善一下(或者有没有别的好点的思路),...

python绘制圆柱体的方法

python绘制圆柱体的方法

本文实例为大家分享了python绘制圆柱体示的具体代码,供大家参考,具体内容如下 #!/usr/bin/env python import vtk # 参考的C++版本源码及解释...