pytorch 常用线性函数详解

yipeiwu_com6年前Python基础

Pytorch的线性函数主要封装了Blas和Lapack,其用法和接口都与之类似。

常用的线性函数如下:

函数 功能
trace 对角线元素之和(矩阵的迹)
diag 对角线元素
triu/tril 矩阵的上三角/下三角,可指定偏移量
mm/bmm 矩阵乘法,batch的矩阵乘法
t 转置
dot/cross 内积/外积
inverse 求逆矩阵
svd 奇异值分解

注意:矩阵的转置会使存储空间不连续,需调用它的.contiguous方法转为连续。

例如:

import torch as t
b=a.t()
b.is_contiguous()
 
输出:False
 
b=b.contiguous()
b.is_contiguous()
 
输出:True

以上这篇pytorch 常用线性函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python实现的字典值比较功能示例

Python实现的字典值比较功能示例

本文实例讲述了Python实现的字典值比较功能。分享给大家供大家参考,具体如下: #coding=utf8 import logging import os from Lib.Dea...

python 实现批量替换文本中的某部分内容

一、介绍 在做YOLOv3项目时,会需要将文本文件中的某部分内容进行批量替换和修改,所以编写了python程序批量替换所有文本文件中特定部分的内容。 二、代码实现 import re...

在python中pandas的series合并方法

如下所示: In [3]: import pandas as pd In [4]: a = pd.Series([1,2,3]) In [5]: b = pd.Series(...

Django REST framework 如何实现内置访问频率控制

对匿名用户采用 IP 控制访问频率,对登录用户采用 用户名 控制访问频率。 from rest_framework.throttling import SimpleRateThrot...

在python中用print()输出多个格式化参数的方法

不废话,直接贴代码: disroot = math.sqrt(deta) root1 = (-b + disroot)/(2*a) root2 = (-b - disroot)/(2...