java实现任意矩阵Strassen算法
作者:SamYjy
这篇文章主要介绍了java实现任意矩阵Strassen算法的相关资料,需要的朋友可以参考下
本例输入为两个任意尺寸的矩阵m * n, n * m,输出为两个矩阵的乘积。计算任意尺寸矩阵相乘时,使用了Strassen算法。程序为自编,经过测试,请放心使用。基本算法是:
1.对于方阵(正方形矩阵),找到最大的l, 使得l = 2 ^ k, k为整数并且l < m。边长为l的方形矩阵则采用Strassen算法,其余部分以及方形矩阵中遗漏的部分用蛮力法。
2.对于非方阵,依照行列相应添加0使其成为方阵。
StrassenMethodTest.java
package matrixalgorithm; import java.util.Scanner; public class StrassenMethodTest { private StrassenMethod strassenMultiply; StrassenMethodTest(){ strassenMultiply = new StrassenMethod(); }//end cons public static void main(String[] args){ Scanner input = new Scanner(System.in); System.out.println("Input row size of the first matrix: "); int arow = input.nextInt(); System.out.println("Input column size of the first matrix: "); int acol = input.nextInt(); System.out.println("Input row size of the second matrix: "); int brow = input.nextInt(); System.out.println("Input column size of the second matrix: "); int bcol = input.nextInt(); double[][] A = new double[arow][acol]; double[][] B = new double[brow][bcol]; double[][] C = new double[arow][bcol]; System.out.println("Input data for matrix A: "); /*In all of the codes later in this project, r means row while c means column. */ for (int r = 0; r < arow; r++) { for (int c = 0; c < acol; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); A[r][c] = input.nextDouble(); }//end inner loop }//end loop System.out.println("Input data for matrix B: "); for (int r = 0; r < brow; r++) { for (int c = 0; c < bcol; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); B[r][c] = input.nextDouble(); }//end inner loop }//end loop StrassenMethodTest algorithm = new StrassenMethodTest(); C = algorithm.multiplyRectMatrix(A, B, arow, acol, brow, bcol); //Display the calculation result: System.out.println("Result from matrix C: "); for (int r = 0; r < arow; r++) { for (int c = 0; c < bcol; c++) { System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]); }//end inner loop }//end outter loop }//end main //Deal with matrices that are not square: public double[][] multiplyRectMatrix(double[][] A, double[][] B, int arow, int acol, int brow, int bcol) { if (arow != bcol) //Invalid multiplicatio return new double[][]{{0}}; double[][] C = new double[arow][bcol]; if (arow < acol) { double[][] newA = new double[acol][acol]; double[][] newB = new double[brow][brow]; int n = acol; for (int r = 0; r < acol; r++) for (int c = 0; c < acol; c++) newA[r][c] = 0.0; for (int r = 0; r < brow; r++) for (int c = 0; c < brow; c++) newB[r][c] = 0.0; for (int r = 0; r < arow; r++) for (int c = 0; c < acol; c++) newA[r][c] = A[r][c]; for (int r = 0; r < brow; r++) for (int c = 0; c < bcol; c++) newB[r][c] = B[r][c]; double[][] C2 = multiplySquareMatrix(newA, newB, n); for(int r = 0; r < arow; r++) for(int c = 0; c < bcol; c++) C[r][c] = C2[r][c]; }//end if else if(arow == acol) C = multiplySquareMatrix(A, B, arow); else { int n = arow; double[][] newA = new double[arow][arow]; double[][] newB = new double[bcol][bcol]; for (int r = 0; r < arow; r++) for (int c = 0; c < arow; c++) newA[r][c] = 0.0; for (int r = 0; r < bcol; r++) for (int c = 0; c < bcol; c++) newB[r][c] = 0.0; for (int r = 0; r < arow; r++) for (int c = 0; c < acol; c++) newA[r][c] = A[r][c]; for (int r = 0; r < brow; r++) for (int c = 0; c < bcol; c++) newB[r][c] = B[r][c]; double[][] C2 = multiplySquareMatrix(newA, newB, n); for(int r = 0; r < arow; r++) for(int c = 0; c < bcol; c++) C[r][c] = C2[r][c]; }//end else return C; }//end method //Deal with matrices that are square matrices. public double[][] multiplySquareMatrix(double[][] A2, double[][] B2, int n){ double[][] C2 = new double[n][n]; for(int r = 0; r < n; r++) for(int c = 0; c < n; c++) C2[r][c] = 0; if(n == 1){ C2[0][0] = A2[0][0] * B2[0][0]; return C2; }//end if int exp2k = 2; while(exp2k <= (n / 2) ){ exp2k *= 2; }//end loop if(exp2k == n){ C2 = strassenMultiply.strassenMultiplyMatrix(A2, B2, n); return C2; }//end else //The "biggest" strassen matrix: double[][][] A = new double[6][exp2k][exp2k]; double[][][] B = new double[6][exp2k][exp2k]; double[][][] C = new double[6][exp2k][exp2k]; for(int r = 0; r < exp2k; r++){ for(int c = 0; c < exp2k; c++){ A[0][r][c] = A2[r][c]; B[0][r][c] = B2[r][c]; }//end inner loop }//end outter loop C[0] = strassenMultiply.strassenMultiplyMatrix(A[0], B[0], exp2k); for(int r = 0; r < exp2k; r++) for(int c = 0; c < exp2k; c++) C2[r][c] = C[0][r][c]; int middle = exp2k / 2; for(int r = 0; r < middle; r++){ for(int c = exp2k; c < n; c++){ A[1][r][c - exp2k] = A2[r][c]; B[3][r][c - exp2k] = B2[r][c]; }//end inner loop }//end outter loop for(int r = exp2k; r < n; r++){ for(int c = 0; c < middle; c++){ A[3][r - exp2k][c] = A2[r][c]; B[1][r - exp2k][c] = B2[r][c]; }//end inner loop }//end outter loop for(int r = middle; r < exp2k; r++){ for(int c = exp2k; c < n; c++){ A[2][r - middle][c - exp2k] = A2[r][c]; B[4][r - middle][c - exp2k] = B2[r][c]; }//end inner loop }//end outter loop for(int r = exp2k; r < n; r++){ for(int c = middle; c < n - exp2k + 1; c++){ A[4][r - exp2k][c - middle] = A2[r][c]; B[2][r - exp2k][c - middle] = B2[r][c]; }//end inner loop }//end outter loop for(int i = 1; i <= 4; i++) C[i] = multiplyRectMatrix(A[i], B[i], middle, A[i].length, A[i].length, middle); /* Calculate the final results of grids in the "biggest 2^k square, according to the rules of matrice multiplication. */ for (int row = 0; row < exp2k; row++) { for (int col = 0; col < exp2k; col++) { for (int k = exp2k; k < n; k++) { C2[row][col] += A2[row][k] * B2[k][col]; }//end loop }//end inner loop }//end outter loop //Use brute force to solve the rest, will be improved later: for(int col = exp2k; col < n; col++){ for(int row = 0; row < n; row++){ for(int k = 0; k < n; k++) C2[row][col] += A2[row][k] * B2[k][row]; }//end inner loop }//end outter loop for(int row = exp2k; row < n; row++){ for(int col = 0; col < exp2k; col++){ for(int k = 0; k < n; k++) C2[row][col] += A2[row][k] * B2[k][row]; }//end inner loop }//end outter loop return C2; }//end method }//end class
StrassenMethod.java
package matrixalgorithm; import java.util.Scanner; public class StrassenMethod { private double[][][][] A = new double[2][2][][]; private double[][][][] B = new double[2][2][][]; private double[][][][] C = new double[2][2][][]; /*//Codes for testing this class: public static void main(String[] args) { Scanner input = new Scanner(System.in); System.out.println("Input size of the matrix: "); int n = input.nextInt(); double[][] A = new double[n][n]; double[][] B = new double[n][n]; double[][] C = new double[n][n]; System.out.println("Input data for matrix A: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); A[r][c] = input.nextDouble(); }//end inner loop }//end loop System.out.println("Input data for matrix B: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); B[r][c] = input.nextDouble(); }//end inner loop }//end loop StrassenMethod algorithm = new StrassenMethod(); C = algorithm.strassenMultiplyMatrix(A, B, n); System.out.println("Result from matrix C: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]); }//end inner loop }//end outter loop }//end main*/ public double[][] strassenMultiplyMatrix(double[][] A2, double B2[][], int n){ double[][] C2 = new double[n][n]; //Initialize the matrix: for(int rowIndex = 0; rowIndex < n; rowIndex++) for(int colIndex = 0; colIndex < n; colIndex++) C2[rowIndex][colIndex] = 0.0; if(n == 1) C2[0][0] = A2[0][0] * B2[0][0]; //"Slice matrices into 2 * 2 parts: else{ double[][][][] A = new double[2][2][n / 2][n / 2]; double[][][][] B = new double[2][2][n / 2][n / 2]; double[][][][] C = new double[2][2][n / 2][n / 2]; for(int r = 0; r < n / 2; r++){ for(int c = 0; c < n / 2; c++){ A[0][0][r][c] = A2[r][c]; A[0][1][r][c] = A2[r][n / 2 + c]; A[1][0][r][c] = A2[n / 2 + r][c]; A[1][1][r][c] = A2[n / 2 + r][n / 2 + c]; B[0][0][r][c] = B2[r][c]; B[0][1][r][c] = B2[r][n / 2 + c]; B[1][0][r][c] = B2[n / 2 + r][c]; B[1][1][r][c] = B2[n / 2 + r][n / 2 + c]; }//end loop }//end loop n = n / 2; double[][][] S = new double[10][n][n]; S[0] = minusMatrix(B[0][1], B[1][1], n); S[1] = addMatrix(A[0][0], A[0][1], n); S[2] = addMatrix(A[1][0], A[1][1], n); S[3] = minusMatrix(B[1][0], B[0][0], n); S[4] = addMatrix(A[0][0], A[1][1], n); S[5] = addMatrix(B[0][0], B[1][1], n); S[6] = minusMatrix(A[0][1], A[1][1], n); S[7] = addMatrix(B[1][0], B[1][1], n); S[8] = minusMatrix(A[0][0], A[1][0], n); S[9] = addMatrix(B[0][0], B[0][1], n); double[][][] P = new double[7][n][n]; P[0] = strassenMultiplyMatrix(A[0][0], S[0], n); P[1] = strassenMultiplyMatrix(S[1], B[1][1], n); P[2] = strassenMultiplyMatrix(S[2], B[0][0], n); P[3] = strassenMultiplyMatrix(A[1][1], S[3], n); P[4] = strassenMultiplyMatrix(S[4], S[5], n); P[5] = strassenMultiplyMatrix(S[6], S[7], n); P[6] = strassenMultiplyMatrix(S[8], S[9], n); C[0][0] = addMatrix(minusMatrix(addMatrix(P[4], P[3], n), P[1], n), P[5], n); C[0][1] = addMatrix(P[0], P[1], n); C[1][0] = addMatrix(P[2], P[3], n); C[1][1] = minusMatrix(minusMatrix(addMatrix(P[4], P[0], n), P[2], n), P[6], n); n *= 2; for(int r = 0; r < n / 2; r++){ for(int c = 0; c < n / 2; c++){ C2[r][c] = C[0][0][r][c]; C2[r][n / 2 + c] = C[0][1][r][c]; C2[n / 2 + r][c] = C[1][0][r][c]; C2[n / 2 + r][n / 2 + c] = C[1][1][r][c]; }//end inner loop }//end outter loop }//end else return C2; }//end method //Add two matrices according to matrix addition. private double[][] addMatrix(double[][] A, double[][] B, int n){ double C[][] = new double[n][n]; for(int r = 0; r < n; r++) for(int c = 0; c < n; c++) C[r][c] = A[r][c] + B[r][c]; return C; }//end method //Substract two matrices according to matrix addition. private double[][] minusMatrix(double[][] A, double[][] B, int n){ double C[][] = new double[n][n]; for(int r = 0; r < n; r++) for(int c = 0; c < n; c++) C[r][c] = A[r][c] - B[r][c]; return C; }//end method }//end class
希望本文所述对大家学习java程序设计有所帮助。