[BACK]Return to mul_fft.c CVS log [TXT][DIR] Up to [local] / OpenXM_contrib / gmp / mpn / generic

Annotation of OpenXM_contrib/gmp/mpn/generic/mul_fft.c, Revision 1.1.1.1

1.1       maekawa     1: /* An implementation in GMP of Scho"nhage's fast multiplication algorithm
                      2:    modulo 2^N+1, by Paul Zimmermann, INRIA Lorraine, February 1998.
                      3:
                      4:    THE CONTENTS OF THIS FILE ARE FOR INTERNAL USE AND THE FUNCTIONS HAVE
                      5:    MUTABLE INTERFACES.  IT IS ONLY SAFE TO REACH THEM THROUGH DOCUMENTED
                      6:    INTERFACES.  IT IS ALMOST GUARANTEED THAT THEY'LL CHANGE OR DISAPPEAR IN
                      7:    A FUTURE GNU MP RELEASE.
                      8:
                      9: Copyright (C) 1998, 1999, 2000 Free Software Foundation, Inc.
                     10:
                     11: This file is part of the GNU MP Library.
                     12:
                     13: The GNU MP Library is free software; you can redistribute it and/or modify
                     14: it under the terms of the GNU Lesser General Public License as published by
                     15: the Free Software Foundation; either version 2.1 of the License, or (at your
                     16: option) any later version.
                     17:
                     18: The GNU MP Library is distributed in the hope that it will be useful, but
                     19: WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
                     20: or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
                     21: License for more details.
                     22:
                     23: You should have received a copy of the GNU Lesser General Public License
                     24: along with the GNU MP Library; see the file COPYING.LIB.  If not, write to
                     25: the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
                     26: MA 02111-1307, USA. */
                     27:
                     28:
                     29: /* References:
                     30:
                     31:    Schnelle Multiplikation grosser Zahlen, by Arnold Scho"nhage and Volker
                     32:    Strassen, Computing 7, p. 281-292, 1971.
                     33:
                     34:    Asymptotically fast algorithms for the numerical multiplication
                     35:    and division of polynomials with complex coefficients, by Arnold Scho"nhage,
                     36:    Computer Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
                     37:
                     38:    Tapes versus Pointers, a study in implementing fast algorithms,
                     39:    by Arnold Scho"nhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
                     40:
                     41:    See also http://www.loria.fr/~zimmerma/bignum
                     42:
                     43:
                     44:    Future:
                     45:
                     46:    K==2 isn't needed in the current uses of this code and the bits specific
                     47:    for that could be dropped.
                     48:
                     49:    It might be possible to avoid a small number of MPN_COPYs by using a
                     50:    rotating temporary or two.
                     51:
                     52:    Multiplications of unequal sized operands can be done with this code, but
                     53:    it needs a tighter test for identifying squaring (same sizes as well as
                     54:    same pointers).  */
                     55:
                     56:
                     57: #include <stdio.h>
                     58: #include "gmp.h"
                     59: #include "gmp-impl.h"
                     60:
                     61:
                     62: /* Change this to "#define TRACE(x) x" for some traces. */
                     63: #define TRACE(x)
                     64:
                     65:
                     66:
                     67: FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] = {
                     68:   FFT_MUL_TABLE,
                     69:   FFT_SQR_TABLE
                     70: };
                     71:
                     72:
                     73: static void mpn_mul_fft_internal
                     74: _PROTO ((mp_limb_t *op, mp_srcptr n, mp_srcptr m, mp_size_t pl,
                     75:          int k, int K,
                     76:          mp_limb_t **Ap, mp_limb_t **Bp,
                     77:          mp_limb_t *A, mp_limb_t *B,
                     78:          mp_size_t nprime, mp_size_t l, mp_size_t Mp, int **_fft_l,
                     79:          mp_limb_t *T, int rec));
                     80:
                     81:
                     82: /* Find the best k to use for a mod 2^(n*BITS_PER_MP_LIMB)+1 FFT.
                     83:    sqr==0 if for a multiply, sqr==1 for a square */
                     84: int
                     85: #if __STDC__
                     86: mpn_fft_best_k (mp_size_t n, int sqr)
                     87: #else
                     88: mpn_fft_best_k (n, sqr)
                     89:      mp_size_t n;
                     90:      int       sqr;
                     91: #endif
                     92: {
                     93:   mp_size_t  t;
                     94:   int        i;
                     95:
                     96:   for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
                     97:     if (n < mpn_fft_table[sqr][i])
                     98:       return i + FFT_FIRST_K;
                     99:
                    100:   /* treat 4*last as one further entry */
                    101:   if (i == 0 || n < 4*mpn_fft_table[sqr][i-1])
                    102:     return i + FFT_FIRST_K;
                    103:   else
                    104:     return i + FFT_FIRST_K + 1;
                    105: }
                    106:
                    107:
                    108: /* Returns smallest possible number of limbs >= pl for a fft of size 2^k.
                    109:    FIXME: Is this simply pl rounded up to the next multiple of 2^k ?  */
                    110:
                    111: mp_size_t
                    112: #if __STDC__
                    113: mpn_fft_next_size (mp_size_t pl, int k)
                    114: #else
                    115: mpn_fft_next_size (pl, k)
                    116:      mp_size_t pl;
                    117:      int       k;
                    118: #endif
                    119: {
                    120:   mp_size_t N, M;
                    121:   int       K;
                    122:
                    123:   /*  if (k==0) k = mpn_fft_best_k (pl, sqr); */
                    124:   N = pl*BITS_PER_MP_LIMB;
                    125:   K = 1<<k;
                    126:   if (N%K) N=(N/K+1)*K;
                    127:   M = N/K;
                    128:   if (M%BITS_PER_MP_LIMB) N=((M/BITS_PER_MP_LIMB)+1)*BITS_PER_MP_LIMB*K;
                    129:   return (N/BITS_PER_MP_LIMB);
                    130: }
                    131:
                    132:
                    133: static void
                    134: #if __STDC__
                    135: mpn_fft_initl(int **l, int k)
                    136: #else
                    137: mpn_fft_initl(l, k)
                    138:      int  **l;
                    139:      int  k;
                    140: #endif
                    141: {
                    142:     int i,j,K;
                    143:
                    144:     l[0][0] = 0;
                    145:     for (i=1,K=2;i<=k;i++,K*=2) {
                    146:        for (j=0;j<K/2;j++) {
                    147:            l[i][j] = 2*l[i-1][j];
                    148:            l[i][K/2+j] = 1+l[i][j];
                    149:        }
                    150:     }
                    151: }
                    152:
                    153:
                    154: /* a <- -a mod 2^(n*BITS_PER_MP_LIMB)+1 */
                    155: static void
                    156: #if __STDC__
                    157: mpn_fft_neg_modF(mp_limb_t *ap, mp_size_t n)
                    158: #else
                    159: mpn_fft_neg_modF(ap, n)
                    160:      mp_limb_t *ap;
                    161:      mp_size_t n;
                    162: #endif
                    163: {
                    164:   mp_limb_t c;
                    165:
                    166:   c = ap[n]+2;
                    167:   mpn_com_n (ap, ap, n);
                    168:   ap[n]=0; mpn_incr_u(ap, c);
                    169: }
                    170:
                    171:
                    172: /* a <- a*2^e mod 2^(n*BITS_PER_MP_LIMB)+1 */
                    173: static void
                    174: #if __STDC__
                    175: mpn_fft_mul_2exp_modF(mp_limb_t *ap, int e, mp_size_t n, mp_limb_t *tp)
                    176: #else
                    177: mpn_fft_mul_2exp_modF(ap, e, n, tp)
                    178:      mp_limb_t *ap;
                    179:      int e;
                    180:      mp_size_t n;
                    181:      mp_limb_t *tp;
                    182: #endif
                    183: {
                    184:   int d, sh, i; mp_limb_t cc;
                    185:
                    186:   d = e%(n*BITS_PER_MP_LIMB); /* 2^e = (+/-) 2^d */
                    187:   sh = d % BITS_PER_MP_LIMB;
                    188:   if (sh) mpn_lshift(tp, ap, n+1, sh); /* no carry here */
                    189:   else MPN_COPY(tp, ap, n+1);
                    190:   d /= BITS_PER_MP_LIMB; /* now shift of d limbs to the left */
                    191:  if (d) {
                    192:    /* ap[d..n-1] = tp[0..n-d-1], ap[0..d-1] = -tp[n-d..n-1] */
                    193:    /* mpn_xor would be more efficient here */
                    194:    for (i=d-1;i>=0;i--) ap[i] = ~tp[n-d+i];
                    195:    cc = 1-mpn_add_1(ap, ap, d, 1);
                    196:    if (cc) cc=mpn_sub_1(ap+d, tp, n-d, 1);
                    197:    else MPN_COPY(ap+d, tp, n-d);
                    198:    if (cc+=mpn_sub_1(ap+d, ap+d, n-d, tp[n]))
                    199:      ap[n]=mpn_add_1(ap, ap, n, cc);
                    200:    else ap[n]=0;
                    201:   }
                    202:   else if ((ap[n]=mpn_sub_1(ap, tp, n, tp[n]))) {
                    203:     ap[n]=mpn_add_1(ap, ap, n, 1);
                    204:   }
                    205:   if ((e/(n*BITS_PER_MP_LIMB))%2) mpn_fft_neg_modF(ap, n);
                    206: }
                    207:
                    208:
                    209: /* a <- a+b mod 2^(n*BITS_PER_MP_LIMB)+1 */
                    210: static void
                    211: #if __STDC__
                    212: mpn_fft_add_modF (mp_limb_t *ap, mp_limb_t *bp, int n)
                    213: #else
                    214: mpn_fft_add_modF (ap, bp, n)
                    215:      mp_limb_t *ap,*bp;
                    216:      int n;
                    217: #endif
                    218: {
                    219:   mp_limb_t c;
                    220:
                    221:   c = ap[n] + bp[n] + mpn_add_n(ap, ap, bp, n);
                    222:   if (c>1) c -= 1+mpn_sub_1(ap,ap,n,1);
                    223:   ap[n]=c;
                    224: }
                    225:
                    226:
                    227: /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
                    228:           N=n*BITS_PER_MP_LIMB
                    229:           2^omega is a primitive root mod 2^N+1
                    230:    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
                    231:
                    232: static void
                    233: #if __STDC__
                    234: mpn_fft_fft_sqr (mp_limb_t **Ap, mp_size_t K, int **ll,
                    235:                  mp_size_t omega, mp_size_t n, mp_size_t inc, mp_limb_t *tp)
                    236: #else
                    237: mpn_fft_fft_sqr(Ap,K,ll,omega,n,inc,tp)
                    238: mp_limb_t **Ap,*tp;
                    239: mp_size_t K,omega,n,inc;
                    240: int       **ll;
                    241: #endif
                    242: {
                    243:   if (K==2) {
                    244: #ifdef ADDSUB
                    245:       if (mpn_addsub_n(Ap[0], Ap[inc], Ap[0], Ap[inc], n+1) & 1)
                    246: #else
                    247:       MPN_COPY(tp, Ap[0], n+1);
                    248:       mpn_add_n(Ap[0], Ap[0], Ap[inc],n+1);
                    249:       if (mpn_sub_n(Ap[inc], tp, Ap[inc],n+1))
                    250: #endif
                    251:        Ap[inc][n] = mpn_add_1(Ap[inc], Ap[inc], n, 1);
                    252:     }
                    253:     else {
                    254:       int       j, inc2=2*inc;
                    255:       int       *lk = *ll;
                    256:       mp_limb_t *tmp;
                    257:       TMP_DECL(marker);
                    258:
                    259:       TMP_MARK(marker);
                    260:       tmp = TMP_ALLOC_LIMBS (n+1);
                    261:        mpn_fft_fft_sqr(Ap, K/2,ll-1,2*omega,n,inc2, tp);
                    262:        mpn_fft_fft_sqr(Ap+inc, K/2,ll-1,2*omega,n,inc2, tp);
                    263:        /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
                    264:           A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
                    265:        for (j=0;j<K/2;j++,lk+=2,Ap+=2*inc) {
                    266:          MPN_COPY(tp, Ap[inc], n+1);
                    267:          mpn_fft_mul_2exp_modF(Ap[inc], lk[1]*omega, n, tmp);
                    268:          mpn_fft_add_modF(Ap[inc], Ap[0], n);
                    269:          mpn_fft_mul_2exp_modF(tp,lk[0]*omega, n, tmp);
                    270:          mpn_fft_add_modF(Ap[0], tp, n);
                    271:        }
                    272:         TMP_FREE(marker);
                    273:     }
                    274: }
                    275:
                    276:
                    277: /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
                    278:           N=n*BITS_PER_MP_LIMB
                    279:          2^omega is a primitive root mod 2^N+1
                    280:    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
                    281:
                    282: static void
                    283: #if __STDC__
                    284: mpn_fft_fft (mp_limb_t **Ap, mp_limb_t **Bp, mp_size_t K, int **ll,
                    285:              mp_size_t omega, mp_size_t n, mp_size_t inc, mp_limb_t *tp)
                    286: #else
                    287: mpn_fft_fft(Ap,Bp,K,ll,omega,n,inc,tp)
                    288:      mp_limb_t **Ap,**Bp,*tp;
                    289:      mp_size_t K,omega,n,inc;
                    290:      int       **ll;
                    291: #endif
                    292: {
                    293:   if (K==2) {
                    294: #ifdef ADDSUB
                    295:       if (mpn_addsub_n(Ap[0], Ap[inc], Ap[0], Ap[inc], n+1) & 1)
                    296: #else
                    297:       MPN_COPY(tp, Ap[0], n+1);
                    298:       mpn_add_n(Ap[0], Ap[0], Ap[inc],n+1);
                    299:       if (mpn_sub_n(Ap[inc], tp, Ap[inc],n+1))
                    300: #endif
                    301:        Ap[inc][n] = mpn_add_1(Ap[inc], Ap[inc], n, 1);
                    302: #ifdef ADDSUB
                    303:       if (mpn_addsub_n(Bp[0], Bp[inc], Bp[0], Bp[inc], n+1) & 1)
                    304: #else
                    305:       MPN_COPY(tp, Bp[0], n+1);
                    306:       mpn_add_n(Bp[0], Bp[0], Bp[inc],n+1);
                    307:       if (mpn_sub_n(Bp[inc], tp, Bp[inc],n+1))
                    308: #endif
                    309:        Bp[inc][n] = mpn_add_1(Bp[inc], Bp[inc], n, 1);
                    310:     }
                    311:     else {
                    312:        int       j, inc2=2*inc;
                    313:         int       *lk=*ll;
                    314:         mp_limb_t *tmp;
                    315:        TMP_DECL(marker);
                    316:
                    317:        TMP_MARK(marker);
                    318:        tmp = TMP_ALLOC_LIMBS (n+1);
                    319:        mpn_fft_fft(Ap, Bp, K/2,ll-1,2*omega,n,inc2, tp);
                    320:        mpn_fft_fft(Ap+inc, Bp+inc, K/2,ll-1,2*omega,n,inc2, tp);
                    321:        /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
                    322:           A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
                    323:        for (j=0;j<K/2;j++,lk+=2,Ap+=2*inc,Bp+=2*inc) {
                    324:          MPN_COPY(tp, Ap[inc], n+1);
                    325:          mpn_fft_mul_2exp_modF(Ap[inc], lk[1]*omega, n, tmp);
                    326:          mpn_fft_add_modF(Ap[inc], Ap[0], n);
                    327:          mpn_fft_mul_2exp_modF(tp,lk[0]*omega, n, tmp);
                    328:          mpn_fft_add_modF(Ap[0], tp, n);
                    329:          MPN_COPY(tp, Bp[inc], n+1);
                    330:          mpn_fft_mul_2exp_modF(Bp[inc], lk[1]*omega, n, tmp);
                    331:          mpn_fft_add_modF(Bp[inc], Bp[0], n);
                    332:          mpn_fft_mul_2exp_modF(tp,lk[0]*omega, n, tmp);
                    333:          mpn_fft_add_modF(Bp[0], tp, n);
                    334:        }
                    335:        TMP_FREE(marker);
                    336:     }
                    337: }
                    338:
                    339:
                    340: /* a[i] <- a[i]*b[i] mod 2^(n*BITS_PER_MP_LIMB)+1 for 0 <= i < K */
                    341: static void
                    342: #if __STDC__
                    343: mpn_fft_mul_modF_K (mp_limb_t **ap, mp_limb_t **bp, mp_size_t n, int K)
                    344: #else
                    345: mpn_fft_mul_modF_K(ap, bp, n, K)
                    346:      mp_limb_t **ap, **bp;
                    347:      mp_size_t n;
                    348:      int       K;
                    349: #endif
                    350: {
                    351:   int  i;
                    352:   int  sqr = (ap == bp);
                    353:   TMP_DECL(marker);
                    354:
                    355:   TMP_MARK(marker);
                    356:
                    357:   if (n >= (sqr ? FFT_MODF_SQR_THRESHOLD : FFT_MODF_MUL_THRESHOLD)) {
                    358:     int k, K2,nprime2,Nprime2,M2,maxLK,l,Mp2;
                    359:     int       **_fft_l;
                    360:     mp_limb_t **Ap,**Bp,*A,*B,*T;
                    361:
                    362:     k = mpn_fft_best_k (n, sqr);
                    363:     K2 = 1<<k;
                    364:     maxLK = (K2>BITS_PER_MP_LIMB) ? K2 : BITS_PER_MP_LIMB;
                    365:     M2 = n*BITS_PER_MP_LIMB/K2;
                    366:     l = n/K2;
                    367:     Nprime2 = ((2*M2+k+2+maxLK)/maxLK)*maxLK; /* ceil((2*M2+k+3)/maxLK)*maxLK*/
                    368:     nprime2 = Nprime2/BITS_PER_MP_LIMB;
                    369:     Mp2 = Nprime2/K2;
                    370:
                    371:     Ap = TMP_ALLOC_MP_PTRS (K2);
                    372:     Bp = TMP_ALLOC_MP_PTRS (K2);
                    373:     A = TMP_ALLOC_LIMBS (2*K2*(nprime2+1));
                    374:     T = TMP_ALLOC_LIMBS (nprime2+1);
                    375:     B = A + K2*(nprime2+1);
                    376:     _fft_l = TMP_ALLOC_TYPE (k+1, int*);
                    377:     for (i=0;i<=k;i++)
                    378:       _fft_l[i] = TMP_ALLOC_TYPE (1<<i, int);
                    379:     mpn_fft_initl(_fft_l, k);
                    380:
                    381:     TRACE (printf("recurse: %dx%d limbs -> %d times %dx%d (%1.2f)\n", n,
                    382:                   n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
                    383:
                    384:     for (i=0;i<K;i++,ap++,bp++)
                    385:       mpn_mul_fft_internal(*ap, *ap, *bp, n, k, K2, Ap, Bp, A, B, nprime2,
                    386:         l, Mp2, _fft_l, T, 1);
                    387:   }
                    388:   else {
                    389:      mp_limb_t *a, *b, cc, *tp, *tpn; int n2=2*n;
                    390:      tp = TMP_ALLOC_LIMBS (n2);
                    391:      tpn = tp+n;
                    392:      TRACE (printf ("  mpn_mul_n %d of %d limbs\n", K, n));
                    393:      for (i=0;i<K;i++) {
                    394:         a = *ap++; b=*bp++;
                    395:         if (sqr)
                    396:           mpn_sqr_n(tp, a, n);
                    397:         else
                    398:           mpn_mul_n(tp, b, a, n);
                    399:        if (a[n]) cc=mpn_add_n(tpn, tpn, b, n); else cc=0;
                    400:        if (b[n]) cc += mpn_add_n(tpn, tpn, a, n) + a[n];
                    401:        if (cc) {
                    402:           cc = mpn_add_1(tp, tp, n2, cc);
                    403:           ASSERT_NOCARRY (mpn_add_1(tp, tp, n2, cc));
                    404:         }
                    405:        a[n] = mpn_sub_n(a, tp, tpn, n) && mpn_add_1(a, a, n, 1);
                    406:      }
                    407:   }
                    408:   TMP_FREE(marker);
                    409: }
                    410:
                    411:
                    412: /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
                    413:    output: K*A[0] K*A[K-1] ... K*A[1] */
                    414:
                    415: static void
                    416: #if __STDC__
                    417: mpn_fft_fftinv (mp_limb_t **Ap, int K, mp_size_t omega, mp_size_t n,
                    418:                 mp_limb_t *tp)
                    419: #else
                    420: mpn_fft_fftinv(Ap,K,omega,n,tp)
                    421:      mp_limb_t **Ap, *tp;
                    422:      int       K;
                    423:      mp_size_t omega, n;
                    424: #endif
                    425: {
                    426:     if (K==2) {
                    427: #ifdef ADDSUB
                    428:       if (mpn_addsub_n(Ap[0], Ap[1], Ap[0], Ap[1], n+1) & 1)
                    429: #else
                    430:       MPN_COPY(tp, Ap[0], n+1);
                    431:       mpn_add_n(Ap[0], Ap[0], Ap[1], n+1);
                    432:       if (mpn_sub_n(Ap[1], tp, Ap[1], n+1))
                    433: #endif
                    434:         Ap[1][n] = mpn_add_1(Ap[1], Ap[1], n, 1);
                    435:     }
                    436:     else {
                    437:        int j, K2=K/2; mp_limb_t **Bp=Ap+K2, *tmp;
                    438:        TMP_DECL(marker);
                    439:
                    440:        TMP_MARK(marker);
                    441:        tmp = TMP_ALLOC_LIMBS (n+1);
                    442:        mpn_fft_fftinv(Ap, K2, 2*omega, n, tp);
                    443:        mpn_fft_fftinv(Bp, K2, 2*omega, n, tp);
                    444:        /* A[j]     <- A[j] + omega^j A[j+K/2]
                    445:           A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
                    446:         for (j=0;j<K2;j++,Ap++,Bp++) {
                    447:          MPN_COPY(tp, Bp[0], n+1);
                    448:          mpn_fft_mul_2exp_modF(Bp[0], (j+K2)*omega, n, tmp);
                    449:          mpn_fft_add_modF(Bp[0], Ap[0], n);
                    450:          mpn_fft_mul_2exp_modF(tp, j*omega, n, tmp);
                    451:          mpn_fft_add_modF(Ap[0], tp, n);
                    452:        }
                    453:        TMP_FREE(marker);
                    454:     }
                    455: }
                    456:
                    457:
                    458: /* A <- A/2^k mod 2^(n*BITS_PER_MP_LIMB)+1 */
                    459: static void
                    460: #if __STDC__
                    461: mpn_fft_div_2exp_modF (mp_limb_t *ap, int k, mp_size_t n, mp_limb_t *tp)
                    462: #else
                    463: mpn_fft_div_2exp_modF(ap,k,n,tp)
                    464:      mp_limb_t *ap,*tp;
                    465:      int       k;
                    466:      mp_size_t n;
                    467: #endif
                    468: {
                    469:     int i;
                    470:
                    471:     i = 2*n*BITS_PER_MP_LIMB;
                    472:     i = (i-k) % i;
                    473:     mpn_fft_mul_2exp_modF(ap,i,n,tp);
                    474:     /* 1/2^k = 2^(2nL-k) mod 2^(n*BITS_PER_MP_LIMB)+1 */
                    475:     /* normalize so that A < 2^(n*BITS_PER_MP_LIMB)+1 */
                    476:     if (ap[n]==1) {
                    477:       for (i=0;i<n && ap[i]==0;i++);
                    478:       if (i<n) {
                    479:        ap[n]=0;
                    480:        mpn_sub_1(ap, ap, n, 1);
                    481:       }
                    482:     }
                    483: }
                    484:
                    485:
                    486: /* R <- A mod 2^(n*BITS_PER_MP_LIMB)+1, n<=an<=3*n */
                    487: static void
                    488: #if __STDC__
                    489: mpn_fft_norm_modF(mp_limb_t *rp, mp_limb_t *ap, mp_size_t n, mp_size_t an)
                    490: #else
                    491: mpn_fft_norm_modF(rp, ap, n, an)
                    492:      mp_limb_t *rp;
                    493:      mp_limb_t *ap;
                    494:      mp_size_t n;
                    495:      mp_size_t an;
                    496: #endif
                    497: {
                    498:   mp_size_t l;
                    499:
                    500:    if (an>2*n) {
                    501:      l = n;
                    502:      rp[n] = mpn_add_1(rp+an-2*n, ap+an-2*n, 3*n-an,
                    503:                       mpn_add_n(rp,ap,ap+2*n,an-2*n));
                    504:    }
                    505:    else {
                    506:      l = an-n;
                    507:      MPN_COPY(rp, ap, n);
                    508:      rp[n]=0;
                    509:    }
                    510:    if (mpn_sub_n(rp,rp,ap+n,l)) {
                    511:      if (mpn_sub_1(rp+l,rp+l,n+1-l,1))
                    512:        rp[n]=mpn_add_1(rp,rp,n,1);
                    513:    }
                    514: }
                    515:
                    516:
                    517: static void
                    518: #if __STDC__
                    519: mpn_mul_fft_internal(mp_limb_t *op, mp_srcptr n, mp_srcptr m, mp_size_t pl,
                    520:                      int k, int K,
                    521:                      mp_limb_t **Ap, mp_limb_t **Bp,
                    522:                      mp_limb_t *A, mp_limb_t *B,
                    523:                      mp_size_t nprime, mp_size_t l, mp_size_t Mp,
                    524:                      int **_fft_l,
                    525:                      mp_limb_t *T, int rec)
                    526: #else
                    527: mpn_mul_fft_internal(op,n,m,pl,k,K,Ap,Bp,A,B,nprime,l,Mp,_fft_l,T,rec)
                    528:      mp_limb_t *op;
                    529:      mp_srcptr n, m;
                    530:      mp_limb_t **Ap,**Bp,*A,*B,*T;
                    531:      mp_size_t pl,nprime;
                    532:      int       **_fft_l;
                    533:      int       k,K,l,Mp,rec;
                    534: #endif
                    535: {
                    536:   int       i, sqr, pla, lo, sh, j;
                    537:   mp_limb_t *p;
                    538:
                    539:     sqr = (n==m);
                    540:
                    541:     TRACE (printf ("pl=%d k=%d K=%d np=%d l=%d Mp=%d rec=%d sqr=%d\n",
                    542:                    pl,k,K,nprime,l,Mp,rec,sqr));
                    543:
                    544:     /* decomposition of inputs into arrays Ap[i] and Bp[i] */
                    545:     if (rec) for (i=0;i<K;i++) {
                    546:       Ap[i] = A+i*(nprime+1); Bp[i] = B+i*(nprime+1);
                    547:       /* store the next M bits of n into A[i] */
                    548:       /* supposes that M is a multiple of BITS_PER_MP_LIMB */
                    549:       MPN_COPY(Ap[i], n, l); n+=l; MPN_ZERO(Ap[i]+l, nprime+1-l);
                    550:       /* set most significant bits of n and m (important in recursive calls) */
                    551:       if (i==K-1) Ap[i][l]=n[0];
                    552:       mpn_fft_mul_2exp_modF(Ap[i], i*Mp, nprime, T);
                    553:       if (!sqr) {
                    554:        MPN_COPY(Bp[i], m, l); m+=l; MPN_ZERO(Bp[i]+l, nprime+1-l);
                    555:        if (i==K-1) Bp[i][l]=m[0];
                    556:        mpn_fft_mul_2exp_modF(Bp[i], i*Mp, nprime, T);
                    557:       }
                    558:     }
                    559:
                    560:     /* direct fft's */
                    561:     if (sqr) mpn_fft_fft_sqr(Ap,K,_fft_l+k,2*Mp,nprime,1, T);
                    562:     else mpn_fft_fft(Ap,Bp,K,_fft_l+k,2*Mp,nprime,1, T);
                    563:
                    564:     /* term to term multiplications */
                    565:     mpn_fft_mul_modF_K(Ap, (sqr) ? Ap : Bp, nprime, K);
                    566:
                    567:     /* inverse fft's */
                    568:     mpn_fft_fftinv(Ap, K, 2*Mp, nprime, T);
                    569:
                    570:     /* division of terms after inverse fft */
                    571:     for (i=0;i<K;i++) mpn_fft_div_2exp_modF(Ap[i],k+((K-i)%K)*Mp,nprime, T);
                    572:
                    573:     /* addition of terms in result p */
                    574:     MPN_ZERO(T,nprime+1);
                    575:     pla = l*(K-1)+nprime+1; /* number of required limbs for p */
                    576:     p = B; /* B has K*(n'+1) limbs, which is >= pla, i.e. enough */
                    577:     MPN_ZERO(p, pla);
                    578:     sqr=0; /* will accumulate the (signed) carry at p[pla] */
                    579:     for (i=K-1,lo=l*i+nprime,sh=l*i;i>=0;i--,lo-=l,sh-=l) {
                    580:         mp_ptr n = p+sh;
                    581:        j = (K-i)%K;
                    582:        if (mpn_add_n(n,n,Ap[j],nprime+1))
                    583:          sqr += mpn_add_1(n+nprime+1,n+nprime+1,pla-sh-nprime-1,1);
                    584:        T[2*l]=i+1; /* T = (i+1)*2^(2*M) */
                    585:        if (mpn_cmp(Ap[j],T,nprime+1)>0) { /* subtract 2^N'+1 */
                    586:          sqr -= mpn_sub_1(n,n,pla-sh,1);
                    587:          sqr -= mpn_sub_1(p+lo,p+lo,pla-lo,1);
                    588:        }
                    589:     }
                    590:     if (sqr==-1) {
                    591:       if ((sqr=mpn_add_1(p+pla-pl,p+pla-pl,pl,1))) {
                    592:        /* p[pla-pl]...p[pla-1] are all zero */
                    593:         mpn_sub_1(p+pla-pl-1,p+pla-pl-1,pl+1,1);
                    594:        mpn_sub_1(p+pla-1,p+pla-1,1,1);
                    595:       }
                    596:     }
                    597:     else if (sqr==1) {
                    598:            if (pla>=2*pl)
                    599:              while ((sqr=mpn_add_1(p+pla-2*pl,p+pla-2*pl,2*pl,sqr)));
                    600:            else {
                    601:              sqr = mpn_sub_1(p+pla-pl,p+pla-pl,pl,sqr);
                    602:               ASSERT (sqr == 0);
                    603:            }
                    604:     }
                    605:     else
                    606:       ASSERT (sqr == 0);
                    607:
                    608:     /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
                    609:               < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
                    610:              < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
                    611:     mpn_fft_norm_modF(op,p,pl,pla);
                    612: }
                    613:
                    614:
                    615: /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*BITS_PER_MP_LIMB
                    616:    n and m have respectively nl and ml limbs
                    617:    op must have space for pl+1 limbs
                    618:    One must have pl = mpn_fft_next_size(pl, k).
                    619: */
                    620:
                    621: void
                    622: #if __STDC__
                    623: mpn_mul_fft (mp_ptr op, mp_size_t pl,
                    624:              mp_srcptr n, mp_size_t nl,
                    625:              mp_srcptr m, mp_size_t ml,
                    626:              int k)
                    627: #else
                    628: mpn_mul_fft (op, pl, n, nl, m, ml, k)
                    629:      mp_ptr    op;
                    630:      mp_size_t pl;
                    631:      mp_srcptr n;
                    632:      mp_size_t nl;
                    633:      mp_srcptr m;
                    634:      mp_size_t ml;
                    635:      int k;
                    636: #endif
                    637: {
                    638:     int        K,maxLK,i,j;
                    639:     mp_size_t  N,Nprime,nprime,M,Mp,l;
                    640:     mp_limb_t  **Ap,**Bp,*A,*T,*B;
                    641:     int        **_fft_l;
                    642:     int        sqr = (n==m && nl==ml);
                    643:     TMP_DECL(marker);
                    644:
                    645:     TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n",
                    646:                    pl, nl, ml, k));
                    647:     ASSERT_ALWAYS (mpn_fft_next_size(pl, k) == pl);
                    648:
                    649:     TMP_MARK(marker);
                    650:     N = pl*BITS_PER_MP_LIMB;
                    651:     _fft_l = TMP_ALLOC_TYPE (k+1, int*);
                    652:     for (i=0;i<=k;i++)
                    653:       _fft_l[i] = TMP_ALLOC_TYPE (1<<i, int);
                    654:     mpn_fft_initl(_fft_l, k);
                    655:     K = 1<<k;
                    656:     M = N/K;   /* N = 2^k M */
                    657:     l = M/BITS_PER_MP_LIMB;
                    658:     maxLK = (K>BITS_PER_MP_LIMB) ? K : BITS_PER_MP_LIMB;
                    659:
                    660:     Nprime = ((2*M+k+2+maxLK)/maxLK)*maxLK; /* ceil((2*M+k+3)/maxLK)*maxLK; */
                    661:     nprime = Nprime/BITS_PER_MP_LIMB;
                    662:     TRACE (printf ("N=%d K=%d, M=%d, l=%d, maxLK=%d, Np=%d, np=%d\n",
                    663:                    N, K, M, l, maxLK, Nprime, nprime));
                    664:     if (nprime >= (sqr ? FFT_MODF_SQR_THRESHOLD : FFT_MODF_MUL_THRESHOLD)) {
                    665:       maxLK = (1<<mpn_fft_best_k(nprime,n==m))*BITS_PER_MP_LIMB;
                    666:       if (Nprime % maxLK) {
                    667:        Nprime=((Nprime/maxLK)+1)*maxLK;
                    668:        nprime = Nprime/BITS_PER_MP_LIMB;
                    669:       }
                    670:       TRACE (printf ("new maxLK=%d, Np=%d, np=%d\n", maxLK, Nprime, nprime));
                    671:     }
                    672:
                    673:     T = TMP_ALLOC_LIMBS (nprime+1);
                    674:     Mp = Nprime/K;
                    675:
                    676:     TRACE (printf("%dx%d limbs -> %d times %dx%d limbs (%1.2f)\n",
                    677:                   pl,pl,K,nprime,nprime,2.0*(double)N/Nprime/K);
                    678:            printf("   temp space %ld\n", 2*K*(nprime+1)));
                    679:
                    680:     A = _MP_ALLOCATE_FUNC_LIMBS (2*K*(nprime+1));
                    681:     B = A+K*(nprime+1);
                    682:     Ap = TMP_ALLOC_MP_PTRS (K);
                    683:     Bp = TMP_ALLOC_MP_PTRS (K);
                    684:     /* special decomposition for main call */
                    685:     for (i=0;i<K;i++) {
                    686:       Ap[i] = A+i*(nprime+1); Bp[i] = B+i*(nprime+1);
                    687:       /* store the next M bits of n into A[i] */
                    688:       /* supposes that M is a multiple of BITS_PER_MP_LIMB */
                    689:       if (nl>0) {
                    690:        j = (nl>=l) ? l : nl; /* limbs to store in Ap[i] */
                    691:        MPN_COPY(Ap[i], n, j); n+=l; MPN_ZERO(Ap[i]+j, nprime+1-j);
                    692:        mpn_fft_mul_2exp_modF(Ap[i], i*Mp, nprime, T);
                    693:       }
                    694:       else MPN_ZERO(Ap[i], nprime+1);
                    695:       nl -= l;
                    696:       if (n!=m) {
                    697:        if (ml>0) {
                    698:          j = (ml>=l) ? l : ml; /* limbs to store in Bp[i] */
                    699:          MPN_COPY(Bp[i], m, j); m+=l; MPN_ZERO(Bp[i]+j, nprime+1-j);
                    700:          mpn_fft_mul_2exp_modF(Bp[i], i*Mp, nprime, T);
                    701:        }
                    702:        else MPN_ZERO(Bp[i], nprime+1);
                    703:       }
                    704:       ml -= l;
                    705:     }
                    706:     mpn_mul_fft_internal(op,n,m,pl,k,K,Ap,Bp,A,B,nprime,l,Mp,_fft_l,T,0);
                    707:     TMP_FREE(marker);
                    708:     _MP_FREE_FUNC_LIMBS (A, 2*K*(nprime+1));
                    709: }
                    710:
                    711:
                    712: #if WANT_ASSERT
                    713: static int
                    714: #if __STDC__
                    715: mpn_zero_p (mp_ptr p, mp_size_t n)
                    716: #else
                    717:      mpn_zero_p (p, n)
                    718:      mp_ptr p;
                    719:      mp_size_t n;
                    720: #endif
                    721: {
                    722:   mp_size_t i;
                    723:
                    724:   for (i = 0; i < n; i++)
                    725:     {
                    726:       if (p[i] != 0)
                    727:         return 0;
                    728:     }
                    729:
                    730:   return 1;
                    731: }
                    732: #endif
                    733:
                    734:
                    735: /* Multiply {n,nl}*{m,ml} and write the result to {op,nl+ml}.
                    736:
                    737:    FIXME: Duplicating the result like this is wasteful, do something better
                    738:    perhaps at the norm_modF stage above. */
                    739:
                    740: void
                    741: #if __STDC__
                    742: mpn_mul_fft_full (mp_ptr op,
                    743:                   mp_srcptr n, mp_size_t nl,
                    744:                   mp_srcptr m, mp_size_t ml)
                    745: #else
                    746: mpn_mul_fft_full (op, n, nl, m, ml)
                    747:      mp_ptr    op;
                    748:      mp_srcptr n;
                    749:      mp_size_t nl;
                    750:      mp_srcptr m;
                    751:      mp_size_t ml;
                    752: #endif
                    753: {
                    754:   mp_ptr     pad_op;
                    755:   mp_size_t  pl;
                    756:   int        k;
                    757:   int        sqr = (n==m && nl==ml);
                    758:
                    759:   k = mpn_fft_best_k (nl+ml, sqr);
                    760:   pl = mpn_fft_next_size (nl+ml, k);
                    761:
                    762:   TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl=%ld k=%d\n",
                    763:                  nl, ml, pl, k));
                    764:
                    765:   pad_op = _MP_ALLOCATE_FUNC_LIMBS (pl+1);
                    766:   mpn_mul_fft (pad_op, pl, n, nl, m, ml, k);
                    767:
                    768:   ASSERT (mpn_zero_p (pad_op+nl+ml, pl+1-(nl+ml)));
                    769:   MPN_COPY (op, pad_op, nl+ml);
                    770:
                    771:   _MP_FREE_FUNC_LIMBS (pad_op, pl+1);
                    772: }

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>