热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

一个函数打天下,einsum

作者丨永远在你身后知乎来源丨https:zhuanlan.zhihu.comp71639781编辑丨极市平台einsum全称Einsteinsummationconvention&

作者丨永远在你身后@知乎

来源丨https://zhuanlan.zhihu.com/p/71639781

编辑丨极市平台

einsum全称Einstein summation convention(爱因斯坦求和约定),又称为爱因斯坦标记法,是爱因斯坦1916年提出的一种标记约定,简单的说就是省去求和式中的求和符号,例如下面的公式:

以einsum的写法就是:

后者将  符号给省去了,显得更加简洁;再比如:

 (1)

 (2)

上面两个栗子换成einsum的写法就变成:

 (1)

 (2)

在实现一些算法时,数学表达式已经求出来了,需要将之转换为代码实现,简单的一些还好,有时碰到例如矩阵转置、矩阵乘法、求迹、张量乘法、数组求和等等,若是以分别以transopse、sum、trace、tensordot等函数实现的话,不但复杂,还容易出错

现在,这些问题你统统可以一个函数搞定,没错,就是einsum,einsum函数就是根据上面的标记法实现的一种函数,可以根据给定的表达式进行运算,可以替代但不限于以下函数:

矩阵求迹:trace
求矩阵对角线:diag
张量(沿轴)求和:sum
张量转置:transopose
矩阵乘法:dot
张量乘法:tensordot
向量内积:inner
外积:outer

该函数在numpy、tensorflow、pytorch上都有实现,用法基本一样,定义如下:

einsum(equation, *operands)

equation是字符串的表达式,operands是操作数,是一个元组参数,并不是只能有两个,所以只要是能够通过einsum标记法表示的乘法求和公式,都可以用一个einsum解决,下面以numpy举几个栗子:

# 沿轴计算张量元素之和:
c = a.sum(axis=0)

上面的以sum函数的实现代码,设  为三维张量,上面代码用公式来表达的话就是:

换成einsum标记法:

然后根据此式使用einsum函数实现等价功能:

c = np.einsum('ijk->jk', a)
# 作用与 c = a.sum(axis=0) 一样

更进一步的,如果  不止是三维,可以将下标  换成省略号,以表示剩下的所有维度:

c = np.einsum('i...->...', a)

这种写法pytorch与tensorflow同样支持,如果不是很理解的话,可以查看其对应的公式:

# 矩阵乘法
c = np.dot(a, b)

矩阵乘法的公式为:

然后是einsum对应的实现:

c = np.einsum('ij,jk->ik', a, b)

最后再举一个张量乘法栗子:

# 张量乘法
c = np.tensordot(a, b, ([0, 1], [0, 1]))

如果  是三维的,对应的公式为:

对应的einsum实现:

c = np.einsum('ijk,ijl->kl', a, b)

下面以numpy做一下测试,对比einsum与各种函数的速度,这里使用python内建的timeit模块进行时间测试,先测试(四维)两张量相乘然后求所有元素之和,对应的公式为:

然后是测试代码:

from timeit import Timer
import numpy as np# 定义两个全局变量
a = np.random.rand(64, 128, 128, 64)
b = np.random.rand(64, 128, 128, 64)# 定义使用einsum与sum的函数
def einsum():temp = np.einsum('ijkl,ijkl->', a, b)def npsum():temp = (a * b).sum()# 打印运行时间
print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(20))
print("npsum cost:", Timer("npsum()", "from __main__ import npsum").timeit(20))

上面Timer是timeit模块内的一个类

Timer(stmt, setup).timeit(number)# stmt: 要测试的语句# setup: 传入stmt的运行环境,比如stmt中要导入的模块等。# 可以写一行语句,也可以写多行语句,写多行语句时要用分号;隔开语句# number: 执行次数

将两个函数各执行20遍,最后的结果为,单位为秒:

einsum cost: 1.5560735
npsum cost: 8.0874927

可以看到,einsum比sum快了几乎一个量级,接下来测试单个张量求和:

将上面的代码改一下:

def einsum():temp = np.einsum('ijkl->', a)def npsum():temp = a.sum()

相应的运行时间为:

einsum cost: 3.2716003
npsum cost: 6.7865246

还是einsum更快,所以哪怕是单个张量求和,numpy上也可以用einsum替代,同样,求均值(mean)、方差(var)、标准差(std)也是一样

接下来测试einsum与dot函数,首先列一下矩阵乘法的公式以以及einsum表达式:

然后是测试代码:

a = np.random.rand(2024, 2024)
b = np.random.rand(2024, 2024)# einsum与dot比较
def einsum():res = np.einsum('ik,kj->ij', a, b)def dot():res = np.dot(a, b)print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(20))
print("dot cost:", Timer("dot()", "from __main__ import dot").timeit(20))# einsum cost: 80.2403851
# dot cost: 2.0842243

这就很尴尬了,比dot慢了40倍(并且差距随着矩阵规模的平方增加),这还怎么打天下?不过在numpy的实现里,einsum是可以进行优化的,去掉不必要的中间结果,减少不必要的转置、变形等等,可以提升很大的性能,将einsum的实现改一下:

def einsum():res = np.einsum('ik,kj->ij', a, b, optimize=True)

加了一个参数optimize=True,官方文档上该参数是可选参数,接受4个值:

optimize : {False, True, ‘greedy’, ‘optimal’}, optional

optimize默认为False,如果设为True,这默认选择‘greedy(贪心)’方式,再看看速度:

einsum cost: 2.0330937
dot cost: 1.9866218

可以看到,通过优化,虽然还是稍慢一些,但是einsum的速度与dot达到了一个量级;不过numpy官方手册上有个einsum_path,说是可以进一步提升速度,但是我在自己电脑上(i7-9750H)测试效果并不稳定,这里简单的介绍一下该函数的用法为:

path = np.einsum_path('ik,kj->ij', a, b)[0]
np.einsum('ik,kj->ij', a, b, optimize=path)

einsum_path返回一个einsum可使用的优化路径列表,一般使用第一个优化路径;另外,optimize及einsum_path函数只有numpy实现了,tensorflow和pytorch上至少现在没有

最后,再测试einsum与另一个常用的函数tensordot,首先定义两个四维张量的及tensordot函数:

a = np.random.rand(128, 128, 64, 64)
b = np.random.rand(128, 128, 64, 64)def tensordot():res = np.tensordot(a, b, ([0, 1], [0, 1]))

该实现对应的公式为:

所以einsum函数的实现为:

def einsum():res = np.einsum('ijkl,ijmn->klmn', a, b, optimize=True)

tensordot也是链接到BLAS实现的函数,所以不加optimize肯定比不了,最后结果为:

print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(1))
print("tensordot cost:", Timer("tensordot()", "from __main__ import tensordot").timeit(1))# einsum cost: 4.2361331
# tensordot cost: 4.2580409

测试了10多次,基本上速度一样,einsum表现好一点的;不过说是一个函数打天下,肯定是做不到的,还有一些数组的分割、合并、指数、对数等功能没法实现,需要使用别的函数,其他的基本都可以用einsum来实现,简单而又高效

之后经过进一步测试发现,优化反而出现速度降低的情况,例如:

def einsum():temp = einsum('...->', a, optimize=True)def test():temp = a.sum()

上面两中对数组求和的方法,当a是一维向量时,或者a是多维但是规模很小是,优化的einsum反而更慢,但是去掉optimize参数后表现比内置的sum函数稍好,我认为优化是有一个固定的成本

还有一个坑需要注意的是,有些情况的省略号不加optimize会报错,就拿上面的栗子而言:

np.einsum('...->', a, optimize=True) # 正常运行
np.einsum('...->', a) # 报错

很无奈,试了很多次,不加optimize就是会报错,但是并不是所有的省略号写法都需要加optimize,例如:

使用省略号实现上面两个公式并不需要加optimize,能够正常运行

np.einsum('i...->...', a) # 正常
np.einsum('...,...->...', a, b) # 正常

但是如果碰到下面的公式:

上式表示将a除第一个维度之外,剩下的维度全部累加,这种实现就必须要加optimize

np.einsum('i...->i', a, optimize=True) # 必须加optimize,不然报错

再举一个栗子:

c = (a * b).sum()
# 如果不知道a, b的维数,使用einsum实现上面的功能也必须要加optimize
c = einsum('...,...->', a, b, optimize=True)

总结一下,在计算量很小时,优化因为有一定的成本,所以速度会慢一些;但是,既然计算量小,慢一点又怎样呢,而且使用优化之后,可以更加肆意的使用省略号写表达式,变量的维数也不用考虑了,所以建议无脑使用优化。

觉得有用麻烦给个在看啦~  



推荐阅读
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 本文介绍了P1651题目的描述和要求,以及计算能搭建的塔的最大高度的方法。通过动态规划和状压技术,将问题转化为求解差值的问题,并定义了相应的状态。最终得出了计算最大高度的解法。 ... [详细]
  • STL迭代器的种类及其功能介绍
    本文介绍了标准模板库(STL)定义的五种迭代器的种类和功能。通过图表展示了这几种迭代器之间的关系,并详细描述了各个迭代器的功能和使用方法。其中,输入迭代器用于从容器中读取元素,输出迭代器用于向容器中写入元素,正向迭代器是输入迭代器和输出迭代器的组合。本文的目的是帮助读者更好地理解STL迭代器的使用方法和特点。 ... [详细]
  • [翻译]PyCairo指南裁剪和masking
    裁剪和masking在PyCairo指南的这个部分,我么将讨论裁剪和masking操作。裁剪裁剪就是将图形的绘制限定在一定的区域内。这样做有一些效率的因素࿰ ... [详细]
  • 本文为Codeforces 1294A题目的解析,主要讨论了Collecting Coins整除+不整除问题。文章详细介绍了题目的背景和要求,并给出了解题思路和代码实现。同时提供了在线测评地址和相关参考链接。 ... [详细]
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 本文讨论了一个数列求和问题,该数列按照一定规律生成。通过观察数列的规律,我们可以得出求解该问题的算法。具体算法为计算前n项i*f[i]的和,其中f[i]表示数列中有i个数字。根据参考的思路,我们可以将算法的时间复杂度控制在O(n),即计算到5e5即可满足1e9的要求。 ... [详细]
  • Java SE从入门到放弃(三)的逻辑运算符详解
    本文详细介绍了Java SE中的逻辑运算符,包括逻辑运算符的操作和运算结果,以及与运算符的不同之处。通过代码演示,展示了逻辑运算符的使用方法和注意事项。文章以Java SE从入门到放弃(三)为背景,对逻辑运算符进行了深入的解析。 ... [详细]
  • Android自定义控件绘图篇之Paint函数大汇总
    本文介绍了Android自定义控件绘图篇中的Paint函数大汇总,包括重置画笔、设置颜色、设置透明度、设置样式、设置宽度、设置抗锯齿等功能。通过学习这些函数,可以更好地掌握Paint的用法。 ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • Python教学练习二Python1-12练习二一、判断季节用户输入月份,判断这个月是哪个季节?3,4,5月----春 ... [详细]
  • 很多时候在注册一些比较重要的帐号,或者使用一些比较重要的接口的时候,需要使用到随机字符串,为了方便,我们设计这个脚本需要注意 ... [详细]
  • 假设我有两个数组A和B,其中A和B都是mxn.我现在的目标是,对于A和B的每一行,找到我应该在B的相应行中插入A的第i行元素的位置.也就是说,我希望将np.digitize或np. ... [详细]
author-avatar
小帅哥小羊儿_309
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有