C语言实现Strassen算法计算两个矩阵积

        如果你学过线性代数,说到计算矩阵乘法,那么我们一般常规操作一般公式是\sum_{i,j=1}^nAij*Bji;即A的每一行乘B的每一列,但是需要注意的是A的列数要与B的行数要相等,否则A与B不可做计算操作,对应的项乘积进行相加处理,就可得到C_{ii} 储存在在C中对应的位置,直到每一项被计算完成。这便是矩阵乘法的定义。

        上面是对矩阵成积计算方法定义的陈述,但是如果要把它转为程序,计算机读懂的语言,我们首先会想到使用三层for循环,一层用于C_{ii} 赋值操作,两层用于计算\sum_{i,j=1}^nAij*Bji ,那么程序基本会定义为如下程序:

for(i=0;i<n;i++)

    for(j=0;j<n;j++)

        for(m=0;m<n;m++)

            C_{[i][j]} +=A_{[j][k]}*B_{[k][j]}

        由于有三个for循环,那么这个计算方法时间复杂度O (n^3),这个计算方法效率相对有些低了,为此出现了Strassen计算方法,但是这个算法的要求是两个矩阵的size n*n ,n要满足:n=2^m,m\in *N。Strassen 算法类似于分治的思想,先将矩阵分解为更小的矩阵,直到n=1,即矩阵的size 1*1,再计算两个数乘积返回数值即可。

        根据Strassen 计算法则可以知道先将矩阵A,B进行拆分,使得子矩阵size满足2*2,如下所示:


1.A&B二分逻辑图

根据Strassen计算法则有以下公式要计算.

P_{1}=A_{11} *(B_{12}-B_{22}  )

P_{2}=(A_{11} +A_{12} ) *B_{22}

P_{3} =(A_{21}+A_{22} )*B_{11}

P_{4}=A_{22}*(B_{21}-B_{11}  ) 

P_{5}=(A_{11}+A_{22} )*(B_{11} +B_{22} )

P_{6}=(A_{12}-A_{22}  )*(B_{21} +B_{22} )

P_{7}=(A_{11}-A_{21}  )*(B_{11}+B_{12}  )

C_{11}=P_{5}+P_{4} -P_{2} +P_{6} 

C_{12} =P_{1} +P_{2}

C_{21} =P_{3} +P_{4}

C_{22} =P_{5} +P_{1} -P_{3} -P_{7}

        根据上诉计算规则,我们只要对分解到符合条件的矩阵,都可使用上诉公式进行计算.分解示例图如下:


2.矩阵二分解逻辑示例图

        说到这里,想必你应该对Strassen计算矩阵规则有所了解了,其实就是对矩阵分解后作加,减,乘三种运算,那么接下去就是用程序实现了。源程序如下所示:

#include <stdio.h>

#include<stdlib.h>


/******************************************

*

* function add_matrix()

*

* args sub_ma1,sub_ma2 inttype **pointers

* n the size of sub_ma1 sub_ma2 n*n

*

* return matrix **inttype

*

* *****************************************/

int** add_matrix(int** sub_ma1, int** sub_ma2, int n) {

    int** temp = init_matrix(n);

    for(int i=0; i<n; i++)

        for(int j=0; j<n; j++)

            temp[i][j] = sub_ma1[i][j] + sub_ma2[i][j];

    return temp;

}

/******************************************

*

* function subtract_matrix()

*

* args sub_ma1,sub_ma2 inttype **pointers

* n the size of sub_ma1 sub_ma2 n*n

*

* return matrix **inttype

*

* *****************************************/

int** subtract_matrix(int** sub_ma1, int** sub_ma2, int n) {

    int** temp = init_matrix(n);

    for(int i=0; i<n; i++)

        for(int j=0; j<n; j++)

            temp[i][j] = sub_ma1[i][j] - sub_ma2[i][j];

    return temp;

}

/**********************************************

*

* function square_matrix_mutiply_recursive()

*

* args

* A,B, inttype ** pointer

* n inttype matrix size n*n

*

* return C inttype array

*

************************************************/

int** square_matrix_strassen_mutiply(int **A,int **B,int n){

    int i,j;

    //only one element

    if (n == 1) {

        int** C = init_matrix(1);

        C[0][0] = A[0][0] * B[0][0];

        return C;

    }

    else{

        //init C,A,B

        int** C = init_matrix(n);

        int k = n/2;

        int** A11 = init_matrix(k);

        int** A12 = init_matrix(k);

        int** A21 = init_matrix(k);

        int** A22 = init_matrix(k);

        int** B11 = init_matrix(k);

        int** B12 = init_matrix(k);

        int** B21 = init_matrix(k);

        int** B22 = init_matrix(k);

        //resolve A,B matrixs to A11...A22,B11...B22

        for(i=0; i<k; i++)

            for(j=0; j<k; j++) {

                A11[i][j] = A[i][j];

                A12[i][j] = A[i][k+j];

                A21[i][j] = A[k+i][j];

                A22[i][j] = A[k+i][k+j];

                B11[i][j] = B[i][j];

                B12[i][j] = B[i][k+j];

                B21[i][j] = B[k+i][j];

                B22[i][j] = B[k+i][k+j];

            }

        //calculate P[1-7]

        int** P1 = square_matrix_strassen_mutiply(A11, subtract_matrix(B12, B22, k), k);

        int** P2 = square_matrix_strassen_mutiply(add_matrix(A11, A12, k), B22, k);

        int** P3 = square_matrix_strassen_mutiply(add_matrix(A21, A22, k), B11, k);

        int** P4 = square_matrix_strassen_mutiply(A22, subtract_matrix(B21, B11, k), k);

        int** P5 = square_matrix_strassen_mutiply(add_matrix(A11, A22, k), add_matrix(B11, B22, k), k);

        int** P6 = square_matrix_strassen_mutiply(subtract_matrix(A12, A22, k), add_matrix(B21, B22, k), k);

        int** P7 = square_matrix_strassen_mutiply(subtract_matrix(A11, A21, k), add_matrix(B11, B12, k), k);

        //calculate C11.....C22

        int** C11 = subtract_matrix(add_matrix(add_matrix(P5, P4, k), P6, k), P2, k);

        int** C12 = add_matrix(P1, P2, k);

        int** C21 = add_matrix(P3, P4, k);

        int** C22 = subtract_matrix(subtract_matrix(add_matrix(P5, P1, k), P3, k), P7, k);

        //copy C11,C12,C13,C14 to C

        for(i=0; i<k; i++)

            for(j=0; j<k; j++) {

                C[i][j] = C11[i][j];

                C[i][j+k] = C12[i][j];

                C[k+i][j] = C21[i][j];

                C[k+i][k+j] = C22[i][j];

            }

        //free memory

        for(i=0;i<k;i++){

            // free subarrays of A,B

            free(A11[i]);

            free(A12[i]);

            free(A21[i]);

            free(A22[i]);

            free(B11[i]);

            free(B12[i]);

            free(B21[i]);

            free(B22[i]);

            //free subarray of P

            free(P1[i]);

            free(P2[i]);

            free(P3[i]);

            free(P4[i]);

            free(P5[i]);

            free(P6[i]);

            free(P7[i]);

            //free subarray of C

            free(C11[i]);

            free(C12[i]);

            free(C21[i]);

            free(C22[i]);

        }

        //free rows of A,B

        free(A11);

        free(A12);

        free(A21);

        free(A22);

        free(B11);

        free(B12);

        free(B21);

        free(B22);

        //free rows of P

        free(P1);

        free(P2);

        free(P3);

        free(P4);

        free(P5);

        free(P6);

        free(P7);

        //free rows of C

        free(C11);

        free(C12);

        free(C21);

        free(C22);

        //make NULL

        A11=NULL;

        A12=NULL;

        A21=NULL;

        A22=NULL;

        B11=NULL;

        B12=NULL;

        B21=NULL;

        B22=NULL;

        P1=NULL;

        P2=NULL;

        P3=NULL;

        P4=NULL;

        P5=NULL;

        P6=NULL;

        P7=NULL;

        C11=NULL;

        C12=NULL;

        C21=NULL;

        C22=NULL;

        return C;

    }

}

int main()

{

    int n=2;

    int i,j;

    // init A,B

    int**A=init_matrix(2);

    int**B=init_matrix(2);

    A[0][0]=1;

    A[0][1]=3;

    A[1][0]=7;

    A[1][1]=5;

    B[0][0]=6;

    B[0][1]=8;

    B[1][0]=4;

    B[1][1]=2;

    //use square_matrix_strassen_mutiply()

    int **C=square_matrix_strassen_mutiply(A,B,n);

    //output C

    printf("the result of C is:\n");

    for(i=0;i<n;i++){

        for(j=0;j<n;j++)

            printf("%d ",*(*(C+i)+j));

        printf("\n"); //next line

    }

    return 0;

}

/****************************************

*

* input A={{1,3},{7,5}} B={{6,8},{4,2}}

*

* output C={{18,14},{62,66}};

*

* **************************************/

以上就是对Strassen算法实现的C源程序。

运行结果截图如下:


3.运行结果截图

有兴趣的朋友,可以一起交流,共同进步。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
禁止转载,如需转载请通过简信或评论联系作者。

相关阅读更多精彩内容

  • 今天感恩节哎,感谢一直在我身边的亲朋好友。感恩相遇!感恩不离不弃。 中午开了第一次的党会,身份的转变要...
    余生动听阅读 10,752评论 0 11
  • 彩排完,天已黑
    刘凯书法阅读 4,425评论 1 3
  • 表情是什么,我认为表情就是表现出来的情绪。表情可以传达很多信息。高兴了当然就笑了,难过就哭了。两者是相互影响密不可...
    Persistenc_6aea阅读 128,862评论 2 7

友情链接更多精彩内容