我一直在努力加快粒子滤波器的重采样计算.由于python有很多方法可以加速它,我虽然会尝试所有这些.不幸的是,numba版本非常慢.由于Numba应该加速,我认为这是我的错误.
我尝试了4个不同的版本:
Numba
蟒蛇
NumPy的
用Cython
每个代码如下:
import numpy as np import scipy as sp import numba as nb from cython_resample import cython_resample @nb.autojit def numba_resample(qs, xs, rands): n = qs.shape[0] lookup = np.cumsum(qs) results = np.empty(n) for j in range(n): for i in range(n): if rands[j] < lookup[i]: results[j] = xs[i] break return results def python_resample(qs, xs, rands): n = qs.shape[0] lookup = np.cumsum(qs) results = np.empty(n) for j in range(n): for i in range(n): if rands[j] < lookup[i]: results[j] = xs[i] break return results def numpy_resample(qs, xs, rands): results = np.empty_like(qs) lookup = sp.cumsum(qs) for j, key in enumerate(rands): i = sp.argmax(lookup>key) results[j] = xs[i] return results #The following is the code for the cython module. It was compiled in a #separate file, but is included here to aid in the question. """ import numpy as np cimport numpy as np cimport cython DTYPE = np.float64 ctypedef np.float64_t DTYPE_t @cython.boundscheck(False) def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs, np.ndarray[DTYPE_t, ndim=1] xs, np.ndarray[DTYPE_t, ndim=1] rands): if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]: raise ValueError("Arrays must have same shape") assert qs.dtype == xs.dtype == rands.dtype == DTYPE cdef unsigned int n = qs.shape[0] cdef unsigned int i, j cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs) cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE) for j in range(n): for i in range(n): if rands[j] < lookup[i]: results[j] = xs[i] break return results """ if __name__ == '__main__': n = 100 xs = np.arange(n, dtype=np.float64) qs = np.array([1.0/n,]*n) rands = np.random.rand(n) print "Timing Numba Function:" %timeit numba_resample(qs, xs, rands) print "Timing Python Function:" %timeit python_resample(qs, xs, rands) print "Timing Numpy Function:" %timeit numpy_resample(qs, xs, rands) print "Timing Cython Function:" %timeit cython_resample(qs, xs, rands)
这导致以下输出:
Timing Numba Function: 1 loops, best of 3: 8.23 ms per loop Timing Python Function: 100 loops, best of 3: 2.48 ms per loop Timing Numpy Function: 1000 loops, best of 3: 793 µs per loop Timing Cython Function: 10000 loops, best of 3: 25 µs per loop
知道为什么numba代码如此之慢?我认为它至少可以与Numpy相媲美.
注意:如果有人对如何加速Numpy或Cython代码示例有任何想法,那也不错:)我的主要问题是关于Numba.
问题是numba无法直觉的类型lookup
.如果你print nb.typeof(lookup)
在你的方法中加入一个,你会看到numba将它视为一个对象,这很慢.通常我会定义lookup
本地字典中的类型,但我得到一个奇怪的错误.相反,我只是创建了一个小包装器,以便我可以显式定义输入和输出类型.
@nb.jit(nb.f8[:](nb.f8[:])) def numba_cumsum(x): return np.cumsum(x) @nb.autojit def numba_resample2(qs, xs, rands): n = qs.shape[0] #lookup = np.cumsum(qs) lookup = numba_cumsum(qs) results = np.empty(n) for j in range(n): for i in range(n): if rands[j] < lookup[i]: results[j] = xs[i] break return results
然后我的时间是:
print "Timing Numba Function:" %timeit numba_resample(qs, xs, rands) print "Timing Revised Numba Function:" %timeit numba_resample2(qs, xs, rands)
Timing Numba Function: 100 loops, best of 3: 8.1 ms per loop Timing Revised Numba Function: 100000 loops, best of 3: 15.3 µs per loop
如果您使用jit
而不是以下情况,您甚至可以更快一点autojit
:
@nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))
对我来说,它将它从15.3微秒降低到12.5微秒,但它仍然令人印象深刻,autojit的表现如何.