通用矩阵乘法(General Matrix Multiply,GEMM)是线性代数中的基本运算,也是深度学习和高性能计算的核心操作。
对于矩阵 A∈Rm×k 和 B∈Rk×n,GEMM 计算:
C=αAB+βC
其中:
- C∈Rm×n 是结果矩阵
- α 和 β 是标量系数
- 当 α=1,β=0 时,即为标准矩阵乘法 C=AB
标准矩阵乘法的时间复杂度:
O(m⋅k⋅n)
对于方阵(m=k=n=N):
O(N3)
将大矩阵分解为适合缓存的小块:
for i in 0..M-1 step BM:
for j in 0..N-1 step BN:
for k in 0..K-1 step BK:
C[i:i+BM, j:j+BN] += A[i:i+BM, k:k+BK] * B[k:k+BK, j:j+BN]
通过分治降低复杂度:
O(Nlog27)≈O(N2.81)
| 硬件 |
库/框架 |
特点 |
| CPU |
BLAS/MKL/OpenBLAS |
多线程、SIMD |
| GPU |
cuBLAS/rocBLAS |
Tensor Core、大规模并行 |
| TPU |
XLA |
脉动阵列、专用架构 |
y=Wx+b
其中 Wx 就是 GEMM。
卷积通过 im2col 转换为 GEMM:
im2col(X)⋅reshape(W)
注意力机制中的 Q/K/V 计算:
Attention(Q,K,V)=softmax(dkQKT)V
其中 QKT 是 GEMM。
- GEMM 是底层数值计算原语
- 量化 PM 通常关注策略逻辑,不深入底层实现
- spectrally-aware preconditioner 需要修改 GEMM 的核心计算流程
- 这导致"数学推导过不了投委会"
| 操作 |
计算量 |
典型时间(A100) |
| 矩阵乘法 40963 |
68.7 G FLOPs |
~1 ms |
| 前向传播(ResNet-50) |
4 G FLOPs |
~1 ms |
| 反向传播 |
~2× 前向 |
~2 ms |
创建于:2026-06-11
*来源:栀染《量化交易的深度学习困境》