pytorch自定义二值化网络层方式

yipeiwu_com5年前Python基础

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

面向初学者的Python编辑器Mu

面向初学者的Python编辑器Mu

Meet Mu,一个开放源码编辑器,使学生们更容易学习编写Python代码。 Mu一个开源编辑器,是满足学生可以轻松学习编写Python代码的工具。作为初学程序员的Python编辑器,旨...

Python的时间模块datetime详解

datetime模块用于是date和time模块的合集,datetime有两个常量,MAXYEAR和MINYEAR,分别是9999和1. datetime模块定义了5个类,分别是 1.d...

Python 3中print函数的使用方法总结

前言 Python 思想:“一切都是对象!”,最近发现python3和python2中print的用法有很多不同,python3中需要使用括号,缩进要使用4个空格(这不是必须的,但你最好...

在Python的Tornado框架中实现简单的在线代理的教程

实现代理的方式很多种,流行的web服务器也大都有代理的功能,比如http://www.tornadoweb.cn用的就是nginx的代理功能做的tornadoweb官网的镜像。 最近,我...

python数据结构之列表和元组的详解

python数据结构之 列表和元组 序列:序列是一种数据结构,它包含的元素都进行了编号(从0开始)。典型的序列包括列表、字符串和元组。其中,列表是可变的(可以进行修改),而元组和字符串...