pytorch自定义二值化网络层方式

yipeiwu_com6年前Python基础

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

在pytorch中查看可训练参数的例子

pytorch中我们有时候可能需要设定某些变量是参与训练的,这时候就需要查看哪些是可训练参数,以确定这些设置是成功的。 pytorch中model.parameters()函数定义如下:...

python使用Image处理图片常用技巧分析

本文实例讲述了python使用Image处理图片常用技巧。分享给大家供大家参考。具体分析如下: 使用python来处理图片是非常方便的,下面提供一小段python处理图片的代码,需要安装...

Python调用C语言的实现

Python中的ctypes模块可能是Python调用C方法中最简单的一种。ctypes模块提供了和C语言兼容的数据类型和函数来加载dll文件,因此在调用时不需对源文件做任何的修改。也正...

python中文件变化监控示例(watchdog)

在python中文件监控主要有两个库,一个是pyinotify ( https://github.com/seb-m/pyinotify/wiki ),一个是watchdog(http:...

TensorFlow打印tensor值的实现方法

TensorFlow打印tensor值的实现方法

最近一直在用TF做CNN的图像分类,当softmax层得到预测结果后,我希望能够看到预测结果,以便和标签之间进行比较。特此补上,以便自己记忆。 我现在通过softmax层得到变量trai...