矩阵链乘法问题
问题描述
在科学计算中,我们经常需要计算多个矩阵的乘积。由于矩阵乘法满足结合律,而各个矩阵的规模可能差异很大,通过改变乘法的计算顺序(即添加括号),对实际的计算量会有显著的影响。
对于矩阵 $A \in \mathbb{R}^{r \times s}, B \in \mathbb{R}^{s \times t}$,计算乘积 $A B \in \mathbb{R}^{r \times t}$ 的每一项都需要 $s$ 次乘法,因此乘法计算量为 $r s t$。
例如计算 $A B C$,A 的尺寸为 500×2,B 的尺寸为 2×500,C 的尺寸为 500×2,下面两种运算顺序的乘法运算量有巨大差异:
- $(A B) C$ 的运算量为 $O(500\times 2 \times 500) + O(500 \times 500 \times 2)$;
- $A (B C)$ 的运算量为 $O(2\times 500 \times 2) + O(500 \times 2 \times 2)$。
我们希望使用最优的一种做法来加速计算,这就是矩阵链乘法(Matrix-chain multiplication)问题,这是通过动态规划求解的一个典型问题。
记矩阵序列为 $A_1, A_2, \dots, A_m$,其中 $A_i \in \mathbb{R}^{P_{i-1} \times P_i}$,求乘法顺序使得运算需要的乘法计算量最少。
输入:矩阵个数 $m$,以及尺寸参数数组 $(P_i)_{i = 0}^m$。
输出:一棵表示最优乘法顺序的有序二叉树。叶子节点对应矩阵 $A_i$;非叶子节点对应一次矩阵乘法,记录其覆盖的区间 $A_i \cdots A_j$、最优切分点和该区间的最小计算量,左右孩子分别表示最优切分得到的左右子区间。
动态规划
令 $M[i, j]$ 表示计算矩阵链 $A_i A_{i + 1} \cdots A_j$ 所需的最少标量乘法次数,其中 $1 \leq i \leq j \leq m$。
当 $i = j$ 时,不需要乘法,显然有 $M[i, i] = 0$。
当 $i < j$ 时,最后一次乘法一定是将矩阵链切成左右两段:
$$
(A_i \cdots A_k)(A_{k + 1} \cdots A_j), \quad i \leq k < j.
$$
左侧结果矩阵的规模是 $P_{i - 1} \times P_k$,右侧结果矩阵的规模是 $P_k \times P_j$,所以最后一次乘法的计算量是 $P_{i - 1} P_k P_j$。
因此状态转移方程为
$$
M[i, j] = \min_{i \leq k < j}
{M[i, k] + M[k + 1, j] + P_{i - 1} P_k P_j}.
$$
为了恢复最优乘法顺序,需要额外记录达到最小值时的切分位置:
$$
S[i, j] = \arg \min_{i \leq k < j}
{M[i, k] + M[k + 1, j] + P_{i - 1} P_k P_j}.
$$
具体计算时,需要保存两张 $m \times m$ 的二维表:
- $M[i, j]$:区间 $A_i \cdots A_j$ 的最小计算量,其中 $1 \leq i \leq j \leq m$;
- $S[i, j]$:区间 $A_i \cdots A_j$ 取得最优值时的切分点,其中 $1 \leq i < j \leq m$。
对于 $i > j$ 的位置没有意义,不需要使用。$M$ 真正有效的是带主对角线的上三角部分,即 $1 \leq i \leq j \leq m$;$S$ 真正有效的是不带主对角线的上三角部分,即 $1 \leq i < j \leq m$。主对角线 $M[i, i]$ 全部初始化为 $0$,表示单个矩阵不需要计算,而 $S[i, i]$ 没有切分点,不需要定义。
从表的角度看,填表顺序就是沿着上三角部分逐条次对角线计算。主对角线对应 $j - i = 0$,已经初始化为 $0$;接下来计算 $j - i = 1$ 的第一条次对角线,也就是所有长度为 $2$ 的矩阵链;再计算 $j - i = 2$ 的次对角线,也就是所有长度为 $3$ 的矩阵链;依次类推,直到右上角的 $M[1, m]$。
也就是令偏移量 $d = j - i$,按照 $d$ 从小到大的顺序计算并填表:
1 | for d = 1, 2, ..., m - 1: |
在计算 $M[i, j]$ 时,所有可能用到的 $M[i, k]$ 和 $M[k + 1, j]$ 都是更短的区间,因此已经准备好了。
分析易得,动态规划中总的状态数为 $O(m^2)$,每个状态最多枚举 $O(m)$ 个切分点,因此时间复杂度为 $O(m^3)$,空间复杂度为 $O(m^2)$。
Python 实现
下面给出 Python 实现。节点中的矩阵编号从 $1$ 开始,与上面的公式保持一致,但是两个表的编号仍然从 0 开始。
1 | from dataclasses import dataclass |
其中:
- 数据结构
cost[i][j]对应公式中的 $M[i + 1, j + 1]$;split[i][j]对应公式中的 $S[i + 1, j + 1]$;
- 后处理
parenthesize递归生成对应的括号表达式print_tree递归打印二叉树信息
测试代码
1 | p = [30, 35, 15, 5, 10, 20, 25] |
输出如下
1 | 15125 |
