高度な einsum 演算
これで基本的なeinsum
演算に慣れたので、もっと高度な応用例をいくつか見てみましょう。これらの演算は、einsum
関数の真の力と柔軟性を示しています。
対角成分の抽出
行列の対角成分を抽出することは、線形代数で一般的な演算です。行列 A について、その対角成分はベクトル d を形成し、次のようになります。
d_i = A_{ii}
einsum
を使って対角成分を抽出する方法は次の通りです。
## Create a random square matrix
A = np.random.rand(4, 4)
print("Matrix A:")
print(A)
## Extract diagonal using einsum
diagonal = np.einsum('ii->i', A)
print("\nDiagonal elements using einsum:")
print(diagonal)
## Verify with NumPy's diagonal function
numpy_diagonal = np.diagonal(A)
print("\nDiagonal elements using np.diagonal():")
print(numpy_diagonal)
表記 'ii->i'
は次の意味を持ちます。
ii
は A の対角成分の繰り返しインデックスを表します
i
はこれらの要素を 1 次元配列に抽出することを意味します
行列の跡
行列の跡は、その対角成分の和です。行列 A について、その跡は次のようになります。
\text{trace}(A) = \sum_i A_{ii}
einsum
を使って跡を計算する方法は次の通りです。
## Using the same matrix A from above
trace = np.einsum('ii->', A)
print("Trace of matrix A using einsum:", trace)
## Verify with NumPy's trace function
numpy_trace = np.trace(A)
print("Trace of matrix A using np.trace():", numpy_trace)
表記 'ii->'
は次の意味を持ちます。
ii
は対角成分の繰り返しインデックスを表します
- 空の出力インデックスは、すべての対角成分を合計してスカラーを得ることを意味します
バッチ行列乗算
einsum
は、多次元配列に対する演算を行う際に本当に威力を発揮します。たとえば、バッチ行列乗算では、2 つのバッチから行列のペアを乗算します。
形状が (n, m, p) の行列のバッチ A と、形状が (n, p, q) の行列のバッチ B がある場合、バッチ行列乗算によって形状が (n, m, q) の結果 C が得られます。
C_{ijk} = \sum_l A_{ijl} \times B_{ilk}
einsum
を使ってバッチ行列乗算を行う方法は次の通りです。
## Create batches of matrices
n, m, p, q = 5, 3, 4, 2 ## Batch size and matrix dimensions
A = np.random.rand(n, m, p) ## Batch of 5 matrices, each 3x4
B = np.random.rand(n, p, q) ## Batch of 5 matrices, each 4x2
print("Shape of batch A:", A.shape)
print("Shape of batch B:", B.shape)
## Batch matrix multiplication using einsum
C = np.einsum('nmp,npq->nmq', A, B)
print("\nShape of result batch C:", C.shape) ## Should be (5, 3, 2)
## Let's check the first matrix multiplication in the batch
print("\nFirst result matrix from batch using einsum:")
print(C[0])
## Verify with NumPy's matmul function
numpy_batch_matmul = np.matmul(A, B)
print("\nFirst result matrix from batch using np.matmul:")
print(numpy_batch_matmul[0])
表記 'nmp,npq->nmq'
は次の意味を持ちます。
nmp
はバッチ A のインデックスを表します(n はバッチ、m は行、p は列)
npq
はバッチ B のインデックスを表します(n はバッチ、p は行、q は列)
nmq
は出力バッチ C のインデックスを表します(n はバッチ、m は行、q は列)
- 繰り返されるインデックス
p
は合計されます(行列の乗算)
einsum を使う理由
NumPy がこれらの演算に対して専用の関数を提供しているのに、なぜeinsum
を使うのか疑問に思うかもしれません。以下にいくつかの利点を挙げます。
- 統一的なインターフェース:
einsum
は多くの配列演算に対して単一の関数を提供します。
- 柔軟性:他の方法では複数のステップが必要な演算を表現できます。
- 可読性:表記法を理解すれば、コードがより簡潔になります。
- パフォーマンス:多くの場合、
einsum
演算は最適化されており、効率的です。
複雑なテンソル演算において、einsum
はしばしば最も明確で直接的な実装を提供します。