pytorch自定义二值化网络层方式

yipeiwu_com6年前Python基础

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

用python + openpyxl处理excel2007文档思路以及心得

寻觅工具 确定任务之后第一步就是找个趁手的库来干活。 Python Excel上列出了xlrd、xlwt、xlutils这几个包,但是 它们都比较老,xlwt甚至不支持07版以后的exc...

从列表或字典创建Pandas的DataFrame对象的方法

从列表或字典创建Pandas的DataFrame对象的方法

介绍 每当我使用pandas进行分析时,我的第一个目标是使用众多可用选项中的一个将数据导入Pandas的DataFrame 。 对于绝大多数情况下,我使用的 read_excel ,...

python机器学习理论与实战(一)K近邻法

python机器学习理论与实战(一)K近邻法

机器学习分两大类,有监督学习(supervised learning)和无监督学习(unsupervised learning)。有监督学习又可分两类:分类(classification...

解决python3中cv2读取中文路径的问题

如下所示: python3: img_path =  ' ' im = cv2.imdecode(np.fromfile(img_path,dtype = np.uin...

解读! Python在人工智能中的作用

人工智能是一种未来性的技术,目前正在致力于研究自己的一套工具。一系列的进展在过去的几年中发生了:无事故驾驶超过300000英里并在三个州合法行驶迎来了自动驾驶的一个里程碑;IBM Was...