1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
| vector<vector<int>> add(const vector<vector<int>>& A, const vector<vector<int>>& B) { int n = A.size(); vector<vector<int>> C(n, vector<int>(n)); for (int i = 0; i < n; ++i) for (int j = 0; j < n; ++j) C[i][j] = A[i][j] + B[i][j]; return C; }
vector<vector<int>> subtract(const vector<vector<int>>& A, const vector<vector<int>>& B) { int n = A.size(); vector<vector<int>> C(n, vector<int>(n)); for (int i = 0; i < n; ++i) for (int j = 0; j < n; ++j) C[i][j] = A[i][j] - B[i][j]; return C; }
vector<vector<int>> multiply(const vector<vector<int>>& A, const vector<vector<int>>& B) { int n = A.size(); if (n == 1) { vector<vector<int>> C(1, vector<int>(1)); C[0][0] = A[0][0] * B[0][0]; return C; }
int mid = n / 2; vector<vector<int>> A11(mid, vector<int>(mid)), A12(mid, vector<int>(mid)), A21(mid, vector<int>(mid)), A22(mid, vector<int>(mid)); vector<vector<int>> B11(mid, vector<int>(mid)), B12(mid, vector<int>(mid)), B21(mid, vector<int>(mid)), B22(mid, vector<int>(mid));
for (int i = 0; i < mid; ++i) for (int j = 0; j < mid; ++j) { A11[i][j] = A[i][j]; A12[i][j] = A[i][j + mid]; A21[i][j] = A[i + mid][j]; A22[i][j] = A[i + mid][j + mid];
B11[i][j] = B[i][j]; B12[i][j] = B[i][j + mid]; B21[i][j] = B[i + mid][j]; B22[i][j] = B[i + mid][j + mid]; }
vector<vector<int>> M1 = multiply(add(A11, A22), add(B11, B22)); vector<vector<int>> M2 = multiply(add(A21, A22), B11); vector<vector<int>> M3 = multiply(A11, subtract(B12, B22)); vector<vector<int>> M4 = multiply(A22, subtract(B21, B11)); vector<vector<int>> M5 = multiply(add(A11, A12), B22); vector<vector<int>> M6 = multiply(subtract(A21, A11), add(B11, B12)); vector<vector<int>> M7 = multiply(subtract(A12, A22), add(B21, B22));
vector<vector<int>> C(n, vector<int>(n)); vector<vector<int>> C11 = add(subtract(add(M1, M4), M5), M7); vector<vector<int>> C12 = add(M3, M5); vector<vector<int>> C21 = add(M2, M4); vector<vector<int>> C22 = add(subtract(add(M1, M3), M2), M6);
for (int i = 0; i < mid; ++i) for (int j = 0; j < mid; ++j) { C[i][j] = C11[i][j]; C[i][j + mid] = C12[i][j]; C[i + mid][j] = C21[i][j]; C[i + mid][j + mid] = C22[i][j]; }
return C; }
|