Annotation of OpenXM_contrib2/asir2000/fft/polmul.c, Revision 1.1.1.1
1.1 noro 1: /* $OpenXM: OpenXM/src/asir99/fft/polmul.c,v 1.1.1.1 1999/11/10 08:12:27 noro Exp $ */
2: #include "dft.h"
3: extern struct PrimesS Primes[];
4:
5: /*
6: #define TIMING
7: */
8:
9: #ifdef TIMING
10:
11: #define MAXTIMING 20
12:
13: #include <sys/time.h>
14: #include <sys/resource.h>
15: #if 0
16: #if define(hitm)
17: #include <machine/xclock.h>
18: #endif
19: #endif
20:
21: static struct rusage ru_time[MAXTIMING];
22:
23: struct timing {
24: struct timeval user, sys;
25: };
26: static struct timing time_duration[MAXTIMING];
27:
28: #define RECORD_TIME(i) getrusage( RUSAGE_SELF, &ru_time[i] )
29: #endif /* TIMING */
30:
31: void FFT_primes(Num, p_prime, p_primroot, p_d)
32: int Num, *p_prime, *p_primroot, *p_d;
33: {
34: *p_prime = Primes[Num].prime;
35: *p_primroot = Primes[Num].primroot;
36: *p_d = Primes[Num].d;
37: }
38:
39: int FFT_pol_square( d1, C1, Prod, MinMod, wk)
40: unsigned int d1;
41: int MinMod;
42: unsigned int C1[], Prod[], wk[];
43: {
44:
45: unsigned int Low0bits;
46: unsigned int Proot, Prime;
47: double Pinv;
48: void MNpol_square_DFT();
49:
50: if ( MinMod < 0 || MinMod > NPrimes - 1 ) return 2;
51: Prime = Primes[MinMod].prime;
52: Proot = Primes[MinMod].primroot;
53: Low0bits = Primes[MinMod].d;
54: Pinv = (((double)1.0)/((double)Prime));
55:
56: MNpol_square_DFT( d1, C1, Prod, Proot, Low0bits, Prime, Pinv, wk );
57:
58: return 0;
59: }
60:
61: int FFT_pol_product( d1, C1, d2, C2, Prod, MinMod, wk)
62: unsigned int d1, d2;
63: int MinMod;
64: unsigned int C1[], C2[], Prod[], wk[];
65: {
66:
67: unsigned int Low0bits;
68: unsigned int Proot, Prime;
69: double Pinv;
70: void MNpol_product_DFT();
71:
72: if ( MinMod < 0 || MinMod > NPrimes - 1 ) return 2;
73: Prime = Primes[MinMod].prime;
74: Proot = Primes[MinMod].primroot;
75: Low0bits = Primes[MinMod].d;
76: Pinv = (((double)1.0)/((double)Prime));
77: MNpol_product_DFT( d1, C1, d2, C2, Prod, Proot, Low0bits, Prime, Pinv, wk );
78: return 0;
79: }
80:
81: struct oEGT {
82: double exectime,gctime;
83: };
84:
85: extern struct oEGT eg_fore,eg_back;
86:
87: void get_eg(struct oEGT *);
88: void add_eg(struct oEGT *,struct oEGT *, struct oEGT *);
89:
90: void MNpol_product_DFT( d1, C1, d2, C2, Prod, a, low0s, P, Pinv, wk )
91: unsigned int d1, d2, low0s;
92: unsigned int C1[], C2[], Prod[], a, P, wk[];
93: double Pinv;
94: /*
95: * The amount of space of wk[] must be >= (11/2)*2^{\lceil \log_2(d1+d2+1) \rceil}.
96: */
97: {
98: int i, d, n, log2n, halfn;
99: unsigned int *dft1, *dft2, *dftprod, *powa, *true_wk, ninv;
100: struct oEGT eg0,eg1;
101:
102: d = d1 + d2;
103: /**/
104: for ( n = 2, log2n = 1; n <= d; n <<= 1) log2n++;
105: halfn = n >> 1;
106: dft1 = wk, dft2 = &wk[n];
107: dftprod = &dft2[n];
108: powa = &dftprod[n];
109: true_wk = &powa[halfn]; /* 2*n areas */
110: /**/
111: C_PREP_ALPHA( a, low0s, log2n, halfn -1, powa, P, Pinv );
112: #ifdef TIMING
113: RECORD_TIME(1);
114: #endif
115: /**/
116: get_eg(&eg0);
117: C_DFT_FORE( C1, d1 + 1, 1, log2n, powa, dft1, 1, P, Pinv, true_wk );
118: #ifdef TIMING
119: RECORD_TIME(2);
120: #endif
121: C_DFT_FORE( C2, d2 + 1, 1, log2n, powa, dft2, 1, P, Pinv, true_wk );
122: get_eg(&eg1); add_eg(&eg_fore,&eg0,&eg1);
123: #ifdef TIMING
124: RECORD_TIME(3);
125: #endif
126: /**/
127: for ( i = 0; i < n; i++ ) {
128: AxBmodP( dftprod[i], unsigned int, dft1[i], dft2[i], P, Pinv );
129: }
130: #ifdef TIMING
131: RECORD_TIME(4);
132: #endif
133: /**/
134: ninv = P - (unsigned int)(((int)P) >> log2n);
135: get_eg(&eg0);
136: C_DFT_BACK( dftprod, n, 1, log2n, powa, Prod, 1, 0, d + 1, ninv, P, Pinv, true_wk );
137: get_eg(&eg1); add_eg(&eg_back,&eg0,&eg1);
138: }
139:
140: void MNpol_square_DFT( d1, C1, Prod, a, low0s, P, Pinv, wk )
141: unsigned int d1, low0s;
142: unsigned int C1[], Prod[], a, P, wk[];
143: double Pinv;
144: /*
145: * The amount of space of wk[] must be >= (11/2)*2^{\lceil \log_2(d1+d2+1) \rceil}.
146: */
147: {
148: int i, d, n, log2n, halfn;
149: unsigned int *dft1, *dft2, *dftprod, *powa, *true_wk, ninv;
150: struct oEGT eg0,eg1;
151:
152: d = d1 + d1;
153: /**/
154: for ( n = 2, log2n = 1; n <= d; n <<= 1) log2n++;
155: halfn = n >> 1;
156: dft1 = wk, dft2 = &wk[n];
157: dftprod = &dft2[n];
158: powa = &dftprod[n];
159: true_wk = &powa[halfn]; /* 2*n areas */
160: /**/
161: C_PREP_ALPHA( a, low0s, log2n, halfn -1, powa, P, Pinv );
162: #ifdef TIMING
163: RECORD_TIME(1);
164: #endif
165: /**/
166: get_eg(&eg0);
167: C_DFT_FORE( C1, d1 + 1, 1, log2n, powa, dft1, 1, P, Pinv, true_wk );
168: get_eg(&eg1); add_eg(&eg_fore,&eg0,&eg1);
169: /**/
170: for ( i = 0; i < n; i++ ) {
171: AxBmodP( dftprod[i], unsigned int, dft1[i], dft1[i], P, Pinv );
172: }
173: #ifdef TIMING
174: RECORD_TIME(4);
175: #endif
176: /**/
177: ninv = P - (unsigned int)(((int)P) >> log2n);
178: get_eg(&eg0);
179: C_DFT_BACK( dftprod, n, 1, log2n, powa, Prod, 1, 0, d + 1, ninv, P, Pinv, true_wk );
180: get_eg(&eg1); add_eg(&eg_back,&eg0,&eg1);
181: }
182:
183: /****************** TEST CODE ****************/
184:
185: /*
186: #define TEST
187: */
188: #ifdef TEST
189:
190: #define RETI 100
191:
192: #include <stdio.h>
193:
194: #define Prime 65537 /* 2^16 + 1 */
195: #define Low0bits 16
196: #define PrimRoot 3
197:
198: #define PInv (((double)1.0)/((double)Prime))
199:
200: #define MaxNPnts (1<<Low0bits)
201:
202: static unsigned int Pol1[MaxNPnts], Pol2[MaxNPnts], PolProdDFT[MaxNPnts], PolProd[MaxNPnts], wk[6*MaxNPnts];
203:
204: static void generate_pol(), dump_pol(), compare_vals();
205:
206: #ifdef TIMING
207: static void time_spent();
208:
209: #define pr_timing(pTiming) fprintf(stderr, "user:%3d.%06d + sys:%3d.%06d = %3d.%06d", \
210: (pTiming)->user.tv_sec,(pTiming)->user.tv_usec, \
211: (pTiming)->sys.tv_sec,(pTiming)->sys.tv_usec, \
212: (pTiming)->user.tv_sec+(pTiming)->sys.tv_sec, \
213: (pTiming)->user.tv_usec+(pTiming)->sys.tv_usec)
214:
215: #define Pr_timing(i) pr_timing(&time_duration[i])
216: #endif /* TIMING */
217: void main( argc, argv )
218: int argc;
219: char **argv;
220: {
221: int i, ac = argc - 1, d1, d2, d;
222:
223: int ii, *error;
224:
225: char **av = &argv[1];
226: unsigned int mnP = (unsigned int)Prime, mnProot = (unsigned int)PrimRoot;
227: double Pinv = PInv;
228:
229: for ( ; ac >= 2; ac -= 2, av += 2 ) {
230: d = (d1 = atoi( av[0] )) + (d2 = atoi( av[1] ));
231: if ( d1 <= 0 || d2 <= 0 || d >= MaxNPnts ) {
232: fprintf( stderr, "*** invalid degrees %s & %s. (sum < %d)\n",
233: av[0], av[1], MaxNPnts );
234: continue;
235: }
236: generate_pol( d1, Pol1 );
237: /* dump_pol( d1, Pol1 ); */
238: generate_pol( d2, Pol2 );
239: /* dump_pol( d2, Pol2 ); */
240: /**/
241:
242:
243:
244: for ( ii=0; ii < RETI; ii++ ){
245: #ifdef TIMING
246: RECORD_TIME(0);
247: #endif
248: MNpol_product_DFT( d1, Pol1, d2, Pol2, PolProdDFT, mnProot, Low0bits, mnP, Pinv, wk );
249: #ifdef TIMING
250: RECORD_TIME(5);
251: time_spent( 0, 1, 0 );
252: time_spent( 1, 2, 1 );
253: time_spent( 2, 3, 2 );
254: time_spent( 3, 4, 3 );
255: time_spent( 4, 5, 4 );
256: time_spent( 0, 5, 5 );
257: #endif
258: }
259:
260:
261: fprintf( stderr, "DFT mul done ( %d times )\n", RETI );
262:
263:
264:
265:
266:
267: for ( ii=0; ii < RETI; ii++ ){
268: #ifdef TIMING
269: RECORD_TIME(10);
270: #endif
271: MNpol_product( d1, Pol1, d2, Pol2, PolProd, mnP, Pinv );
272: #ifdef TIMING
273: RECORD_TIME(11);
274: time_spent( 10, 11, 10 );
275: #endif
276: }
277:
278:
279:
280: fprintf( stderr, "mul by classical alg. done ( %d times )\n", RETI );
281: /**/
282:
283:
284: /* PolProdDFT[20] = 0.0; */
285:
286:
287: compare_vals( PolProdDFT, PolProd, d+1 , error);
288: if ( *error == 0)
289: fprintf( stderr, "******* Result is OK *******\n" );
290: else
291: fprintf( stderr, "******* Result is NG *******\n" );
292:
293: /**/
294: #ifdef TIMING
295: fprintf( stderr, "DFT mul prep: " ); Pr_timing( 0 );
296: fprintf( stderr, "\nDFT transform 1: " ); Pr_timing( 1 );
297: fprintf( stderr, "\nDFT transform 2: " ); Pr_timing( 2 );
298: fprintf( stderr, "\nDFT mult: " ); Pr_timing( 3 );
299: fprintf( stderr, "\nDFT inv trans: " ); Pr_timing( 4 );
300: fprintf( stderr, "\nDFT pol-mult Tot: " ); Pr_timing( 5 );
301: /**/
302: fprintf( stderr, "\nClassical mult: " ); Pr_timing( 10 );
303: fprintf( stderr, "\n" );
304: #endif /* TIMING */
305: }
306: }
307:
308: static void generate_pol( d, cf )
309: int d;
310: unsigned int cf[];
311: {
312: unsigned int mnP = (unsigned int)Prime;
313: int i, c;
314:
315: for ( i = 0; i < d; i++ ) cf[i] = random() % Prime;
316: while ( (c = random() % Prime) == 0 ) ;
317: cf[d] = c;
318: }
319:
320: static void dump_pol( d, cf )
321: int d;
322: unsigned int cf[];
323: {
324: int i;
325:
326: printf( "Pol of degree %d over Z/(%d) :\n", d, Prime );
327: for ( i = 0; i <= d; i++ ) {
328: if ( (i%5) == 0 ) printf( "\t" );
329: printf( "%10d,", (int)cf[i] );
330: if ( (i%5) == 4 ) printf( "\n" );
331: }
332: if ( (i%5) != 0 ) printf( "\n" );
333: }
334:
335: void compare_vals( v1, v2, n , error)
336: unsigned int v1[], v2[];
337: int n;
338: int *error;
339: {
340: int i, tmp;
341:
342: *error = 0;
343:
344: for ( i = 0; i < n; i ++) {
345: tmp = (int)v1[i] - (int)v2[i];
346: tmp = abs(tmp);
347: if ( tmp > *error ) *error = tmp;
348: }
349:
350:
351:
352: /*
353: int i, j;
354:
355:
356:
357: for ( i = 0; i < n; i += 5 ) {
358: printf( " %6d:%10d", i, (int)v1[i] );
359: for ( j = 1; j < 5 && i + j < n; j++ ) printf( ",%10d", (int)v1[i+j] );
360: printf( "\n" );
361: if ( (int)v1[i] == (int)v2[i] && (int)v1[i+1] == (int)v2[i+1]
362: && (int)v1[i+2] == (int)v2[i+2] && (int)v1[i+3] == (int)v2[i+3]
363: && (int)v1[i+4] == (int)v2[i+4] ) continue;
364: printf( " " );
365: for ( j = 0; j < 5 && i + j < n; j++ )
366: if ( (int)v1[i+j] == (int)v2[i+j] ) printf( " " );
367: else printf( "%10d,", (int)v2[i+j] );
368: printf( "\n" );
369: }
370: */
371:
372:
373:
374: }
375:
376: #ifdef TIMING
377:
378: static void time_spent( s, e, n )
379: int s, e, n;
380: {
381: struct rusage *rus = &ru_time[s], *rue = &ru_time[e];
382: struct timeval *tms, *tme;
383: long sec, usec;
384: struct timing *pt = &time_duration[n];
385:
386: tms = &rus->ru_utime, tme = &rue->ru_utime;
387: sec = tme->tv_sec - tms->tv_sec,
388: usec = tme->tv_usec - tms->tv_usec;
389: if ( usec < 0 ) usec += 1000000, sec--;
390: /* pt->user.tv_sec = sec, pt->user.tv_usec = usec; */
391: pt->user.tv_sec += sec, pt->user.tv_usec += usec;
392:
393:
394:
395: if ( pt->user.tv_usec > 999999 ){
396: pt->user.tv_usec -= 1000000;
397: pt->user.tv_sec++;
398: }
399:
400:
401:
402: /**/
403: tms = &rus->ru_stime, tme = &rue->ru_stime;
404: sec = tme->tv_sec - tms->tv_sec,
405: usec = tme->tv_usec - tms->tv_usec;
406: if ( usec < 0 ) usec += 1000000, sec--;
407: /* pt->sys.tv_sec = sec, pt->sys.tv_usec = usec; */
408: pt->sys.tv_sec += sec, pt->sys.tv_usec += usec;
409:
410:
411:
412: if ( pt->sys.tv_usec > 999999 ){
413: pt->sys.tv_usec -= 1000000;
414: pt->sys.tv_sec++;
415: }
416:
417:
418: }
419: #endif /* TIMING */
420:
421: #endif /* TEST */
422:
423: void MNpol_product( d1, C1, d2, C2, Prod, P, Pinv )
424: unsigned int d1, d2;
425: unsigned int C1[], C2[], Prod[], P;
426: double Pinv;
427: {
428: unsigned int i, j;
429: unsigned int c;
430:
431: c = C1[0];
432: AxBmodP( Prod[0], unsigned int, c, C2[0], P, Pinv );
433: for ( i = 1; i <= d2; i++ ) {
434: AxBmodPnostrchk( Prod[i], unsigned int, c, C2[i], P, Pinv );
435: }
436: c = C1[d1];
437: if ( d1 > d2 ) {
438: for ( i = d2 + 1; i < d1; i++ ) Prod[i] = (unsigned int)0;
439: for ( i = 0; i < d2; i++ ) {
440: AxBmodPnostrchk( Prod[i+d1], unsigned int, c, C2[i], P, Pinv );
441: }
442: } else {
443: j = d2 - d1;
444: for ( i = 0; i <= j; i++ ) {
445: AxBplusCmodPnostrchk( Prod[i+d1], unsigned int, c, C2[i], Prod[i+d1], P, Pinv );
446: }
447: for ( ; i < d2; i++ ) {
448: AxBmodPnostrchk( Prod[i+d1], unsigned int, c, C2[i], P, Pinv );
449: }
450: }
451: AxBmodP( Prod[d1+d2], unsigned int, c, C2[d2], P, Pinv );
452: /**/
453: for ( j = 1; j < d1 - 1; j++ ) {
454: c = C1[j];
455: if ( c == (unsigned int)0 ) {
456: Prod[j] = AstrictmodP( Prod[j], P );
457: continue;
458: }
459: AxBplusCmodP( Prod[j], unsigned int, c, C2[0], Prod[j], P, Pinv );
460: for ( i = 1; i <= d2; i++ ) {
461: AxBplusCmodPnostrchk( Prod[i+j], unsigned int, c, C2[i], Prod[i+j], P, Pinv );
462: }
463: }
464: c = C1[j = d1-1];
465: for ( i = 0; i <= d2; i++ ) {
466: AxBplusCmodP( Prod[i+j], unsigned int, c, C2[i], Prod[i+j], P, Pinv );
467: }
468: }
FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>