张量操作python中的内存和时间

前端之家收集整理的这篇文章主要介绍了张量操作python中的内存和时间 前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。

目标
我的目标是计算公式的张量,您可以在下面看到.索引i,j,k,l从0到40以及p,m,x从0到80.

The formula for the tensor

Tensordot方法这个总和只是收缩6个巨大张量的索引.我试图通过允许这种计算的张量点来做到这一点,但是即使我先执行一个张量点然后再执行另一个张量点,我的问题还是内存. (我在colab中工作,所以我有12GB的可用内存)

嵌套循环法但是控制B矩阵有一些附加对称性,即B {ijpx}的唯一非零元素是i j = p x.因此,我能够根据x(p = i jx,m = k lx)来写p和m,然后我做了5个循环,分别是i,l,x,但是另一方面这是一个问题,因为计算需要136秒,我想重复很多次.

嵌套循环方法中的定时目标将时间减少10倍将是令人满意的,但如果可以将其减少100倍,则绰绰有余.

您是否有解决内存问题或减少时间的想法?您如何处理带有附加约束的求和?

(注:矩阵A是对称的,到目前为止我还没有使用过这一事实.不再有对称性.)

这是嵌套循环的代码

for i in range (0,40):
  for j in range (0,40):
    for k in range (0,40):
      for l in range (0,40):
            Sum=0
            for x in range (0,80):
              p=i+j-x
              m=k+l-x
              if p>=0 and p<80 and m>=0 and m<80:
                Sum += A[p,m]*B[i,p,x]*B[k,x]
            T[i,l]= Sum

以及张量点方法代码

P=np.tensordot(A,B,axes=((0),(2)))
T=np.tensordot(P,axes=((0,3),(2,3)))
最佳答案
Numba可能是您最好的选择.我根据您的代码组合了此功能.为了避免不必要的迭代和if块,我对其进行了一些更改:

import numpy as np
import numba as nb

@nb.njit(parallel=True)
def my_formula_nb(A,B):
    di,dj,dx,_ = B.shape
    T = np.zeros((di,di,dj),dtype=A.dtype)
    for i in nb.prange (di):
        for j in nb.prange (dj):
            for k in nb.prange (di):
                for l in nb.prange (dj):
                    sum = 0
                    x_start = max(0,i + j - dx + 1,k + l - dx + 1)
                    x_end = min(dx,i + j + 1,k + l + 1)
                    for x in range(x_start,x_end):
                        p = i + j - x
                        m = k + l - x
                        sum += A[p,m] * B[i,x] * B[k,x]
                    T[i,l] = sum
    return T

让我们来看看它的作用:

import numpy as np

def make_problem(di,dx):
    a = np.random.rand(dx,dx)
    a = a + a.T
    b = np.random.rand(di,dx)
    b_ind = np.indices(b.shape)
    b_mask = b_ind[0] + b_ind[1] != b_ind[2] + b_ind[3]
    b[b_mask] = 0
    return a,b

# Generate a problem
np.random.seed(100)
a,b = make_problem(15,20,25)
# Solve with Numba function
t1 = my_formula_nb(a,b)
# Solve with einsum
t2 = np.einsum('pm,ijpx,klmx->ijkl',a,b,b)
# Check result
print(np.allclose(t1,t2))
# True

# Benchmark (IPython)
%timeit np.einsum('pm,b)
# 4.5 s ± 39.2 ms per loop (mean ± std. dev. of 7 runs,1 loop each)
%timeit my_formula_nb(a,b)
# 6.06 ms ± 20.4 µs per loop (mean ± std. dev. of 7 runs,100 loops each)

如您所见,Numba解决方案的速度提高了大约三个数量级,并且不应占用过多的内存.

猜你在找的Python相关文章