pytorch自定义二值化网络层方式

yipeiwu_com5年前Python基础

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python映射列表实例分析

本文实例讲述了python映射列表。分享给大家供大家参考。具体分析如下: 列表映射是个非常有用的方法,通过对列表的每个元素应用一个函数来转换数据,可以使用一种策略或者方法来遍历计算每个元...

基于Python实现拆分和合并GIF动态图

基于Python实现拆分和合并GIF动态图

“表情包”是当前社交软件上不可或缺的交流方式,难以用文字表达的意思,发一个“表情包”,对方就能心领神会。下面是小派制作的一个表情包,准确地讲,是在已有表情包的基础上,二次加工而成的。 下...

python调用Moxa PCOMM Lite通过串口Ymodem协议实现发送文件

本文实例讲述python调用Moxa PCOMM Lite通过串口Ymodem协议实现发送文件的方法,该程序采用python 2.7编写。主要内容如下: 经过长期搜寻,终于找到了Moxa...

把JSON数据格式转换为Python的类对象方法详解(两种方法)

把JSON数据格式转换为Python的类对象方法详解(两种方法)

JOSN字符串转换为自定义类实例对象 有时候我们有这种需求就是把一个JSON字符串转换为一个具体的Python类的实例,比如你接收到这样一个JSON字符串如下: {"Name": "...

PyCharm+Qt Designer+PyUIC安装配置教程详解

PyCharm+Qt Designer+PyUIC安装配置教程详解

Qt Designer用于像VC++的MFC一样拖放、设计控件 PyUIC用于将Qt Designer生成的.ui文件转换成.py文件 Qt Designer和PyUIC都包含在PyQt...