pytorch自定义二值化网络层方式

yipeiwu_com6年前Python基础

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python 实现简单的FTP程序

python 实现简单的FTP程序

FTP即文件传输协议;它基于客户机-服务器模型体系结构,应用广泛。它有两个通道:一个命令通道和一个数据通道。命令通道用于控制通信,数据通道用于文件的实际传输。使用FTP可以做很多事情,比...

python利用正则表达式排除集合中字符的功能示例

前言 我们在之前学习过通过集合枚举的功能,把所有需要出现的字符列出来,保存在集合里面,这样正则表达式就可以根据集合里的字符是否存在来判断是否匹配成功,如果在集合里,就匹配成功,否则不成功...

python中将字典转换成其json字符串

#这是Python中的一个字典 dic = { 'str': 'this is a string', 'list': [1, 2, 'a', 'b'], 'sub_dic': {...

关于python的list相关知识(推荐)

如下所示,一起跟随小编过来看看吧! list01 = ['alex',12,65,'xiaodong',100,'chen',5] list02 = [67,7,'jinjiao_daw...

python如何使用socketserver模块实现并发聊天

这篇文章主要介绍了python如何使用socketserver模块实现并发聊天,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 利用so...