在numpy中乘以对数概率矩阵的数值稳定方法

 何其何从丶 发布于 2023-01-15 13:41

我需要获取包含对数概率的两个NumPy矩阵(或其他2d数组)的矩阵乘积.np.log(np.dot(np.exp(a), np.exp(b)))出于显而易见的原因,天真的方式不是优选的.

运用

from scipy.misc import logsumexp
res = np.zeros((a.shape[0], b.shape[1]))
for n in range(b.shape[1]):
    # broadcast b[:,n] over rows of a, sum columns
    res[:, n] = logsumexp(a + b[:, n].T, axis=1) 

工作但运行速度比慢100倍 np.log(np.dot(np.exp(a), np.exp(b)))

运用

logsumexp((tile(a, (b.shape[1],1)) + repeat(b.T, a.shape[0], axis=0)).reshape(b.shape[1],a.shape[0],a.shape[1]), 2).T

或者其他瓦片和重塑的组合也起作用,但是比上面的循环运行得更慢,因为实际大小的输入矩阵需要非常大量的存​​储器.

我目前正在考虑在C中编写一个NumPy扩展来计算它,但当然我宁愿避免这种情况.是否有既定的方法来执行此操作,或者是否有人知道执行此计算的内存密集程度较低的方法?

编辑: 感谢larsmans提供此解决方案(参见下面的推导):

def logdot(a, b):
    max_a, max_b = np.max(a), np.max(b)
    exp_a, exp_b = a - max_a, b - max_b
    np.exp(exp_a, out=exp_a)
    np.exp(exp_b, out=exp_b)
    c = np.dot(exp_a, exp_b)
    np.log(c, out=c)
    c += max_a + max_b
    return c

logdot_old使用iPython的magic %timeit函数快速比较此方法与上面发布的方法()会产生以下结果:

In  [1] a = np.log(np.random.rand(1000,2000))

In  [2] b = np.log(np.random.rand(2000,1500))

In  [3] x = logdot(a, b)

In  [4] y = logdot_old(a, b) # this takes a while

In  [5] np.any(np.abs(x-y) > 1e-14)
Out [5] False

In  [6] %timeit logdot_old(a, b)
1 loops, best of 3: 1min 18s per loop

In  [6] %timeit logdot(a, b)
1 loops, best of 3: 264 ms per loop

显然larsmans的方法抹杀了我的!

撰写答案
今天,你开发时遇到什么问题呢?
立即提问
热门标签
PHP1.CN | 中国最专业的PHP中文社区 | PNG素材下载 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有