[BACK]Return to polmul.c CVS log [TXT][DIR] Up to [local] / OpenXM_contrib2 / asir2000 / fft

Annotation of OpenXM_contrib2/asir2000/fft/polmul.c, Revision 1.1.1.1

1.1       noro        1: /* $OpenXM: OpenXM/src/asir99/fft/polmul.c,v 1.1.1.1 1999/11/10 08:12:27 noro Exp $ */
                      2: #include "dft.h"
                      3: extern struct PrimesS Primes[];
                      4:
                      5: /*
                      6: #define TIMING
                      7: */
                      8:
                      9: #ifdef TIMING
                     10:
                     11: #define MAXTIMING 20
                     12:
                     13: #include <sys/time.h>
                     14: #include <sys/resource.h>
                     15: #if 0
                     16: #if define(hitm)
                     17: #include <machine/xclock.h>
                     18: #endif
                     19: #endif
                     20:
                     21: static struct rusage ru_time[MAXTIMING];
                     22:
                     23: struct timing {
                     24:   struct timeval user, sys;
                     25: };
                     26: static struct timing time_duration[MAXTIMING];
                     27:
                     28: #define RECORD_TIME(i) getrusage( RUSAGE_SELF, &ru_time[i] )
                     29: #endif /* TIMING */
                     30:
                     31: void FFT_primes(Num, p_prime, p_primroot, p_d)
                     32: int Num, *p_prime, *p_primroot, *p_d;
                     33: {
                     34:   *p_prime =  Primes[Num].prime;
                     35:   *p_primroot =  Primes[Num].primroot;
                     36:   *p_d = Primes[Num].d;
                     37: }
                     38:
                     39: int FFT_pol_square( d1, C1, Prod,  MinMod, wk)
                     40: unsigned int d1;
                     41: int MinMod;
                     42: unsigned int C1[], Prod[], wk[];
                     43: {
                     44:
                     45:   unsigned int  Low0bits;
                     46:   unsigned int Proot, Prime;
                     47:   double Pinv;
                     48:   void MNpol_square_DFT();
                     49:
                     50:    if ( MinMod < 0  || MinMod > NPrimes - 1 ) return 2;
                     51:    Prime =  Primes[MinMod].prime;
                     52:    Proot =  Primes[MinMod].primroot;
                     53:    Low0bits = Primes[MinMod].d;
                     54:    Pinv = (((double)1.0)/((double)Prime));
                     55:
                     56:     MNpol_square_DFT( d1, C1, Prod, Proot, Low0bits, Prime, Pinv, wk );
                     57:
                     58:     return 0;
                     59: }
                     60:
                     61: int FFT_pol_product( d1, C1, d2, C2, Prod,  MinMod, wk)
                     62: unsigned int d1, d2;
                     63: int MinMod;
                     64: unsigned int C1[], C2[], Prod[], wk[];
                     65: {
                     66:
                     67:   unsigned int  Low0bits;
                     68:   unsigned int Proot, Prime;
                     69:   double Pinv;
                     70:   void MNpol_product_DFT();
                     71:
                     72:    if ( MinMod < 0  || MinMod > NPrimes - 1 ) return 2;
                     73:    Prime =  Primes[MinMod].prime;
                     74:    Proot =  Primes[MinMod].primroot;
                     75:    Low0bits = Primes[MinMod].d;
                     76:    Pinv = (((double)1.0)/((double)Prime));
                     77:    MNpol_product_DFT( d1, C1, d2, C2, Prod, Proot, Low0bits, Prime, Pinv, wk );
                     78:    return 0;
                     79: }
                     80:
                     81: struct oEGT {
                     82:     double exectime,gctime;
                     83: };
                     84:
                     85: extern struct oEGT eg_fore,eg_back;
                     86:
                     87: void get_eg(struct oEGT *);
                     88: void add_eg(struct oEGT *,struct oEGT *, struct oEGT *);
                     89:
                     90: void MNpol_product_DFT( d1, C1, d2, C2, Prod, a, low0s, P, Pinv, wk )
                     91: unsigned int d1, d2, low0s;
                     92: unsigned int C1[], C2[], Prod[], a, P, wk[];
                     93: double Pinv;
                     94: /*
                     95:  *  The amount of space of wk[] must be >= (11/2)*2^{\lceil \log_2(d1+d2+1) \rceil}.
                     96:  */
                     97: {
                     98:   int i, d, n, log2n, halfn;
                     99:   unsigned int *dft1, *dft2, *dftprod, *powa, *true_wk, ninv;
                    100:   struct oEGT eg0,eg1;
                    101:
                    102:   d = d1 + d2;
                    103:   /**/
                    104:   for ( n = 2, log2n = 1; n <= d; n <<= 1) log2n++;
                    105:   halfn = n >> 1;
                    106:   dft1 = wk,  dft2 = &wk[n];
                    107:   dftprod = &dft2[n];
                    108:   powa = &dftprod[n];
                    109:   true_wk = &powa[halfn];      /* 2*n areas */
                    110:   /**/
                    111:   C_PREP_ALPHA( a, low0s, log2n, halfn -1, powa, P, Pinv );
                    112: #ifdef TIMING
                    113:   RECORD_TIME(1);
                    114: #endif
                    115:   /**/
                    116:   get_eg(&eg0);
                    117:   C_DFT_FORE( C1, d1 + 1, 1, log2n, powa, dft1, 1, P, Pinv, true_wk );
                    118: #ifdef TIMING
                    119:   RECORD_TIME(2);
                    120: #endif
                    121:   C_DFT_FORE( C2, d2 + 1, 1, log2n, powa, dft2, 1, P, Pinv, true_wk );
                    122:   get_eg(&eg1); add_eg(&eg_fore,&eg0,&eg1);
                    123: #ifdef TIMING
                    124:   RECORD_TIME(3);
                    125: #endif
                    126:   /**/
                    127:   for ( i = 0; i < n; i++ ) {
                    128:     AxBmodP( dftprod[i], unsigned int, dft1[i], dft2[i], P, Pinv );
                    129:   }
                    130: #ifdef TIMING
                    131:   RECORD_TIME(4);
                    132: #endif
                    133:   /**/
                    134:   ninv = P - (unsigned int)(((int)P) >> log2n);
                    135:   get_eg(&eg0);
                    136:   C_DFT_BACK( dftprod, n, 1, log2n, powa, Prod, 1, 0, d + 1, ninv, P, Pinv, true_wk );
                    137:   get_eg(&eg1); add_eg(&eg_back,&eg0,&eg1);
                    138: }
                    139:
                    140: void MNpol_square_DFT( d1, C1, Prod, a, low0s, P, Pinv, wk )
                    141: unsigned int d1, low0s;
                    142: unsigned int C1[], Prod[], a, P, wk[];
                    143: double Pinv;
                    144: /*
                    145:  *  The amount of space of wk[] must be >= (11/2)*2^{\lceil \log_2(d1+d2+1) \rceil}.
                    146:  */
                    147: {
                    148:   int i, d, n, log2n, halfn;
                    149:   unsigned int *dft1, *dft2, *dftprod, *powa, *true_wk, ninv;
                    150:   struct oEGT eg0,eg1;
                    151:
                    152:   d = d1 + d1;
                    153:   /**/
                    154:   for ( n = 2, log2n = 1; n <= d; n <<= 1) log2n++;
                    155:   halfn = n >> 1;
                    156:   dft1 = wk,  dft2 = &wk[n];
                    157:   dftprod = &dft2[n];
                    158:   powa = &dftprod[n];
                    159:   true_wk = &powa[halfn];      /* 2*n areas */
                    160:   /**/
                    161:   C_PREP_ALPHA( a, low0s, log2n, halfn -1, powa, P, Pinv );
                    162: #ifdef TIMING
                    163:   RECORD_TIME(1);
                    164: #endif
                    165:   /**/
                    166:   get_eg(&eg0);
                    167:   C_DFT_FORE( C1, d1 + 1, 1, log2n, powa, dft1, 1, P, Pinv, true_wk );
                    168:   get_eg(&eg1); add_eg(&eg_fore,&eg0,&eg1);
                    169:   /**/
                    170:   for ( i = 0; i < n; i++ ) {
                    171:     AxBmodP( dftprod[i], unsigned int, dft1[i], dft1[i], P, Pinv );
                    172:   }
                    173: #ifdef TIMING
                    174:   RECORD_TIME(4);
                    175: #endif
                    176:   /**/
                    177:   ninv = P - (unsigned int)(((int)P) >> log2n);
                    178:   get_eg(&eg0);
                    179:   C_DFT_BACK( dftprod, n, 1, log2n, powa, Prod, 1, 0, d + 1, ninv, P, Pinv, true_wk );
                    180:   get_eg(&eg1); add_eg(&eg_back,&eg0,&eg1);
                    181: }
                    182:
                    183: /******************  TEST CODE ****************/
                    184:
                    185: /*
                    186: #define TEST
                    187: */
                    188: #ifdef TEST
                    189:
                    190: #define RETI 100
                    191:
                    192: #include <stdio.h>
                    193:
                    194: #define Prime 65537    /* 2^16 + 1 */
                    195: #define Low0bits 16
                    196: #define PrimRoot 3
                    197:
                    198: #define PInv (((double)1.0)/((double)Prime))
                    199:
                    200: #define MaxNPnts (1<<Low0bits)
                    201:
                    202: static unsigned int Pol1[MaxNPnts], Pol2[MaxNPnts], PolProdDFT[MaxNPnts], PolProd[MaxNPnts], wk[6*MaxNPnts];
                    203:
                    204: static void generate_pol(), dump_pol(), compare_vals();
                    205:
                    206: #ifdef TIMING
                    207: static void time_spent();
                    208:
                    209: #define pr_timing(pTiming) fprintf(stderr, "user:%3d.%06d + sys:%3d.%06d = %3d.%06d", \
                    210: (pTiming)->user.tv_sec,(pTiming)->user.tv_usec, \
                    211: (pTiming)->sys.tv_sec,(pTiming)->sys.tv_usec, \
                    212: (pTiming)->user.tv_sec+(pTiming)->sys.tv_sec, \
                    213: (pTiming)->user.tv_usec+(pTiming)->sys.tv_usec)
                    214:
                    215: #define Pr_timing(i) pr_timing(&time_duration[i])
                    216: #endif /* TIMING */
                    217: void main( argc, argv )
                    218: int argc;
                    219: char **argv;
                    220: {
                    221:   int i, ac = argc - 1, d1, d2, d;
                    222:
                    223:   int ii, *error;
                    224:
                    225:   char **av = &argv[1];
                    226:   unsigned int mnP = (unsigned int)Prime, mnProot = (unsigned int)PrimRoot;
                    227:   double Pinv = PInv;
                    228:
                    229:   for ( ; ac >= 2; ac -= 2, av += 2 ) {
                    230:     d = (d1 = atoi( av[0] )) + (d2 = atoi( av[1] ));
                    231:     if ( d1 <= 0 || d2 <= 0 || d >= MaxNPnts ) {
                    232:       fprintf( stderr, "*** invalid degrees %s & %s. (sum < %d)\n",
                    233:               av[0], av[1], MaxNPnts );
                    234:       continue;
                    235:     }
                    236:     generate_pol( d1, Pol1 );
                    237: /*    dump_pol( d1, Pol1 );   */
                    238:     generate_pol( d2, Pol2 );
                    239: /*  dump_pol( d2, Pol2 );    */
                    240:     /**/
                    241:
                    242:
                    243:
                    244:     for ( ii=0; ii < RETI; ii++ ){
                    245: #ifdef TIMING
                    246:     RECORD_TIME(0);
                    247: #endif
                    248:     MNpol_product_DFT( d1, Pol1, d2, Pol2, PolProdDFT, mnProot, Low0bits, mnP, Pinv, wk );
                    249: #ifdef TIMING
                    250:     RECORD_TIME(5);
                    251:     time_spent( 0, 1,  0 );
                    252:     time_spent( 1, 2,  1 );
                    253:     time_spent( 2, 3,  2 );
                    254:     time_spent( 3, 4,  3 );
                    255:     time_spent( 4, 5,  4 );
                    256:     time_spent( 0, 5,  5 );
                    257: #endif
                    258:     }
                    259:
                    260:
                    261: fprintf( stderr, "DFT mul done ( %d times )\n", RETI );
                    262:
                    263:
                    264:
                    265:
                    266:
                    267:     for ( ii=0; ii < RETI; ii++ ){
                    268: #ifdef TIMING
                    269:     RECORD_TIME(10);
                    270: #endif
                    271:     MNpol_product( d1, Pol1, d2, Pol2, PolProd, mnP, Pinv );
                    272: #ifdef TIMING
                    273:     RECORD_TIME(11);
                    274:     time_spent( 10, 11, 10 );
                    275: #endif
                    276:     }
                    277:
                    278:
                    279:
                    280: fprintf( stderr, "mul by classical alg. done ( %d times )\n", RETI );
                    281:     /**/
                    282:
                    283:
                    284:      /*     PolProdDFT[20] = 0.0;    */
                    285:
                    286:
                    287:     compare_vals( PolProdDFT, PolProd, d+1 , error);
                    288:     if ( *error == 0)
                    289:         fprintf( stderr, "******* Result is OK *******\n" );
                    290:     else
                    291:         fprintf( stderr, "******* Result is NG *******\n" );
                    292:
                    293:     /**/
                    294: #ifdef TIMING
                    295:     fprintf( stderr, "DFT     mul prep: " );  Pr_timing( 0 );
                    296:     fprintf( stderr, "\nDFT  transform 1: " );  Pr_timing( 1 );
                    297:     fprintf( stderr, "\nDFT  transform 2: " );  Pr_timing( 2 );
                    298:     fprintf( stderr, "\nDFT         mult: " );  Pr_timing( 3 );
                    299:     fprintf( stderr, "\nDFT    inv trans: " );  Pr_timing( 4 );
                    300:     fprintf( stderr, "\nDFT pol-mult Tot: " );  Pr_timing( 5 );
                    301:     /**/
                    302:     fprintf( stderr, "\nClassical mult:   " );  Pr_timing( 10 );
                    303:     fprintf( stderr, "\n" );
                    304: #endif /* TIMING */
                    305:   }
                    306: }
                    307:
                    308: static void generate_pol( d, cf )
                    309: int d;
                    310: unsigned int cf[];
                    311: {
                    312:   unsigned int mnP = (unsigned int)Prime;
                    313:   int i, c;
                    314:
                    315:   for ( i = 0; i < d; i++ ) cf[i] = random() % Prime;
                    316:   while ( (c = random() % Prime) == 0 ) ;
                    317:   cf[d] = c;
                    318: }
                    319:
                    320: static void dump_pol( d, cf )
                    321: int d;
                    322: unsigned int cf[];
                    323: {
                    324:   int i;
                    325:
                    326:   printf( "Pol of degree %d over Z/(%d) :\n", d, Prime );
                    327:   for ( i = 0; i <= d; i++ ) {
                    328:     if ( (i%5) == 0 ) printf( "\t" );
                    329:     printf( "%10d,", (int)cf[i] );
                    330:     if ( (i%5) == 4 ) printf( "\n" );
                    331:   }
                    332:   if ( (i%5) != 0 ) printf( "\n" );
                    333: }
                    334:
                    335: void compare_vals( v1, v2, n , error)
                    336: unsigned int v1[], v2[];
                    337: int n;
                    338: int *error;
                    339: {
                    340:   int i, tmp;
                    341:
                    342:   *error = 0;
                    343:
                    344:   for ( i = 0; i < n; i ++) {
                    345:         tmp = (int)v1[i] - (int)v2[i];
                    346:         tmp = abs(tmp);
                    347:         if ( tmp > *error ) *error = tmp;
                    348:   }
                    349:
                    350:
                    351:
                    352: /*
                    353:   int i, j;
                    354:
                    355:
                    356:
                    357:   for ( i = 0; i < n; i += 5 ) {
                    358:     printf( " %6d:%10d", i, (int)v1[i] );
                    359:     for ( j = 1; j < 5 && i + j < n; j++ ) printf( ",%10d", (int)v1[i+j] );
                    360:     printf( "\n" );
                    361:     if ( (int)v1[i] == (int)v2[i] && (int)v1[i+1] == (int)v2[i+1]
                    362:         && (int)v1[i+2] == (int)v2[i+2] && (int)v1[i+3] == (int)v2[i+3]
                    363:         && (int)v1[i+4] == (int)v2[i+4] ) continue;
                    364:     printf( "        " );
                    365:     for ( j = 0; j < 5 && i + j < n; j++ )
                    366:       if ( (int)v1[i+j] == (int)v2[i+j] ) printf( "           " );
                    367:       else printf( "%10d,", (int)v2[i+j] );
                    368:     printf( "\n" );
                    369:   }
                    370: */
                    371:
                    372:
                    373:
                    374: }
                    375:
                    376: #ifdef TIMING
                    377:
                    378: static void time_spent( s, e, n )
                    379: int s, e, n;
                    380: {
                    381:   struct rusage *rus = &ru_time[s], *rue = &ru_time[e];
                    382:   struct timeval *tms, *tme;
                    383:   long sec, usec;
                    384:   struct timing *pt = &time_duration[n];
                    385:
                    386:   tms = &rus->ru_utime,  tme = &rue->ru_utime;
                    387:   sec = tme->tv_sec - tms->tv_sec,
                    388:   usec = tme->tv_usec - tms->tv_usec;
                    389:   if ( usec < 0 ) usec += 1000000, sec--;
                    390: /*  pt->user.tv_sec = sec,  pt->user.tv_usec = usec;   */
                    391:   pt->user.tv_sec += sec,  pt->user.tv_usec += usec;
                    392:
                    393:
                    394:
                    395:   if ( pt->user.tv_usec > 999999 ){
                    396:   pt->user.tv_usec -= 1000000;
                    397:   pt->user.tv_sec++;
                    398:   }
                    399:
                    400:
                    401:
                    402:   /**/
                    403:   tms = &rus->ru_stime,  tme = &rue->ru_stime;
                    404:   sec = tme->tv_sec - tms->tv_sec,
                    405:   usec = tme->tv_usec - tms->tv_usec;
                    406:   if ( usec < 0 ) usec += 1000000, sec--;
                    407: /*  pt->sys.tv_sec = sec,  pt->sys.tv_usec = usec;   */
                    408:   pt->sys.tv_sec += sec,  pt->sys.tv_usec += usec;
                    409:
                    410:
                    411:
                    412:   if ( pt->sys.tv_usec > 999999 ){
                    413:   pt->sys.tv_usec -= 1000000;
                    414:   pt->sys.tv_sec++;
                    415:   }
                    416:
                    417:
                    418: }
                    419: #endif /* TIMING */
                    420:
                    421: #endif /* TEST */
                    422:
                    423: void MNpol_product( d1, C1, d2, C2, Prod, P, Pinv )
                    424: unsigned int d1, d2;
                    425: unsigned int C1[], C2[], Prod[], P;
                    426: double Pinv;
                    427: {
                    428:   unsigned int i, j;
                    429:   unsigned int c;
                    430:
                    431:   c = C1[0];
                    432:   AxBmodP( Prod[0], unsigned int, c, C2[0], P, Pinv );
                    433:   for ( i = 1; i <= d2; i++ ) {
                    434:     AxBmodPnostrchk( Prod[i], unsigned int, c, C2[i], P, Pinv );
                    435:   }
                    436:   c = C1[d1];
                    437:   if ( d1 > d2 ) {
                    438:     for ( i = d2 + 1; i < d1; i++ ) Prod[i] = (unsigned int)0;
                    439:     for ( i = 0; i < d2; i++ ) {
                    440:       AxBmodPnostrchk( Prod[i+d1], unsigned int, c, C2[i], P, Pinv );
                    441:     }
                    442:   } else {
                    443:     j = d2 - d1;
                    444:     for ( i = 0; i <= j; i++ ) {
                    445:       AxBplusCmodPnostrchk( Prod[i+d1], unsigned int, c, C2[i], Prod[i+d1], P, Pinv );
                    446:     }
                    447:     for ( ; i < d2; i++ ) {
                    448:       AxBmodPnostrchk( Prod[i+d1], unsigned int, c, C2[i], P, Pinv );
                    449:     }
                    450:   }
                    451:   AxBmodP( Prod[d1+d2], unsigned int, c, C2[d2], P, Pinv );
                    452:   /**/
                    453:   for ( j = 1; j < d1 - 1; j++ ) {
                    454:     c = C1[j];
                    455:     if ( c == (unsigned int)0 ) {
                    456:       Prod[j] = AstrictmodP( Prod[j], P );
                    457:       continue;
                    458:     }
                    459:     AxBplusCmodP( Prod[j], unsigned int, c, C2[0], Prod[j], P, Pinv );
                    460:     for ( i = 1; i <= d2; i++ ) {
                    461:       AxBplusCmodPnostrchk( Prod[i+j], unsigned int, c, C2[i], Prod[i+j], P, Pinv );
                    462:     }
                    463:   }
                    464:   c = C1[j = d1-1];
                    465:   for ( i = 0; i <= d2; i++ ) {
                    466:     AxBplusCmodP( Prod[i+j], unsigned int, c, C2[i], Prod[i+j], P, Pinv );
                    467:   }
                    468: }

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>