pytorch自定义二值化网络层方式

yipeiwu_com6年前Python基础

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

对pycharm代码整体左移和右移缩进快捷键的介绍

在使用pycharm时,经常会需要多行代码同时缩进、左移,pycharm提供了快捷方式 1、pycharm使多行代码同时缩进 鼠标选中多行代码后,按下Tab键,一次缩进四个字符 2、py...

python批量将excel内容进行翻译写入功能

由于小编初来乍到,有很多地方不是很到位,还请见谅,但是很实用的哦! 1.首先是需要进行文件的读写操作,需要获取文件路径,方式使用os.listdir(路径)进行批量查找文件。 fil...

python关闭windows进程的方法

本文实例讲述了python关闭windows进程的方法。分享给大家供大家参考。具体如下: 下面的python代码根据进程的名字调用windows的taskkill命令关闭指定的进程...

Python命令行参数解析工具 docopt 安装和应用过程详解

什么是 docopt? 1、docopt 是一种 Python 编写的命令行执行脚本的交互语言。 它是一种语言! 它是一种语言! 它是一种语言! 2、使用这种语言可以在自己的脚本中,添...

详解python播放音频的三种方法

第一种 使用pygame模块 pygame.mixer.init() pygame.mixer.music.load(self.wav_file) pygame.mix...