问题描述

在科学计算中,我们经常需要计算多个矩阵的乘积。由于矩阵乘法满足结合律,而各个矩阵的规模可能差异很大,通过改变乘法的计算顺序(即添加括号),对实际的计算量会有显著的影响。

对于矩阵 $A \in \mathbb{R}^{r \times s}, B \in \mathbb{R}^{s \times t}$,计算乘积 $A B \in \mathbb{R}^{r \times t}$ 的每一项都需要 $s$ 次乘法,因此乘法计算量为 $r s t$。

例如计算 $A B C$,A 的尺寸为 500×2,B 的尺寸为 2×500,C 的尺寸为 500×2,下面两种运算顺序的乘法运算量有巨大差异:

  • $(A B) C$ 的运算量为 $O(500\times 2 \times 500) + O(500 \times 500 \times 2)$;
  • $A (B C)$ 的运算量为 $O(2\times 500 \times 2) + O(500 \times 2 \times 2)$。

我们希望使用最优的一种做法来加速计算,这就是矩阵链乘法(Matrix-chain multiplication)问题,这是通过动态规划求解的一个典型问题。

记矩阵序列为 $A_1, A_2, \dots, A_m$,其中 $A_i \in \mathbb{R}^{P_{i-1} \times P_i}$,求乘法顺序使得运算需要的乘法计算量最少。

输入:矩阵个数 $m$,以及尺寸参数数组 $(P_i)_{i = 0}^m$。

输出:一棵表示最优乘法顺序的有序二叉树。叶子节点对应矩阵 $A_i$;非叶子节点对应一次矩阵乘法,记录其覆盖的区间 $A_i \cdots A_j$、最优切分点和该区间的最小计算量,左右孩子分别表示最优切分得到的左右子区间。

动态规划

令 $M[i, j]$ 表示计算矩阵链 $A_i A_{i + 1} \cdots A_j$ 所需的最少标量乘法次数,其中 $1 \leq i \leq j \leq m$。

当 $i = j$ 时,不需要乘法,显然有 $M[i, i] = 0$。

当 $i < j$ 时,最后一次乘法一定是将矩阵链切成左右两段:

$$
(A_i \cdots A_k)(A_{k + 1} \cdots A_j), \quad i \leq k < j.
$$

左侧结果矩阵的规模是 $P_{i - 1} \times P_k$,右侧结果矩阵的规模是 $P_k \times P_j$,所以最后一次乘法的计算量是 $P_{i - 1} P_k P_j$。

因此状态转移方程为

$$
M[i, j] = \min_{i \leq k < j}
{M[i, k] + M[k + 1, j] + P_{i - 1} P_k P_j}.
$$

为了恢复最优乘法顺序,需要额外记录达到最小值时的切分位置:

$$
S[i, j] = \arg \min_{i \leq k < j}
{M[i, k] + M[k + 1, j] + P_{i - 1} P_k P_j}.
$$

具体计算时,需要保存两张 $m \times m$ 的二维表:

  • $M[i, j]$:区间 $A_i \cdots A_j$ 的最小计算量,其中 $1 \leq i \leq j \leq m$;
  • $S[i, j]$:区间 $A_i \cdots A_j$ 取得最优值时的切分点,其中 $1 \leq i < j \leq m$。

对于 $i > j$ 的位置没有意义,不需要使用。$M$ 真正有效的是带主对角线的上三角部分,即 $1 \leq i \leq j \leq m$;$S$ 真正有效的是不带主对角线的上三角部分,即 $1 \leq i < j \leq m$。主对角线 $M[i, i]$ 全部初始化为 $0$,表示单个矩阵不需要计算,而 $S[i, i]$ 没有切分点,不需要定义。

从表的角度看,填表顺序就是沿着上三角部分逐条次对角线计算。主对角线对应 $j - i = 0$,已经初始化为 $0$;接下来计算 $j - i = 1$ 的第一条次对角线,也就是所有长度为 $2$ 的矩阵链;再计算 $j - i = 2$ 的次对角线,也就是所有长度为 $3$ 的矩阵链;依次类推,直到右上角的 $M[1, m]$。

也就是令偏移量 $d = j - i$,按照 $d$ 从小到大的顺序计算并填表:

1
2
3
4
for d = 1, 2, ..., m - 1:
for i = 1, 2, ..., m - d:
j = i + d
compute M[i, j] and S[i, j]

在计算 $M[i, j]$ 时,所有可能用到的 $M[i, k]$ 和 $M[k + 1, j]$ 都是更短的区间,因此已经准备好了。

分析易得,动态规划中总的状态数为 $O(m^2)$,每个状态最多枚举 $O(m)$ 个切分点,因此时间复杂度为 $O(m^3)$,空间复杂度为 $O(m^2)$。

Python 实现

下面给出 Python 实现。节点中的矩阵编号从 $1$ 开始,与上面的公式保持一致,但是两个表的编号仍然从 0 开始。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from dataclasses import dataclass


@dataclass
class Node:
i: int
j: int
cost: int
split: int | None = None
left: "Node | None" = None
right: "Node | None" = None

@property
def is_leaf(self) -> bool:
return self.i == self.j


def matrix_chain_order(p: list[int]) -> tuple[int, Node]:
"""Return minimal cost and an ordered binary tree for matrix-chain multiplication.

p has length m + 1, and matrix A_i has shape p[i - 1] * p[i].
"""
m = len(p) - 1
if m <= 0:
raise ValueError("at least one matrix is required")

cost: list[list[int]] = [[0] * m for _ in range(m)]
split: list[list[int | None]] = [[None] * m for _ in range(m)]

for length in range(2, m + 1):
for i in range(0, m - length + 1):
j = i + length - 1
best, best_k = min(
(cost[i][k] + cost[k + 1][j] + p[i] * p[k + 1] * p[j + 1], k)
for k in range(i, j)
)
cost[i][j] = best
split[i][j] = best_k

def build(i: int, j: int) -> Node:
if i == j:
return Node(i=i + 1, j=j + 1, cost=0)

k = split[i][j]
if k is None:
raise RuntimeError("invalid split table")

return Node(
i=i + 1,
j=j + 1,
cost=cost[i][j],
split=k + 1,
left=build(i, k),
right=build(k + 1, j),
)

root = build(0, m - 1)
return cost[0][m - 1], root


def parenthesize(node: Node) -> str:
if node.is_leaf:
return f"A{node.i}"

assert node.left is not None and node.right is not None
return f"({parenthesize(node.left)} {parenthesize(node.right)})"


def print_tree(node: Node, indent: str = "") -> None:
if node.is_leaf:
print(f"{indent}A{node.i}: cost={node.cost}")
return

print(f"{indent}A{node.i}..A{node.j}: cost={node.cost}, split={node.split}")
assert node.left is not None and node.right is not None
print_tree(node.left, indent + " ")
print_tree(node.right, indent + " ")

其中:

  • 数据结构
    • cost[i][j] 对应公式中的 $M[i + 1, j + 1]$;
    • split[i][j] 对应公式中的 $S[i + 1, j + 1]$;
  • 后处理
    • parenthesize 递归生成对应的括号表达式
    • print_tree 递归打印二叉树信息

测试代码

1
2
3
4
5
6
p = [30, 35, 15, 5, 10, 20, 25]
min_cost, tree = matrix_chain_order(p)

print(min_cost)
print(parenthesize(tree))
print_tree(tree)

输出如下

1
2
3
4
5
6
7
8
9
10
11
12
13
15125
((A1 (A2 A3)) ((A4 A5) A6))
A1..A6: cost=15125, split=3
A1..A3: cost=7875, split=1
A1: cost=0
A2..A3: cost=2625, split=2
A2: cost=0
A3: cost=0
A4..A6: cost=3500, split=5
A4..A5: cost=1000, split=4
A4: cost=0
A5: cost=0
A6: cost=0