Numba代码比纯python慢

 神烟醉_263 发布于 2023-01-30 12:03

我一直在努力加快粒子滤波器的重采样计算.由于python有很多方法可以加速它,我虽然会尝试所有这些.不幸的是,numba版本非常慢.由于Numba应该加速,我认为这是我的错误.

我尝试了4个不同的版本:

    Numba

    蟒蛇

    NumPy的

    用Cython

每个代码如下:

import numpy as np
import scipy as sp
import numba as nb
from cython_resample import cython_resample

@nb.autojit
def numba_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def python_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def numpy_resample(qs, xs, rands):
    results = np.empty_like(qs)
    lookup = sp.cumsum(qs)
    for j, key in enumerate(rands):
        i = sp.argmax(lookup>key)
        results[j] = xs[i]
    return results

#The following is the code for the cython module. It was compiled in a
#separate file, but is included here to aid in the question.
"""
import numpy as np
cimport numpy as np
cimport cython

DTYPE = np.float64

ctypedef np.float64_t DTYPE_t

@cython.boundscheck(False)
def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs, 
             np.ndarray[DTYPE_t, ndim=1] xs, 
             np.ndarray[DTYPE_t, ndim=1] rands):
    if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]:
        raise ValueError("Arrays must have same shape")
    assert qs.dtype == xs.dtype == rands.dtype == DTYPE

    cdef unsigned int n = qs.shape[0]
    cdef unsigned int i, j 
    cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs)
    cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results
"""

if __name__ == '__main__':
    n = 100
    xs = np.arange(n, dtype=np.float64)
    qs = np.array([1.0/n,]*n)
    rands = np.random.rand(n)

    print "Timing Numba Function:"
    %timeit numba_resample(qs, xs, rands)
    print "Timing Python Function:"
    %timeit python_resample(qs, xs, rands)
    print "Timing Numpy Function:"
    %timeit numpy_resample(qs, xs, rands)
    print "Timing Cython Function:"
    %timeit cython_resample(qs, xs, rands)

这导致以下输出:

Timing Numba Function:
1 loops, best of 3: 8.23 ms per loop
Timing Python Function:
100 loops, best of 3: 2.48 ms per loop
Timing Numpy Function:
1000 loops, best of 3: 793 µs per loop
Timing Cython Function:
10000 loops, best of 3: 25 µs per loop

知道为什么numba代码如此之慢?我认为它至少可以与Numpy相媲美.

注意:如果有人对如何加速Numpy或Cython代码示例有任何想法,那也不错:)我的主要问题是关于Numba.

1 个回答
  • 问题是numba无法直觉的类型lookup.如果你print nb.typeof(lookup)在你的方法中加入一个,你会看到numba将它视为一个对象,这很慢.通常我会定义lookup本地字典中的类型,但我得到一个奇怪的错误.相反,我只是创建了一个小包装器,以便我可以显式定义输入和输出类型.

    @nb.jit(nb.f8[:](nb.f8[:]))
    def numba_cumsum(x):
        return np.cumsum(x)
    
    @nb.autojit
    def numba_resample2(qs, xs, rands):
        n = qs.shape[0]
        #lookup = np.cumsum(qs)
        lookup = numba_cumsum(qs)
        results = np.empty(n)
    
        for j in range(n):
            for i in range(n):
                if rands[j] < lookup[i]:
                    results[j] = xs[i]
                    break
        return results
    

    然后我的时间是:

    print "Timing Numba Function:"
    %timeit numba_resample(qs, xs, rands)
    
    print "Timing Revised Numba Function:"
    %timeit numba_resample2(qs, xs, rands)
    

    Timing Numba Function:
    100 loops, best of 3: 8.1 ms per loop
    Timing Revised Numba Function:
    100000 loops, best of 3: 15.3 µs per loop
    

    如果您使用jit而不是以下情况,您甚至可以更快一点autojit:

    @nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))
    

    对我来说,它将它从15.3微秒降低到12.5微秒,但它仍然令人印象深刻,autojit的表现如何.

    2023-01-30 12:06 回答
撰写答案
今天,你开发时遇到什么问题呢?
立即提问
热门标签
PHP1.CN | 中国最专业的PHP中文社区 | PNG素材下载 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有