Strassen 矩阵算法是一种分治策略,能帮助我们更高效地计算两个矩阵(尺寸为 n X n)的乘积。
我们可以先了解一下 Strassen 矩阵法的相关知识。但是,这种方法通常需要硬记一些复杂的方程,所以下面我将分享最简单的记忆方法:
!<a href="https://media.geeksforgeeks.org/wp-content/uploads/stressenformulanewnew1.png">stressenformulanewnew
我们只需要记住 4 条规则:
- AHED (读作 ‘Ahead‘)
- Diagonal (对角线)
- Last CR (最后一行/列)
- First CR (第一行/列)
> 另外,请将 X 矩阵看作(行 +),将 Y 矩阵看作(列 -)。
让我们按照以下步骤操作:
- 写下 P1 = A; P2 = H; P3 = E; P4 = D
- 关于 P5:我们将使用 对角线规则,即(矩阵 X 对角元素之和) (矩阵 Y 对角元素之和)。我们得到 P5 = (A + D) (E + H)
- 关于 P6:我们将使用 Last CR 规则,即 X 的最后一列和 Y 的最后一行。请记住“行加列减”的原则,即 (B – D) (G + H)。我们得到 P6 = (B – D) (G + H)
- 关于 P7:我们将使用 First CR 规则,即 X 的第一列和 Y 的第一行。同样记住“行加列减”,即 (A – C) (E + F)。我们得到 P7 = (A – C) (E + F)
- 回到 P1:这里我们有 A,它在 Y 矩阵中的相邻元素是 E。因为 Y 是列矩阵,我们需要在 Y 中选择一列使得 E 不出现。我们找到了 F H 列,所以将 A 乘以 (F – H)。最终 P1 = A * (F – H)
- 回到 P2:这里我们有 H,它在 X 矩阵中的相邻元素是 D。因为 X 是行矩阵,我们需要在 X 中选择一行使得 D 不出现。我们找到了 A B 列,所以将 H 乘以 (A + B)。最终 P2 = (A + B) * H
- 回到 P3:这里我们有 E,它在 X 矩阵中的相邻元素是 A。因为 X 是行矩阵,我们需要在 X 中选择一行使得 A 不出现。我们找到了 C D 列,所以将 E 乘以 (C + D)。最终 P3 = (C + D) * E
- 回到 P4:这里我们有 D,它在 Y 矩阵中的相邻元素是 H。因为 Y 是列矩阵,我们需要在 Y 中选择一列使得 H 不出现。我们找到了 G E 列,所以将 D 乘以 (G – E)。最终 P4 = D * (G – E)
- 记住计数顺序:在 C2 位置写下 P1 + P2
- 在对角线位置(即 C3)写下 P3 + P4
- 在第 1 个位置写下 P4 + P5 + P6 并减去 P2,即 C1 = P4 + P5 + P6 – P2
- 在最后的位置写下奇数值,并使用交替的 – 和 + 号,即 P1 P3 P5 P7 变为 C4 = P1 – P3 + P5 – P7
代码实现:
C++
“`
#include
#include
#define vi vector
#define vii vector
using namespace std;
/ 查找下一个 2 的次幂 /
int nextPowerOf2(int k)
{
return pow(2, int(ceil(log2(k))));
}
// 打印矩阵
void display(vii C, int m, int n)
{
for (int i = 0; i < m; i++)
{
cout << "|"
<< " ";
for (int j = 0; j < n; j++)
{
cout << C[i][j] << " ";
}
cout << "|" << endl;
}
}
//! 矩阵加法和减法
void add(vii &A, vii &B, vii &C, int size)
{
for (int i = 0; i < size; i++)
{
for (int j = 0; j < size; j++)
{
C[i][j] = A[i][j] + B[i][j];
}
}
}
void sub(vii &A, vii &B, vii &C, int size)
{
for (int i = 0; i < size; i++)
{
for (int j = 0; j < size; j++)
{
C[i][j] = A[i][j] – B[i][j];
}
}
}
//!—————————–
void Strassen_algorithm(vii &A, vii &B, vii &C, int size)
{
if (size == 1)
{
C[0][0] = A[0][0] * B[0][0];
return;
}
else
{
int newSize = size / 2;
vi z(newSize);
vii a(newSize, z), b(newSize, z), c(newSize, z), d(newSize, z),
e(newSize, z), f(newSize, z), g(newSize, z), h(newSize, z),
c11(newSize, z), c12(newSize, z), c21(newSize, z), c22(newSize, z),
p1(newSize, z), p2(newSize, z), p3(newSize, z), p4(newSize, z),
p5(newSize, z), p6(newSize, z), p7(newSize, z), fResult(newSize, z),
sResult(newSize, z);
int i, j;
//! 将矩阵分成相等的部分
for (i = 0; i < newSize; i++)
{
for (j = 0; j < newSize; j++)
{
a[i][j] = A[i][j];
b[i][j] = A[i][j + newSize];
c[i][j] = A[i + newSize][j];
d[i][j] = A[i + newSize][j + newSize];
e[i][j] = B[i][j];
f[i][j] = B[i][j + newSize];
g[i][j] = B[i + newSize][j];
h[i][j] = B[i + newSize][j + newSize];
}
}
/*
A B C
[a b] * [e f] = [c11 c12]
[c d] [g h] [c21 c22]
p1,p2,p3,p4=AHED 对应:A:行(+) 且 B:列(-)
p5=对角线 :两者均为正
p6=Last CR :A:行(-) B:列(+)
p7=First CR :A:行(-) B:列(+)
*/
//! 计算所有的 Stra