[BACK]Return to polmul.c CVS log [TXT][DIR] Up to [local] / OpenXM_contrib2 / asir2000 / fft

Annotation of OpenXM_contrib2/asir2000/fft/polmul.c, Revision 1.1

1.1     ! noro        1: /* $OpenXM: OpenXM/src/asir99/fft/polmul.c,v 1.1.1.1 1999/11/10 08:12:27 noro Exp $ */
        !             2: #include "dft.h"
        !             3: extern struct PrimesS Primes[];
        !             4:
        !             5: /*
        !             6: #define TIMING
        !             7: */
        !             8:
        !             9: #ifdef TIMING
        !            10:
        !            11: #define MAXTIMING 20
        !            12:
        !            13: #include <sys/time.h>
        !            14: #include <sys/resource.h>
        !            15: #if 0
        !            16: #if define(hitm)
        !            17: #include <machine/xclock.h>
        !            18: #endif
        !            19: #endif
        !            20:
        !            21: static struct rusage ru_time[MAXTIMING];
        !            22:
        !            23: struct timing {
        !            24:   struct timeval user, sys;
        !            25: };
        !            26: static struct timing time_duration[MAXTIMING];
        !            27:
        !            28: #define RECORD_TIME(i) getrusage( RUSAGE_SELF, &ru_time[i] )
        !            29: #endif /* TIMING */
        !            30:
        !            31: void FFT_primes(Num, p_prime, p_primroot, p_d)
        !            32: int Num, *p_prime, *p_primroot, *p_d;
        !            33: {
        !            34:   *p_prime =  Primes[Num].prime;
        !            35:   *p_primroot =  Primes[Num].primroot;
        !            36:   *p_d = Primes[Num].d;
        !            37: }
        !            38:
        !            39: int FFT_pol_square( d1, C1, Prod,  MinMod, wk)
        !            40: unsigned int d1;
        !            41: int MinMod;
        !            42: unsigned int C1[], Prod[], wk[];
        !            43: {
        !            44:
        !            45:   unsigned int  Low0bits;
        !            46:   unsigned int Proot, Prime;
        !            47:   double Pinv;
        !            48:   void MNpol_square_DFT();
        !            49:
        !            50:    if ( MinMod < 0  || MinMod > NPrimes - 1 ) return 2;
        !            51:    Prime =  Primes[MinMod].prime;
        !            52:    Proot =  Primes[MinMod].primroot;
        !            53:    Low0bits = Primes[MinMod].d;
        !            54:    Pinv = (((double)1.0)/((double)Prime));
        !            55:
        !            56:     MNpol_square_DFT( d1, C1, Prod, Proot, Low0bits, Prime, Pinv, wk );
        !            57:
        !            58:     return 0;
        !            59: }
        !            60:
        !            61: int FFT_pol_product( d1, C1, d2, C2, Prod,  MinMod, wk)
        !            62: unsigned int d1, d2;
        !            63: int MinMod;
        !            64: unsigned int C1[], C2[], Prod[], wk[];
        !            65: {
        !            66:
        !            67:   unsigned int  Low0bits;
        !            68:   unsigned int Proot, Prime;
        !            69:   double Pinv;
        !            70:   void MNpol_product_DFT();
        !            71:
        !            72:    if ( MinMod < 0  || MinMod > NPrimes - 1 ) return 2;
        !            73:    Prime =  Primes[MinMod].prime;
        !            74:    Proot =  Primes[MinMod].primroot;
        !            75:    Low0bits = Primes[MinMod].d;
        !            76:    Pinv = (((double)1.0)/((double)Prime));
        !            77:    MNpol_product_DFT( d1, C1, d2, C2, Prod, Proot, Low0bits, Prime, Pinv, wk );
        !            78:    return 0;
        !            79: }
        !            80:
        !            81: struct oEGT {
        !            82:     double exectime,gctime;
        !            83: };
        !            84:
        !            85: extern struct oEGT eg_fore,eg_back;
        !            86:
        !            87: void get_eg(struct oEGT *);
        !            88: void add_eg(struct oEGT *,struct oEGT *, struct oEGT *);
        !            89:
        !            90: void MNpol_product_DFT( d1, C1, d2, C2, Prod, a, low0s, P, Pinv, wk )
        !            91: unsigned int d1, d2, low0s;
        !            92: unsigned int C1[], C2[], Prod[], a, P, wk[];
        !            93: double Pinv;
        !            94: /*
        !            95:  *  The amount of space of wk[] must be >= (11/2)*2^{\lceil \log_2(d1+d2+1) \rceil}.
        !            96:  */
        !            97: {
        !            98:   int i, d, n, log2n, halfn;
        !            99:   unsigned int *dft1, *dft2, *dftprod, *powa, *true_wk, ninv;
        !           100:   struct oEGT eg0,eg1;
        !           101:
        !           102:   d = d1 + d2;
        !           103:   /**/
        !           104:   for ( n = 2, log2n = 1; n <= d; n <<= 1) log2n++;
        !           105:   halfn = n >> 1;
        !           106:   dft1 = wk,  dft2 = &wk[n];
        !           107:   dftprod = &dft2[n];
        !           108:   powa = &dftprod[n];
        !           109:   true_wk = &powa[halfn];      /* 2*n areas */
        !           110:   /**/
        !           111:   C_PREP_ALPHA( a, low0s, log2n, halfn -1, powa, P, Pinv );
        !           112: #ifdef TIMING
        !           113:   RECORD_TIME(1);
        !           114: #endif
        !           115:   /**/
        !           116:   get_eg(&eg0);
        !           117:   C_DFT_FORE( C1, d1 + 1, 1, log2n, powa, dft1, 1, P, Pinv, true_wk );
        !           118: #ifdef TIMING
        !           119:   RECORD_TIME(2);
        !           120: #endif
        !           121:   C_DFT_FORE( C2, d2 + 1, 1, log2n, powa, dft2, 1, P, Pinv, true_wk );
        !           122:   get_eg(&eg1); add_eg(&eg_fore,&eg0,&eg1);
        !           123: #ifdef TIMING
        !           124:   RECORD_TIME(3);
        !           125: #endif
        !           126:   /**/
        !           127:   for ( i = 0; i < n; i++ ) {
        !           128:     AxBmodP( dftprod[i], unsigned int, dft1[i], dft2[i], P, Pinv );
        !           129:   }
        !           130: #ifdef TIMING
        !           131:   RECORD_TIME(4);
        !           132: #endif
        !           133:   /**/
        !           134:   ninv = P - (unsigned int)(((int)P) >> log2n);
        !           135:   get_eg(&eg0);
        !           136:   C_DFT_BACK( dftprod, n, 1, log2n, powa, Prod, 1, 0, d + 1, ninv, P, Pinv, true_wk );
        !           137:   get_eg(&eg1); add_eg(&eg_back,&eg0,&eg1);
        !           138: }
        !           139:
        !           140: void MNpol_square_DFT( d1, C1, Prod, a, low0s, P, Pinv, wk )
        !           141: unsigned int d1, low0s;
        !           142: unsigned int C1[], Prod[], a, P, wk[];
        !           143: double Pinv;
        !           144: /*
        !           145:  *  The amount of space of wk[] must be >= (11/2)*2^{\lceil \log_2(d1+d2+1) \rceil}.
        !           146:  */
        !           147: {
        !           148:   int i, d, n, log2n, halfn;
        !           149:   unsigned int *dft1, *dft2, *dftprod, *powa, *true_wk, ninv;
        !           150:   struct oEGT eg0,eg1;
        !           151:
        !           152:   d = d1 + d1;
        !           153:   /**/
        !           154:   for ( n = 2, log2n = 1; n <= d; n <<= 1) log2n++;
        !           155:   halfn = n >> 1;
        !           156:   dft1 = wk,  dft2 = &wk[n];
        !           157:   dftprod = &dft2[n];
        !           158:   powa = &dftprod[n];
        !           159:   true_wk = &powa[halfn];      /* 2*n areas */
        !           160:   /**/
        !           161:   C_PREP_ALPHA( a, low0s, log2n, halfn -1, powa, P, Pinv );
        !           162: #ifdef TIMING
        !           163:   RECORD_TIME(1);
        !           164: #endif
        !           165:   /**/
        !           166:   get_eg(&eg0);
        !           167:   C_DFT_FORE( C1, d1 + 1, 1, log2n, powa, dft1, 1, P, Pinv, true_wk );
        !           168:   get_eg(&eg1); add_eg(&eg_fore,&eg0,&eg1);
        !           169:   /**/
        !           170:   for ( i = 0; i < n; i++ ) {
        !           171:     AxBmodP( dftprod[i], unsigned int, dft1[i], dft1[i], P, Pinv );
        !           172:   }
        !           173: #ifdef TIMING
        !           174:   RECORD_TIME(4);
        !           175: #endif
        !           176:   /**/
        !           177:   ninv = P - (unsigned int)(((int)P) >> log2n);
        !           178:   get_eg(&eg0);
        !           179:   C_DFT_BACK( dftprod, n, 1, log2n, powa, Prod, 1, 0, d + 1, ninv, P, Pinv, true_wk );
        !           180:   get_eg(&eg1); add_eg(&eg_back,&eg0,&eg1);
        !           181: }
        !           182:
        !           183: /******************  TEST CODE ****************/
        !           184:
        !           185: /*
        !           186: #define TEST
        !           187: */
        !           188: #ifdef TEST
        !           189:
        !           190: #define RETI 100
        !           191:
        !           192: #include <stdio.h>
        !           193:
        !           194: #define Prime 65537    /* 2^16 + 1 */
        !           195: #define Low0bits 16
        !           196: #define PrimRoot 3
        !           197:
        !           198: #define PInv (((double)1.0)/((double)Prime))
        !           199:
        !           200: #define MaxNPnts (1<<Low0bits)
        !           201:
        !           202: static unsigned int Pol1[MaxNPnts], Pol2[MaxNPnts], PolProdDFT[MaxNPnts], PolProd[MaxNPnts], wk[6*MaxNPnts];
        !           203:
        !           204: static void generate_pol(), dump_pol(), compare_vals();
        !           205:
        !           206: #ifdef TIMING
        !           207: static void time_spent();
        !           208:
        !           209: #define pr_timing(pTiming) fprintf(stderr, "user:%3d.%06d + sys:%3d.%06d = %3d.%06d", \
        !           210: (pTiming)->user.tv_sec,(pTiming)->user.tv_usec, \
        !           211: (pTiming)->sys.tv_sec,(pTiming)->sys.tv_usec, \
        !           212: (pTiming)->user.tv_sec+(pTiming)->sys.tv_sec, \
        !           213: (pTiming)->user.tv_usec+(pTiming)->sys.tv_usec)
        !           214:
        !           215: #define Pr_timing(i) pr_timing(&time_duration[i])
        !           216: #endif /* TIMING */
        !           217: void main( argc, argv )
        !           218: int argc;
        !           219: char **argv;
        !           220: {
        !           221:   int i, ac = argc - 1, d1, d2, d;
        !           222:
        !           223:   int ii, *error;
        !           224:
        !           225:   char **av = &argv[1];
        !           226:   unsigned int mnP = (unsigned int)Prime, mnProot = (unsigned int)PrimRoot;
        !           227:   double Pinv = PInv;
        !           228:
        !           229:   for ( ; ac >= 2; ac -= 2, av += 2 ) {
        !           230:     d = (d1 = atoi( av[0] )) + (d2 = atoi( av[1] ));
        !           231:     if ( d1 <= 0 || d2 <= 0 || d >= MaxNPnts ) {
        !           232:       fprintf( stderr, "*** invalid degrees %s & %s. (sum < %d)\n",
        !           233:               av[0], av[1], MaxNPnts );
        !           234:       continue;
        !           235:     }
        !           236:     generate_pol( d1, Pol1 );
        !           237: /*    dump_pol( d1, Pol1 );   */
        !           238:     generate_pol( d2, Pol2 );
        !           239: /*  dump_pol( d2, Pol2 );    */
        !           240:     /**/
        !           241:
        !           242:
        !           243:
        !           244:     for ( ii=0; ii < RETI; ii++ ){
        !           245: #ifdef TIMING
        !           246:     RECORD_TIME(0);
        !           247: #endif
        !           248:     MNpol_product_DFT( d1, Pol1, d2, Pol2, PolProdDFT, mnProot, Low0bits, mnP, Pinv, wk );
        !           249: #ifdef TIMING
        !           250:     RECORD_TIME(5);
        !           251:     time_spent( 0, 1,  0 );
        !           252:     time_spent( 1, 2,  1 );
        !           253:     time_spent( 2, 3,  2 );
        !           254:     time_spent( 3, 4,  3 );
        !           255:     time_spent( 4, 5,  4 );
        !           256:     time_spent( 0, 5,  5 );
        !           257: #endif
        !           258:     }
        !           259:
        !           260:
        !           261: fprintf( stderr, "DFT mul done ( %d times )\n", RETI );
        !           262:
        !           263:
        !           264:
        !           265:
        !           266:
        !           267:     for ( ii=0; ii < RETI; ii++ ){
        !           268: #ifdef TIMING
        !           269:     RECORD_TIME(10);
        !           270: #endif
        !           271:     MNpol_product( d1, Pol1, d2, Pol2, PolProd, mnP, Pinv );
        !           272: #ifdef TIMING
        !           273:     RECORD_TIME(11);
        !           274:     time_spent( 10, 11, 10 );
        !           275: #endif
        !           276:     }
        !           277:
        !           278:
        !           279:
        !           280: fprintf( stderr, "mul by classical alg. done ( %d times )\n", RETI );
        !           281:     /**/
        !           282:
        !           283:
        !           284:      /*     PolProdDFT[20] = 0.0;    */
        !           285:
        !           286:
        !           287:     compare_vals( PolProdDFT, PolProd, d+1 , error);
        !           288:     if ( *error == 0)
        !           289:         fprintf( stderr, "******* Result is OK *******\n" );
        !           290:     else
        !           291:         fprintf( stderr, "******* Result is NG *******\n" );
        !           292:
        !           293:     /**/
        !           294: #ifdef TIMING
        !           295:     fprintf( stderr, "DFT     mul prep: " );  Pr_timing( 0 );
        !           296:     fprintf( stderr, "\nDFT  transform 1: " );  Pr_timing( 1 );
        !           297:     fprintf( stderr, "\nDFT  transform 2: " );  Pr_timing( 2 );
        !           298:     fprintf( stderr, "\nDFT         mult: " );  Pr_timing( 3 );
        !           299:     fprintf( stderr, "\nDFT    inv trans: " );  Pr_timing( 4 );
        !           300:     fprintf( stderr, "\nDFT pol-mult Tot: " );  Pr_timing( 5 );
        !           301:     /**/
        !           302:     fprintf( stderr, "\nClassical mult:   " );  Pr_timing( 10 );
        !           303:     fprintf( stderr, "\n" );
        !           304: #endif /* TIMING */
        !           305:   }
        !           306: }
        !           307:
        !           308: static void generate_pol( d, cf )
        !           309: int d;
        !           310: unsigned int cf[];
        !           311: {
        !           312:   unsigned int mnP = (unsigned int)Prime;
        !           313:   int i, c;
        !           314:
        !           315:   for ( i = 0; i < d; i++ ) cf[i] = random() % Prime;
        !           316:   while ( (c = random() % Prime) == 0 ) ;
        !           317:   cf[d] = c;
        !           318: }
        !           319:
        !           320: static void dump_pol( d, cf )
        !           321: int d;
        !           322: unsigned int cf[];
        !           323: {
        !           324:   int i;
        !           325:
        !           326:   printf( "Pol of degree %d over Z/(%d) :\n", d, Prime );
        !           327:   for ( i = 0; i <= d; i++ ) {
        !           328:     if ( (i%5) == 0 ) printf( "\t" );
        !           329:     printf( "%10d,", (int)cf[i] );
        !           330:     if ( (i%5) == 4 ) printf( "\n" );
        !           331:   }
        !           332:   if ( (i%5) != 0 ) printf( "\n" );
        !           333: }
        !           334:
        !           335: void compare_vals( v1, v2, n , error)
        !           336: unsigned int v1[], v2[];
        !           337: int n;
        !           338: int *error;
        !           339: {
        !           340:   int i, tmp;
        !           341:
        !           342:   *error = 0;
        !           343:
        !           344:   for ( i = 0; i < n; i ++) {
        !           345:         tmp = (int)v1[i] - (int)v2[i];
        !           346:         tmp = abs(tmp);
        !           347:         if ( tmp > *error ) *error = tmp;
        !           348:   }
        !           349:
        !           350:
        !           351:
        !           352: /*
        !           353:   int i, j;
        !           354:
        !           355:
        !           356:
        !           357:   for ( i = 0; i < n; i += 5 ) {
        !           358:     printf( " %6d:%10d", i, (int)v1[i] );
        !           359:     for ( j = 1; j < 5 && i + j < n; j++ ) printf( ",%10d", (int)v1[i+j] );
        !           360:     printf( "\n" );
        !           361:     if ( (int)v1[i] == (int)v2[i] && (int)v1[i+1] == (int)v2[i+1]
        !           362:         && (int)v1[i+2] == (int)v2[i+2] && (int)v1[i+3] == (int)v2[i+3]
        !           363:         && (int)v1[i+4] == (int)v2[i+4] ) continue;
        !           364:     printf( "        " );
        !           365:     for ( j = 0; j < 5 && i + j < n; j++ )
        !           366:       if ( (int)v1[i+j] == (int)v2[i+j] ) printf( "           " );
        !           367:       else printf( "%10d,", (int)v2[i+j] );
        !           368:     printf( "\n" );
        !           369:   }
        !           370: */
        !           371:
        !           372:
        !           373:
        !           374: }
        !           375:
        !           376: #ifdef TIMING
        !           377:
        !           378: static void time_spent( s, e, n )
        !           379: int s, e, n;
        !           380: {
        !           381:   struct rusage *rus = &ru_time[s], *rue = &ru_time[e];
        !           382:   struct timeval *tms, *tme;
        !           383:   long sec, usec;
        !           384:   struct timing *pt = &time_duration[n];
        !           385:
        !           386:   tms = &rus->ru_utime,  tme = &rue->ru_utime;
        !           387:   sec = tme->tv_sec - tms->tv_sec,
        !           388:   usec = tme->tv_usec - tms->tv_usec;
        !           389:   if ( usec < 0 ) usec += 1000000, sec--;
        !           390: /*  pt->user.tv_sec = sec,  pt->user.tv_usec = usec;   */
        !           391:   pt->user.tv_sec += sec,  pt->user.tv_usec += usec;
        !           392:
        !           393:
        !           394:
        !           395:   if ( pt->user.tv_usec > 999999 ){
        !           396:   pt->user.tv_usec -= 1000000;
        !           397:   pt->user.tv_sec++;
        !           398:   }
        !           399:
        !           400:
        !           401:
        !           402:   /**/
        !           403:   tms = &rus->ru_stime,  tme = &rue->ru_stime;
        !           404:   sec = tme->tv_sec - tms->tv_sec,
        !           405:   usec = tme->tv_usec - tms->tv_usec;
        !           406:   if ( usec < 0 ) usec += 1000000, sec--;
        !           407: /*  pt->sys.tv_sec = sec,  pt->sys.tv_usec = usec;   */
        !           408:   pt->sys.tv_sec += sec,  pt->sys.tv_usec += usec;
        !           409:
        !           410:
        !           411:
        !           412:   if ( pt->sys.tv_usec > 999999 ){
        !           413:   pt->sys.tv_usec -= 1000000;
        !           414:   pt->sys.tv_sec++;
        !           415:   }
        !           416:
        !           417:
        !           418: }
        !           419: #endif /* TIMING */
        !           420:
        !           421: #endif /* TEST */
        !           422:
        !           423: void MNpol_product( d1, C1, d2, C2, Prod, P, Pinv )
        !           424: unsigned int d1, d2;
        !           425: unsigned int C1[], C2[], Prod[], P;
        !           426: double Pinv;
        !           427: {
        !           428:   unsigned int i, j;
        !           429:   unsigned int c;
        !           430:
        !           431:   c = C1[0];
        !           432:   AxBmodP( Prod[0], unsigned int, c, C2[0], P, Pinv );
        !           433:   for ( i = 1; i <= d2; i++ ) {
        !           434:     AxBmodPnostrchk( Prod[i], unsigned int, c, C2[i], P, Pinv );
        !           435:   }
        !           436:   c = C1[d1];
        !           437:   if ( d1 > d2 ) {
        !           438:     for ( i = d2 + 1; i < d1; i++ ) Prod[i] = (unsigned int)0;
        !           439:     for ( i = 0; i < d2; i++ ) {
        !           440:       AxBmodPnostrchk( Prod[i+d1], unsigned int, c, C2[i], P, Pinv );
        !           441:     }
        !           442:   } else {
        !           443:     j = d2 - d1;
        !           444:     for ( i = 0; i <= j; i++ ) {
        !           445:       AxBplusCmodPnostrchk( Prod[i+d1], unsigned int, c, C2[i], Prod[i+d1], P, Pinv );
        !           446:     }
        !           447:     for ( ; i < d2; i++ ) {
        !           448:       AxBmodPnostrchk( Prod[i+d1], unsigned int, c, C2[i], P, Pinv );
        !           449:     }
        !           450:   }
        !           451:   AxBmodP( Prod[d1+d2], unsigned int, c, C2[d2], P, Pinv );
        !           452:   /**/
        !           453:   for ( j = 1; j < d1 - 1; j++ ) {
        !           454:     c = C1[j];
        !           455:     if ( c == (unsigned int)0 ) {
        !           456:       Prod[j] = AstrictmodP( Prod[j], P );
        !           457:       continue;
        !           458:     }
        !           459:     AxBplusCmodP( Prod[j], unsigned int, c, C2[0], Prod[j], P, Pinv );
        !           460:     for ( i = 1; i <= d2; i++ ) {
        !           461:       AxBplusCmodPnostrchk( Prod[i+j], unsigned int, c, C2[i], Prod[i+j], P, Pinv );
        !           462:     }
        !           463:   }
        !           464:   c = C1[j = d1-1];
        !           465:   for ( i = 0; i <= d2; i++ ) {
        !           466:     AxBplusCmodP( Prod[i+j], unsigned int, c, C2[i], Prod[i+j], P, Pinv );
        !           467:   }
        !           468: }

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>