1. Introduction
Strassen’s Matrix Multiplication is an algorithm for multiplying two matrices. The traditional method to multiply matrices requires O(n^3) operations, but Strassen’s method reduces it to approximately O(n^2.81) using divide-and-conquer technique. The main idea is to divide the input matrices into smaller matrices, then recursively compute the product of these smaller matrices.
2. Program Steps
1. Divide matrices A and B into 2×2 sub-matrices.
2. Formulate seven different products using these sub-matrices.
3. Calculate the final product matrix using these seven products.
3. Code Program
public class StrassenMatrixMultiplication {
// Function to multiply two matrices a and b.
static int[][] multiply(int[][] a, int[][] b, int N) {
int[][] result = new int[N][N];
if (N == 1) {
result[0][0] = a[0][0] * b[0][0];
return result;
}
// Splitting matrices into 2x2 sub-matrices
int[][] a11 = new int[N / 2][N / 2];
int[][] a12 = new int[N / 2][N / 2];
int[][] a21 = new int[N / 2][N / 2];
int[][] a22 = new int[N / 2][N / 2];
int[][] b11 = new int[N / 2][N / 2];
int[][] b12 = new int[N / 2][N / 2];
int[][] b21 = new int[N / 2][N / 2];
int[][] b22 = new int[N / 2][N / 2];
// Filling sub-matrices
for (int i = 0; i < N / 2; i++) {
for (int j = 0; j < N / 2; j++) {
a11[i][j] = a[i][j];
a12[i][j] = a[i][j + N / 2];
a21[i][j] = a[i + N / 2][j];
a22[i][j] = a[i + N / 2][j + N / 2];
b11[i][j] = b[i][j];
b12[i][j] = b[i][j + N / 2];
b21[i][j] = b[i + N / 2][j];
b22[i][j] = b[i + N / 2][j + N / 2];
}
}
// Recursive function calls
int[][] p1 = multiply(a11, subtract(b12, b22, N / 2), N / 2);
int[][] p2 = multiply(add(a11, a12, N / 2), b22, N / 2);
int[][] p3 = multiply(add(a21, a22, N / 2), b11, N / 2);
int[][] p4 = multiply(a22, subtract(b21, b11, N / 2), N / 2);
int[][] p5 = multiply(add(a11, a22, N / 2), add(b11, b22, N / 2), N / 2);
int[][] p6 = multiply(subtract(a12, a22, N / 2), add(b21, b22, N / 2), N / 2);
int[][] p7 = multiply(subtract(a11, a21, N / 2), add(b11, b12, N / 2), N / 2);
// Formulate result matrix
for (int i = 0; i < N / 2; i++) {
for (int j = 0; j < N / 2; j++) {
result[i][j] = p5[i][j] + p4[i][j] - p2[i][j] + p6[i][j];
result[i][j + N / 2] = p1[i][j] + p2[i][j];
result[i + N / 2][j] = p3[i][j] + p4[i][j];
result[i + N / 2][j + N / 2] = p5[i][j] + p1[i][j] - p3[i][j] - p7[i][j];
}
}
return result;
}
// Helper function to add two matrices
static int[][] add(int[][] a, int[][] b, int N) {
int[][] result = new int[N][N];
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
result[i][j] = a[i][j] + b[i][j];
}
}
return result;
}
// Helper function to subtract two matrices
static int[][] subtract(int[][] a, int[][] b, int N) {
int[][] result = new int[N][N];
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
result[i][j] = a[i][j] - b[i][j];
}
}
return result;
}
public static void main(String[] args) {
int N = 2;
int[][] a = { { 1, 2 }, { 3, 4 } };
int[][] b = { { 2, 0 }, { 1, 3 } };
int[][] result = multiply(a, b, N);
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
System.out.print(result[i][j] + " ");
}
System.out.println();
}
}
}
Output:
5 6 9 12
4. Step By Step Explanation
1. The function multiply calculates the product of matrices a and b using Strassen's method.
2. Matrices a and b are divided into 2×2 sub-matrices (a11, a12, a21, a22 and b11, b12, b21, b22).
3. Seven products (p1 to p7) are formulated using these sub-matrices.
4. The final result matrix is computed using these seven products.
5. add and subtract functions are helper functions to perform matrix addition and subtraction respectively.
6. In the main method, we initialize two matrices, a and b, call the multiply function, and print the result.