=================================================================== RCS file: /home/cvs/OpenXM_contrib2/asir2000/engine/mat.c,v retrieving revision 1.3 retrieving revision 1.4 diff -u -p -r1.3 -r1.4 --- OpenXM_contrib2/asir2000/engine/mat.c 2000/08/22 05:04:05 1.3 +++ OpenXM_contrib2/asir2000/engine/mat.c 2002/01/04 17:08:23 1.4 @@ -45,33 +45,36 @@ * DEVELOPER SHALL HAVE NO LIABILITY IN CONNECTION WITH THE USE, * PERFORMANCE OR NON-PERFORMANCE OF THE SOFTWARE. * - * $OpenXM: OpenXM_contrib2/asir2000/engine/mat.c,v 1.2 2000/08/21 08:31:28 noro Exp $ + * $OpenXM: OpenXM_contrib2/asir2000/engine/mat.c,v 1.3 2000/08/22 05:04:05 noro Exp $ */ #include "ca.h" +#include "../parse/parse.h" +extern int StrassenSize; + void addmat(vl,a,b,c) VL vl; MAT a,b,*c; { int row,col,i,j; - MAT t; - pointer *ab,*bb,*tb; - - if ( !a ) - *c = b; - else if ( !b ) - *c = a; - else if ( (a->row != b->row) || (a->col != b->col) ) { - *c = 0; error("addmat : size mismatch"); - } else { - row = a->row; col = a->col; - MKMAT(t,row,col); - for ( i = 0; i < row; i++ ) - for ( j = 0, ab = BDY(a)[i], bb = BDY(b)[i], tb = BDY(t)[i]; - j < col; j++ ) - addr(vl,(Obj)ab[j],(Obj)bb[j],(Obj *)&tb[j]); - *c = t; - } + MAT t; + pointer *ab,*bb,*tb; + + if ( !a ) + *c = b; + else if ( !b ) + *c = a; + else if ( (a->row != b->row) || (a->col != b->col) ) { + *c = 0; error("addmat : size mismatch add"); + } else { + row = a->row; col = a->col; + MKMAT(t,row,col); + for ( i = 0; i < row; i++ ) + for ( j = 0, ab = BDY(a)[i], bb = BDY(b)[i], tb = BDY(t)[i]; + j < col; j++ ) + addr(vl,(Obj)ab[j],(Obj)bb[j],(Obj *)&tb[j]); + *c = t; + } } void submat(vl,a,b,c) @@ -79,24 +82,24 @@ VL vl; MAT a,b,*c; { int row,col,i,j; - MAT t; - pointer *ab,*bb,*tb; + MAT t; + pointer *ab,*bb,*tb; - if ( !a ) - chsgnmat(b,c); - else if ( !b ) - *c = a; - else if ( (a->row != b->row) || (a->col != b->col) ) { - *c = 0; error("submat : size mismatch"); - } else { - row = a->row; col = a->col; - MKMAT(t,row,col); - for ( i = 0; i < row; i++ ) - for ( j = 0, ab = BDY(a)[i], bb = BDY(b)[i], tb = BDY(t)[i]; - j < col; j++ ) - subr(vl,(Obj)ab[j],(Obj)bb[j],(Obj *)&tb[j]); - *c = t; - } + if ( !a ) + chsgnmat(b,c); + else if ( !b ) + *c = a; + else if ( (a->row != b->row) || (a->col != b->col) ) { + *c = 0; error("submat : size mismatch sub"); + } else { + row = a->row; col = a->col; + MKMAT(t,row,col); + for ( i = 0; i < row; i++ ) + for ( j = 0, ab = BDY(a)[i], bb = BDY(b)[i], tb = BDY(t)[i]; + j < col; j++ ) + subr(vl,(Obj)ab[j],(Obj)bb[j],(Obj *)&tb[j]); + *c = t; + } } void mulmat(vl,a,b,c) @@ -235,26 +238,295 @@ void mulmatmat(vl,a,b,c) VL vl; MAT a,b,*c; { +#if 0 int arow,bcol,i,j,k,m; MAT t; pointer s,u,v; pointer *ab,*tb; + /* 行列のcol,rowの数があわない場合 */ if ( a->col != b->row ) { *c = 0; error("mulmat : size mismatch"); } else { arow = a->row; m = a->col; bcol = b->col; - MKMAT(t,arow,bcol); + MKMAt(t,arow,bcol); for ( i = 0; i < arow; i++ ) for ( j = 0, ab = BDY(a)[i], tb = BDY(t)[i]; j < bcol; j++ ) { for ( k = 0, s = 0; k < m; k++ ) { - mulr(vl,(Obj)ab[k],(Obj)BDY(b)[k][j],(Obj *)&u); addr(vl,(Obj)s,(Obj)u,(Obj *)&v); s = v; + mulr(vl,(Obj)ab[k],(Obj)BDY(b)[k][j],(Obj *)&u); + addr(vl,(Obj)s,(Obj)u,(Obj *)&v); + s = v; } tb[j] = s; } *c = t; } } + +void Strassen(arg, c) +NODE arg; +Obj *c; +{ + MAT a,b; + VL vl; + + /* tomo */ + a = (MAT)ARG0(arg); + b = (MAT)ARG1(arg); + vl = CO; + strassen(CO, a, b, c); +} + +void strassen(vl,a,b,c) +VL vl; +MAT a,b,*c; +{ +#endif + int arow,bcol,i,j,k,m, h, arowh, bcolh; + MAT t, a11, a12, a21, a22; + MAT p, b11, b12, b21, b22; + MAT ans1, ans2, ans3, c11, c12, c21, c22; + MAT s1, s2, t1, t2, u1, v1, w1, aa, bb; + pointer s,u,v; + pointer *ab,*tb; + int a1row,a2row, a3row,a4row, a1col, a2col, a3col, a4col; + int b1row,b2row, b3row,b4row, b1col, b2col, b3col, b4col; + int pflag1, pflag2; + /* 行列のcol,rowの数があわない場合 */ + if ( a->col != b->row ) { + *c = 0; error("mulmat : size mismatch"); + } + else { + pflag1 = 0; pflag2 = 0; + arow = a->row; m = a->col; bcol = b->col; + arowh = arow/2; bcolh = bcol/2; + MKMAT(t,arow,bcol); + /* StrassenSize == 0 or matrix size less then StrassenSize, + then calc cannonical algorizm. */ + if((StrassenSize == 0)||(a->row<=StrassenSize || a->col <= StrassenSize)) { + for ( i = 0; i < arow; i++ ) + for ( j = 0, ab = BDY(a)[i], tb = BDY(t)[i]; j < bcol; j++ ) { + for ( k = 0, s = 0; k < m; k++ ) { + mulr(vl,(Obj)ab[k],(Obj)BDY(b)[k][j],(Obj *)&u); + addr(vl,(Obj)s,(Obj)u,(Obj *)&v); + s = v; + } + tb[j] = s; + } + *c = t; + return; + } + /* 行列が奇数次の場合は偶数次になるように0でpadding */ + i = arow/2; + j = arow - i; + if (i != j) { + arow++; + pflag1 = 1; + } + i = m/2; + j = m - i; + if (i != j) { + m++; + pflag2 = 1; + } + MKMAT(aa, arow, m); + for (i = 0; i < a->row; i++) { + for (j = 0; j < a->col; j++) { + aa->body[i][j] = a->body[i][j]; + } + } + i = bcol/2; + j = bcol - i; + if (i != j) { + bcol++; + } + MKMAT(bb, m, bcol); + for (i = 0; i < b->row; i++) { + for ( j = 0; j < b->col; j++) { + bb->body[i][j] = b->body[i][j]; + } + } + + /* 行列A,Bを分割 */ + a1row = aa->row/2; a1col = aa->col/2; + MKMAT(a11,a1row,a1col); + MKMAT(a21,a1row,a1col); + MKMAT(a12,a1row,a1col); + MKMAT(a22,a1row,a1col); + + b1row = bb->row/2; b1col = bb->col/2; + MKMAT(b11,b1row,b1col); + MKMAT(b21,b1row,b1col); + MKMAT(b12,b1row,b1col); + MKMAT(b22,b1row,b1col); + + /* a11の行列を作る */ + for (i = 0; i < a1row; i++) { + for (j = 0; j < a1col; j++) { + a11->body[i][j] = aa->body[i][j]; + } + } + + /* a21の行列を作る */ + for (i = a1row; i < aa->row; i++) { + for (j = 0; j < a1col; j++) { + a21->body[i-a1row][j] = aa->body[i][j]; + } + } + + /* a12の行列を作る */ + for (i = 0; i < a1row; i++) { + for (j = a1col; j < aa->col; j++) { + a12->body[i][j-a1col] = aa->body[i][j]; + } + } + + /* a22の行列を作る */ + for (i = a1row; i < aa->row; i++) { + for (j = a1col; j < aa->col; j++) { + a22->body[i-a1row][j-a1col] = aa->body[i][j]; + } + } + + + /* b11の行列を作る */ + for (i = 0; i < b1row; i++) { + for (j = 0; j < b1col; j++) { + b11->body[i][j] = bb->body[i][j]; + } + } + + /* b21の行列を作る */ + for (i = b1row; i < bb->row; i++) { + for (j = 0; j < b1col; j++) { + b21->body[i-b1row][j] = bb->body[i][j]; + } + } + + /* b12の行列を作る */ + for (i = 0; i < b1row; i++) { + for (j = b1col; j < bb->col; j++) { + b12->body[i][j-b1col] = bb->body[i][j]; + } + } + + /* b22の行列を作る */ + for (i = b1row; i < bb->row; i++) { + for (j = b1col; j < bb->col; j++) { + b22->body[i-b1row][j-b1col] = bb->body[i][j]; + } + } + /* Strassen-Winogradの方法で展開 */ + /* s1=A21+A22 */ + addmat(vl,a21,a22,&s1); + + /* s2=s1-A11 */ + submat(vl,s1,a11,&s2); + + /* t1=B12-B11 */ + submat(vl, b12, b11, &t1); + + /* t2=B22-t1 */ + submat(vl, b22, t1, &t2); + + /* u=(A11-A21)*(B22-B12) */ + submat(vl, a11, a21, &ans1); + submat(vl, b22, b12, &ans2); + mulmatmat(vl, ans1, ans2, &u1); +/* tomo + strassen(vl, ans1, ans2, &u1); +*/ + + /* v=s1*t1 */ + mulmatmat(vl, s1, t1, &v1); +/* tomo + strassen(vl, s1, t1, &v1); +*/ + + /* w=A11*B11+s2*t2 */ + mulmatmat(vl, a11, b11, &ans1); + mulmatmat(vl, s2, t2, &ans2); +/* tomo + strassen(vl, a11, b11, &ans1); + strassen(vl, s2, t2, &ans2); +*/ + addmat(vl, ans1, ans2, &w1); + + /* C11 = A11*B11+A12*B21 */ + mulmatmat(vl, a12, b21, &ans2); +/* tomo + strassen(vl, a12, b21, &ans2); +*/ + addmat(vl, ans1, ans2, &c11); + + /* C12 = w1+v1+(A12-s2)*B22 */ + submat(vl, a12, s2, &ans1); + mulmatmat(vl, ans1, b22, &ans2); +/* tomo + strassen(vl, ans1, b22, &ans2); +*/ + addmat(vl, w1, v1, &ans1); + addmat(vl, ans1, ans2, &c12); + + /* C21 = w1+u1+A22*(B21-t2) */ + submat(vl, b21, t2, &ans1); + mulmatmat(vl, a22, ans1, &ans2); +/* tomo + strassen(vl, a22, ans1, &ans2); +*/ + addmat(vl, w1, u1, &ans1); + addmat(vl, ans1, ans2, &c21); + + /* C22 = w1 + u1 + v1 */ + addmat(vl, ans1, v1, &c22); + + } + + /* 解の領域tに計算結果を戻す */ + for(i =0; irow; i++) { + for ( j=0; j < c11->col; j++) { + t->body[i][j] = c11->body[i][j]; + } + } + if (pflag1 == 0) { + k = c21->row; + } else { + k = c21->row - 1; + } + for(i =0; icol; j++) { + t->body[i+c11->row][j] = c21->body[i][j]; + } + } + if (pflag2 == 0) { + h = c12->col; + } else { + h = c12->col -1; + } + for(i =0; irow; i++) { + for ( j=0; j < k; j++) { + t->body[i][j+c11->col] = c12->body[i][j]; + } + } + if (pflag1 == 0) { + k = c22->row; + } else { + k = c22->row -1; + } + if (pflag2 == 0) { + h = c22->col; + } else { + h = c22->col - 1; + } + for(i =0; ibody[i+c11->row][j+c11->col] = c22->body[i][j]; + } + } + *c = t; +} + + void mulmatvect(vl,a,b,c) VL vl;