[BACK]Return to mul_fft.c CVS log [TXT][DIR] Up to [local] / OpenXM_contrib / gmp / mpn / generic

Annotation of OpenXM_contrib/gmp/mpn/generic/mul_fft.c, Revision 1.1.1.2

1.1       maekawa     1: /* An implementation in GMP of Scho"nhage's fast multiplication algorithm
                      2:    modulo 2^N+1, by Paul Zimmermann, INRIA Lorraine, February 1998.
                      3:
                      4:    THE CONTENTS OF THIS FILE ARE FOR INTERNAL USE AND THE FUNCTIONS HAVE
                      5:    MUTABLE INTERFACES.  IT IS ONLY SAFE TO REACH THEM THROUGH DOCUMENTED
                      6:    INTERFACES.  IT IS ALMOST GUARANTEED THAT THEY'LL CHANGE OR DISAPPEAR IN
                      7:    A FUTURE GNU MP RELEASE.
                      8:
1.1.1.2 ! ohara       9: Copyright 1998, 1999, 2000, 2001, 2002 Free Software Foundation, Inc.
1.1       maekawa    10:
                     11: This file is part of the GNU MP Library.
                     12:
                     13: The GNU MP Library is free software; you can redistribute it and/or modify
                     14: it under the terms of the GNU Lesser General Public License as published by
                     15: the Free Software Foundation; either version 2.1 of the License, or (at your
                     16: option) any later version.
                     17:
                     18: The GNU MP Library is distributed in the hope that it will be useful, but
                     19: WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
                     20: or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
                     21: License for more details.
                     22:
                     23: You should have received a copy of the GNU Lesser General Public License
                     24: along with the GNU MP Library; see the file COPYING.LIB.  If not, write to
                     25: the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
                     26: MA 02111-1307, USA. */
                     27:
                     28:
                     29: /* References:
                     30:
                     31:    Schnelle Multiplikation grosser Zahlen, by Arnold Scho"nhage and Volker
                     32:    Strassen, Computing 7, p. 281-292, 1971.
                     33:
                     34:    Asymptotically fast algorithms for the numerical multiplication
                     35:    and division of polynomials with complex coefficients, by Arnold Scho"nhage,
                     36:    Computer Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
                     37:
                     38:    Tapes versus Pointers, a study in implementing fast algorithms,
                     39:    by Arnold Scho"nhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
                     40:
                     41:    See also http://www.loria.fr/~zimmerma/bignum
                     42:
                     43:
                     44:    Future:
                     45:
                     46:    It might be possible to avoid a small number of MPN_COPYs by using a
                     47:    rotating temporary or two.
                     48:
                     49:    Multiplications of unequal sized operands can be done with this code, but
                     50:    it needs a tighter test for identifying squaring (same sizes as well as
                     51:    same pointers).  */
                     52:
                     53:
                     54: #include <stdio.h>
                     55: #include "gmp.h"
                     56: #include "gmp-impl.h"
                     57:
                     58:
                     59: /* Change this to "#define TRACE(x) x" for some traces. */
                     60: #define TRACE(x)
                     61:
                     62:
                     63:
                     64: FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] = {
1.1.1.2 ! ohara      65:   MUL_FFT_TABLE,
        !            66:   SQR_FFT_TABLE
1.1       maekawa    67: };
                     68:
                     69:
                     70: static void mpn_mul_fft_internal
1.1.1.2 ! ohara      71: _PROTO ((mp_ptr, mp_srcptr, mp_srcptr, mp_size_t, int, int, mp_ptr *, mp_ptr *,
        !            72:         mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr,int));
1.1       maekawa    73:
                     74:
                     75: /* Find the best k to use for a mod 2^(n*BITS_PER_MP_LIMB)+1 FFT.
                     76:    sqr==0 if for a multiply, sqr==1 for a square */
                     77: int
                     78: mpn_fft_best_k (mp_size_t n, int sqr)
                     79: {
1.1.1.2 ! ohara      80:   int i;
1.1       maekawa    81:
                     82:   for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
                     83:     if (n < mpn_fft_table[sqr][i])
                     84:       return i + FFT_FIRST_K;
                     85:
                     86:   /* treat 4*last as one further entry */
1.1.1.2 ! ohara      87:   if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
1.1       maekawa    88:     return i + FFT_FIRST_K;
                     89:   else
                     90:     return i + FFT_FIRST_K + 1;
                     91: }
                     92:
                     93:
                     94: /* Returns smallest possible number of limbs >= pl for a fft of size 2^k.
1.1.1.2 ! ohara      95:
        !            96:    FIXME: Is this N rounded up to the next multiple of (2^k)*BITS_PER_MP_LIMB
        !            97:    bits and therefore simply pl rounded up to a multiple of 2^k? */
1.1       maekawa    98:
                     99: mp_size_t
                    100: mpn_fft_next_size (mp_size_t pl, int k)
                    101: {
                    102:   mp_size_t N, M;
1.1.1.2 ! ohara     103:   int K;
1.1       maekawa   104:
                    105:   /*  if (k==0) k = mpn_fft_best_k (pl, sqr); */
1.1.1.2 ! ohara     106:   N = pl * BITS_PER_MP_LIMB;
        !           107:   K = 1 << k;
        !           108:   if (N % K)
        !           109:     N = (N / K + 1) * K;
        !           110:   M = N / K;
        !           111:   if (M % BITS_PER_MP_LIMB)
        !           112:     N = ((M / BITS_PER_MP_LIMB) + 1) * BITS_PER_MP_LIMB * K;
        !           113:   return N / BITS_PER_MP_LIMB;
1.1       maekawa   114: }
                    115:
                    116:
                    117: static void
1.1.1.2 ! ohara     118: mpn_fft_initl (int **l, int k)
1.1       maekawa   119: {
1.1.1.2 ! ohara     120:   int i, j, K;
1.1       maekawa   121:
1.1.1.2 ! ohara     122:   l[0][0] = 0;
        !           123:   for (i = 1,K = 2; i <= k; i++,K *= 2)
        !           124:     {
        !           125:       for (j = 0; j < K / 2; j++)
        !           126:        {
        !           127:          l[i][j] = 2 * l[i - 1][j];
        !           128:          l[i][K / 2 + j] = 1 + l[i][j];
1.1       maekawa   129:        }
                    130:     }
                    131: }
                    132:
                    133:
                    134: /* a <- a*2^e mod 2^(n*BITS_PER_MP_LIMB)+1 */
                    135: static void
1.1.1.2 ! ohara     136: mpn_fft_mul_2exp_modF (mp_ptr ap, int e, mp_size_t n, mp_ptr tp)
1.1       maekawa   137: {
1.1.1.2 ! ohara     138:   int d, sh, i;
        !           139:   mp_limb_t cc;
1.1       maekawa   140:
1.1.1.2 ! ohara     141:   d = e % (n * BITS_PER_MP_LIMB);      /* 2^e = (+/-) 2^d */
1.1       maekawa   142:   sh = d % BITS_PER_MP_LIMB;
1.1.1.2 ! ohara     143:   if (sh != 0)
        !           144:     mpn_lshift (tp, ap, n + 1, sh);    /* no carry here */
        !           145:   else
        !           146:     MPN_COPY (tp, ap, n + 1);
        !           147:   d /= BITS_PER_MP_LIMB;               /* now shift of d limbs to the left */
        !           148:   if (d)
        !           149:     {
        !           150:       /* ap[d..n-1] = tp[0..n-d-1], ap[0..d-1] = -tp[n-d..n-1] */
        !           151:       /* mpn_xor would be more efficient here */
        !           152:       for (i = d - 1; i >= 0; i--)
        !           153:        ap[i] = ~tp[n - d + i];
        !           154:       cc = 1 - mpn_add_1 (ap, ap, d, CNST_LIMB(1));
        !           155:       if (cc != 0)
        !           156:        cc = mpn_sub_1 (ap + d, tp, n - d, CNST_LIMB(1));
        !           157:       else
        !           158:        MPN_COPY (ap + d, tp, n - d);
        !           159:       cc += mpn_sub_1 (ap + d, ap + d, n - d, tp[n]);
        !           160:       if (cc != 0)
        !           161:        ap[n] = mpn_add_1 (ap, ap, n, cc);
        !           162:       else
        !           163:        ap[n] = 0;
        !           164:     }
        !           165:   else if ((ap[n] = mpn_sub_1 (ap, tp, n, tp[n])))
        !           166:     {
        !           167:       ap[n] = mpn_add_1 (ap, ap, n, CNST_LIMB(1));
        !           168:     }
        !           169:   if ((e / (n * BITS_PER_MP_LIMB)) % 2)
        !           170:     {
        !           171:       mp_limb_t c;
        !           172:       mpn_com_n (ap, ap, n);
        !           173:       c = ap[n] + 2;
        !           174:       ap[n] = 0;
        !           175:       mpn_incr_u (ap, c);
        !           176:     }
1.1       maekawa   177: }
                    178:
                    179:
                    180: /* a <- a+b mod 2^(n*BITS_PER_MP_LIMB)+1 */
                    181: static void
1.1.1.2 ! ohara     182: mpn_fft_add_modF (mp_ptr ap, mp_ptr bp, int n)
1.1       maekawa   183: {
                    184:   mp_limb_t c;
                    185:
1.1.1.2 ! ohara     186:   c = ap[n] + bp[n] + mpn_add_n (ap, ap, bp, n);
        !           187:   if (c > 1)
        !           188:     {
        !           189:       ap[n] = c - 1;
        !           190:       mpn_decr_u (ap, 1);
        !           191:     }
        !           192:   else
        !           193:     ap[n] = c;
1.1       maekawa   194: }
                    195:
                    196:
                    197: /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
1.1.1.2 ! ohara     198:          N=n*BITS_PER_MP_LIMB
        !           199:          2^omega is a primitive root mod 2^N+1
1.1       maekawa   200:    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
                    201:
                    202: static void
1.1.1.2 ! ohara     203: mpn_fft_fft_sqr (mp_ptr *Ap, mp_size_t K, int **ll,
        !           204:                 mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
1.1       maekawa   205: {
1.1.1.2 ! ohara     206:   if (K == 2)
        !           207:     {
        !           208: #if HAVE_NATIVE_mpn_addsub_n
        !           209:       if (mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1)
        !           210:        Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, CNST_LIMB(1));
        !           211: #else
        !           212:       MPN_COPY (tp, Ap[0], n + 1);
        !           213:       mpn_add_n (Ap[0], Ap[0], Ap[inc],n + 1);
        !           214:       if (mpn_sub_n (Ap[inc], tp, Ap[inc],n + 1))
        !           215:        Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, CNST_LIMB(1));
1.1       maekawa   216: #endif
                    217:     }
1.1.1.2 ! ohara     218:   else
        !           219:     {
        !           220:       int j, inc2 = 2 * inc;
        !           221:       int *lk = *ll;
        !           222:       mp_ptr tmp;
1.1       maekawa   223:       TMP_DECL(marker);
                    224:
                    225:       TMP_MARK(marker);
1.1.1.2 ! ohara     226:       tmp = TMP_ALLOC_LIMBS (n + 1);
        !           227:       mpn_fft_fft_sqr (Ap, K/2,ll-1,2 * omega,n,inc2, tp);
        !           228:       mpn_fft_fft_sqr (Ap+inc, K/2,ll-1,2 * omega,n,inc2, tp);
        !           229:       /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
        !           230:         A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
        !           231:       for (j = 0; j < K / 2; j++,lk += 2,Ap += 2 * inc)
        !           232:        {
        !           233:          MPN_COPY (tp, Ap[inc], n + 1);
        !           234:          mpn_fft_mul_2exp_modF (Ap[inc], lk[1] * omega, n, tmp);
        !           235:          mpn_fft_add_modF (Ap[inc], Ap[0], n);
        !           236:          mpn_fft_mul_2exp_modF (tp, lk[0] * omega, n, tmp);
        !           237:          mpn_fft_add_modF (Ap[0], tp, n);
1.1       maekawa   238:        }
1.1.1.2 ! ohara     239:       TMP_FREE(marker);
1.1       maekawa   240:     }
                    241: }
                    242:
                    243:
                    244: /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
1.1.1.2 ! ohara     245:          N=n*BITS_PER_MP_LIMB
        !           246:         2^omega is a primitive root mod 2^N+1
1.1       maekawa   247:    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
                    248:
                    249: static void
1.1.1.2 ! ohara     250: mpn_fft_fft (mp_ptr *Ap, mp_ptr *Bp, mp_size_t K, int **ll,
        !           251:             mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
1.1       maekawa   252: {
1.1.1.2 ! ohara     253:   if (K == 2)
        !           254:     {
        !           255: #if HAVE_NATIVE_mpn_addsub_n
        !           256:       if (mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1)
        !           257:        Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, CNST_LIMB(1));
        !           258: #else
        !           259:       MPN_COPY (tp, Ap[0], n + 1);
        !           260:       mpn_add_n (Ap[0], Ap[0], Ap[inc],n + 1);
        !           261:       if (mpn_sub_n (Ap[inc], tp, Ap[inc],n + 1))
        !           262:        Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, CNST_LIMB(1));
        !           263: #endif
        !           264: #if HAVE_NATIVE_mpn_addsub_n
        !           265:       if (mpn_addsub_n (Bp[0], Bp[inc], Bp[0], Bp[inc], n + 1) & 1)
        !           266:        Bp[inc][n] = mpn_add_1 (Bp[inc], Bp[inc], n, CNST_LIMB(1));
        !           267: #else
        !           268:       MPN_COPY (tp, Bp[0], n + 1);
        !           269:       mpn_add_n (Bp[0], Bp[0], Bp[inc],n + 1);
        !           270:       if (mpn_sub_n (Bp[inc], tp, Bp[inc],n + 1))
        !           271:        Bp[inc][n] = mpn_add_1 (Bp[inc], Bp[inc], n, CNST_LIMB(1));
1.1       maekawa   272: #endif
                    273:     }
1.1.1.2 ! ohara     274:   else
        !           275:     {
        !           276:       int j, inc2=2 * inc;
        !           277:       int *lk = *ll;
        !           278:       mp_ptr tmp;
        !           279:       TMP_DECL(marker);
        !           280:
        !           281:       TMP_MARK(marker);
        !           282:       tmp = TMP_ALLOC_LIMBS (n + 1);
        !           283:       mpn_fft_fft (Ap, Bp, K/2,ll-1,2 * omega,n,inc2, tp);
        !           284:       mpn_fft_fft (Ap+inc, Bp+inc, K/2,ll-1,2 * omega,n,inc2, tp);
        !           285:       /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
        !           286:         A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
        !           287:       for (j = 0; j < K / 2; j++,lk += 2,Ap += 2 * inc,Bp += 2 * inc)
        !           288:        {
        !           289:          MPN_COPY (tp, Ap[inc], n + 1);
        !           290:          mpn_fft_mul_2exp_modF (Ap[inc], lk[1] * omega, n, tmp);
        !           291:          mpn_fft_add_modF (Ap[inc], Ap[0], n);
        !           292:          mpn_fft_mul_2exp_modF (tp, lk[0] * omega, n, tmp);
        !           293:          mpn_fft_add_modF (Ap[0], tp, n);
        !           294:          MPN_COPY (tp, Bp[inc], n + 1);
        !           295:          mpn_fft_mul_2exp_modF (Bp[inc], lk[1] * omega, n, tmp);
        !           296:          mpn_fft_add_modF (Bp[inc], Bp[0], n);
        !           297:          mpn_fft_mul_2exp_modF (tp, lk[0] * omega, n, tmp);
        !           298:          mpn_fft_add_modF (Bp[0], tp, n);
1.1       maekawa   299:        }
1.1.1.2 ! ohara     300:       TMP_FREE(marker);
        !           301:     }
        !           302: }
        !           303:
        !           304:
        !           305: /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*BITS_PER_MP_LIMB)+1,
        !           306:    by subtracting that modulus if necessary.
        !           307:
        !           308:    If ap[0..n] is exactly 2^(n*BITS_PER_MP_LIMB) then the sub_1 produces a
        !           309:    borrow and the limbs must be zeroed out again.  This will occur very
        !           310:    infrequently.  */
        !           311:
        !           312: static void
        !           313: mpn_fft_norm (mp_ptr ap, mp_size_t n)
        !           314: {
        !           315:   ASSERT (ap[n] <= 1);
        !           316:
        !           317:   if (ap[n])
        !           318:     {
        !           319:       if ((ap[n] = mpn_sub_1 (ap, ap, n, CNST_LIMB(1))))
        !           320:        MPN_ZERO (ap, n);
1.1       maekawa   321:     }
                    322: }
                    323:
                    324:
                    325: /* a[i] <- a[i]*b[i] mod 2^(n*BITS_PER_MP_LIMB)+1 for 0 <= i < K */
                    326: static void
1.1.1.2 ! ohara     327: mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K)
1.1       maekawa   328: {
1.1.1.2 ! ohara     329:   int i;
        !           330:   int sqr = (ap == bp);
1.1       maekawa   331:   TMP_DECL(marker);
                    332:
1.1.1.2 ! ohara     333:   TMP_MARK(marker);
        !           334:
        !           335:   if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
        !           336:     {
        !           337:       int k, K2,nprime2,Nprime2,M2,maxLK,l,Mp2;
        !           338:       int **_fft_l;
        !           339:       mp_ptr *Ap,*Bp,A,B,T;
        !           340:
        !           341:       k = mpn_fft_best_k (n, sqr);
        !           342:       K2 = 1<<k;
        !           343:       maxLK = (K2>BITS_PER_MP_LIMB) ? K2 : BITS_PER_MP_LIMB;
        !           344:       M2 = n*BITS_PER_MP_LIMB/K2;
        !           345:       l = n/K2;
        !           346:       Nprime2 = ((2 * M2+k+2+maxLK)/maxLK)*maxLK; /* ceil()(2*M2+k+3)/maxLK)*maxLK*/
        !           347:       nprime2 = Nprime2/BITS_PER_MP_LIMB;
        !           348:       Mp2 = Nprime2/K2;
        !           349:
        !           350:       Ap = TMP_ALLOC_MP_PTRS (K2);
        !           351:       Bp = TMP_ALLOC_MP_PTRS (K2);
        !           352:       A = TMP_ALLOC_LIMBS (2 * K2 * (nprime2 + 1));
        !           353:       T = TMP_ALLOC_LIMBS (nprime2 + 1);
        !           354:       B = A + K2 * (nprime2 + 1);
        !           355:       _fft_l = TMP_ALLOC_TYPE (k + 1, int*);
        !           356:       for (i = 0; i <= k; i++)
        !           357:        _fft_l[i] = TMP_ALLOC_TYPE (1<<i, int);
        !           358:       mpn_fft_initl (_fft_l, k);
        !           359:
        !           360:       TRACE (printf ("recurse: %dx%d limbs -> %d times %dx%d (%1.2f)\n", n,
        !           361:                    n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
        !           362:       for (i = 0; i < K; i++,ap++,bp++)
        !           363:        {
        !           364:          mpn_fft_norm (*ap, n);
        !           365:          if (!sqr) mpn_fft_norm (*bp, n);
        !           366:          mpn_mul_fft_internal (*ap, *ap, *bp, n, k, K2, Ap, Bp, A, B, nprime2,
        !           367:                               l, Mp2, _fft_l, T, 1);
        !           368:        }
        !           369:     }
        !           370:   else
        !           371:     {
        !           372:       mp_ptr a, b, tp, tpn;
        !           373:       mp_limb_t cc;
        !           374:       int n2 = 2 * n;
        !           375:       tp = TMP_ALLOC_LIMBS (n2);
        !           376:       tpn = tp+n;
        !           377:       TRACE (printf ("  mpn_mul_n %d of %d limbs\n", K, n));
        !           378:       for (i = 0; i < K; i++)
        !           379:        {
        !           380:          a = *ap++;
        !           381:          b = *bp++;
        !           382:          if (sqr)
        !           383:            mpn_sqr_n (tp, a, n);
        !           384:          else
        !           385:            mpn_mul_n (tp, b, a, n);
        !           386:          if (a[n] != 0)
        !           387:            cc = mpn_add_n (tpn, tpn, b, n);
        !           388:          else
        !           389:            cc = 0;
        !           390:          if (b[n] != 0)
        !           391:            cc += mpn_add_n (tpn, tpn, a, n) + a[n];
        !           392:          if (cc != 0)
        !           393:            {
        !           394:              cc = mpn_add_1 (tp, tp, n2, cc);
        !           395:              ASSERT_NOCARRY (mpn_add_1 (tp, tp, n2, cc));
        !           396:            }
        !           397:          a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
        !           398:        }
        !           399:     }
        !           400:   TMP_FREE(marker);
1.1       maekawa   401: }
                    402:
                    403:
                    404: /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
                    405:    output: K*A[0] K*A[K-1] ... K*A[1] */
                    406:
                    407: static void
1.1.1.2 ! ohara     408: mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp)
1.1       maekawa   409: {
1.1.1.2 ! ohara     410:   if (K == 2)
        !           411:     {
        !           412: #if HAVE_NATIVE_mpn_addsub_n
        !           413:       if (mpn_addsub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1)
        !           414:        Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, CNST_LIMB(1));
        !           415: #else
        !           416:       MPN_COPY (tp, Ap[0], n + 1);
        !           417:       mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
        !           418:       if (mpn_sub_n (Ap[1], tp, Ap[1], n + 1))
        !           419:        Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, CNST_LIMB(1));
1.1       maekawa   420: #endif
                    421:     }
1.1.1.2 ! ohara     422:   else
        !           423:     {
        !           424:       int j, K2 = K / 2;
        !           425:       mp_ptr *Bp = Ap + K2, tmp;
        !           426:       TMP_DECL(marker);
        !           427:
        !           428:       TMP_MARK(marker);
        !           429:       tmp = TMP_ALLOC_LIMBS (n + 1);
        !           430:       mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp);
        !           431:       mpn_fft_fftinv (Bp, K2, 2 * omega, n, tp);
        !           432:       /* A[j]     <- A[j] + omega^j A[j+K/2]
        !           433:         A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
        !           434:       for (j = 0; j < K2; j++,Ap++,Bp++)
        !           435:        {
        !           436:          MPN_COPY (tp, Bp[0], n + 1);
        !           437:          mpn_fft_mul_2exp_modF (Bp[0], (j + K2) * omega, n, tmp);
        !           438:          mpn_fft_add_modF (Bp[0], Ap[0], n);
        !           439:          mpn_fft_mul_2exp_modF (tp, j * omega, n, tmp);
        !           440:          mpn_fft_add_modF (Ap[0], tp, n);
1.1       maekawa   441:        }
1.1.1.2 ! ohara     442:       TMP_FREE(marker);
1.1       maekawa   443:     }
                    444: }
                    445:
                    446:
                    447: /* A <- A/2^k mod 2^(n*BITS_PER_MP_LIMB)+1 */
                    448: static void
1.1.1.2 ! ohara     449: mpn_fft_div_2exp_modF (mp_ptr ap, int k, mp_size_t n, mp_ptr tp)
1.1       maekawa   450: {
1.1.1.2 ! ohara     451:   int i;
        !           452:
        !           453:   i = 2 * n * BITS_PER_MP_LIMB;
        !           454:   i = (i - k) % i;
        !           455:   mpn_fft_mul_2exp_modF (ap, i, n, tp);
        !           456:   /* 1/2^k = 2^(2nL-k) mod 2^(n*BITS_PER_MP_LIMB)+1 */
        !           457:   /* normalize so that A < 2^(n*BITS_PER_MP_LIMB)+1 */
        !           458:   mpn_fft_norm (ap, n);
1.1       maekawa   459: }
                    460:
                    461:
                    462: /* R <- A mod 2^(n*BITS_PER_MP_LIMB)+1, n<=an<=3*n */
                    463: static void
1.1.1.2 ! ohara     464: mpn_fft_norm_modF (mp_ptr rp, mp_ptr ap, mp_size_t n, mp_size_t an)
1.1       maekawa   465: {
                    466:   mp_size_t l;
                    467:
1.1.1.2 ! ohara     468:   if (an > 2 * n)
        !           469:     {
        !           470:       l = n;
        !           471:       rp[n] = mpn_add_1 (rp + an - 2 * n, ap + an - 2 * n, 3 * n - an,
        !           472:                        mpn_add_n (rp, ap, ap + 2 * n, an - 2 * n));
        !           473:     }
        !           474:   else
        !           475:     {
        !           476:       l = an - n;
        !           477:       MPN_COPY (rp, ap, n);
        !           478:       rp[n] = 0;
        !           479:     }
        !           480:   if (mpn_sub_n (rp, rp, ap + n, l))
        !           481:     {
        !           482:       if (mpn_sub_1 (rp + l, rp + l, n + 1 - l, CNST_LIMB(1)))
        !           483:        rp[n] = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
        !           484:     }
1.1       maekawa   485: }
                    486:
                    487:
                    488: static void
1.1.1.2 ! ohara     489: mpn_mul_fft_internal (mp_ptr op, mp_srcptr n, mp_srcptr m, mp_size_t pl,
        !           490:                      int k, int K,
        !           491:                      mp_ptr *Ap, mp_ptr *Bp,
        !           492:                      mp_ptr A, mp_ptr B,
        !           493:                      mp_size_t nprime, mp_size_t l, mp_size_t Mp,
        !           494:                      int **_fft_l,
        !           495:                      mp_ptr T, int rec)
        !           496: {
        !           497:   int i, sqr, pla, lo, sh, j;
        !           498:   mp_ptr p;
        !           499:
        !           500:   sqr = n == m;
        !           501:
        !           502:   TRACE (printf ("pl=%d k=%d K=%d np=%d l=%d Mp=%d rec=%d sqr=%d\n",
        !           503:                 pl,k,K,nprime,l,Mp,rec,sqr));
        !           504:
        !           505:   /* decomposition of inputs into arrays Ap[i] and Bp[i] */
        !           506:   if (rec)
        !           507:     for (i = 0; i < K; i++)
        !           508:       {
        !           509:        Ap[i] = A + i * (nprime + 1); Bp[i] = B + i * (nprime + 1);
        !           510:        /* store the next M bits of n into A[i] */
        !           511:        /* supposes that M is a multiple of BITS_PER_MP_LIMB */
        !           512:        MPN_COPY (Ap[i], n, l); n += l;
        !           513:        MPN_ZERO (Ap[i]+l, nprime + 1 - l);
        !           514:        /* set most significant bits of n and m (important in recursive calls) */
        !           515:        if (i == K - 1)
        !           516:          Ap[i][l] = n[0];
        !           517:        mpn_fft_mul_2exp_modF (Ap[i], i * Mp, nprime, T);
        !           518:        if (!sqr)
        !           519:          {
        !           520:            MPN_COPY (Bp[i], m, l); m += l;
        !           521:            MPN_ZERO (Bp[i] + l, nprime + 1 - l);
        !           522:            if (i == K - 1)
        !           523:              Bp[i][l] = m[0];
        !           524:            mpn_fft_mul_2exp_modF (Bp[i], i * Mp, nprime, T);
        !           525:          }
        !           526:       }
1.1       maekawa   527:
1.1.1.2 ! ohara     528:   /* direct fft's */
        !           529:   if (sqr)
        !           530:     mpn_fft_fft_sqr (Ap, K, _fft_l + k, 2 * Mp, nprime, 1, T);
        !           531:   else
        !           532:     mpn_fft_fft (Ap, Bp, K, _fft_l + k, 2 * Mp, nprime, 1, T);
1.1       maekawa   533:
1.1.1.2 ! ohara     534:   /* term to term multiplications */
        !           535:   mpn_fft_mul_modF_K (Ap, (sqr) ? Ap : Bp, nprime, K);
1.1       maekawa   536:
1.1.1.2 ! ohara     537:   /* inverse fft's */
        !           538:   mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
1.1       maekawa   539:
1.1.1.2 ! ohara     540:   /* division of terms after inverse fft */
        !           541:   for (i = 0; i < K; i++)
        !           542:     mpn_fft_div_2exp_modF (Ap[i], k + ((K - i) % K) * Mp, nprime, T);
        !           543:
        !           544:   /* addition of terms in result p */
        !           545:   MPN_ZERO (T, nprime + 1);
        !           546:   pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
        !           547:   p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
        !           548:   MPN_ZERO (p, pla);
        !           549:   sqr = 0; /* will accumulate the (signed) carry at p[pla] */
        !           550:   for (i = K - 1,lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
        !           551:     {
        !           552:       mp_ptr n = p+sh;
        !           553:       j = (K-i)%K;
        !           554:       if (mpn_add_n (n, n, Ap[j], nprime + 1))
        !           555:        sqr += mpn_add_1 (n + nprime + 1, n + nprime + 1, pla - sh - nprime - 1, CNST_LIMB(1));
        !           556:       T[2 * l]=i + 1; /* T = (i + 1)*2^(2*M) */
        !           557:       if (mpn_cmp (Ap[j],T,nprime + 1)>0)
        !           558:        { /* subtract 2^N'+1 */
        !           559:          sqr -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
        !           560:          sqr -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
1.1       maekawa   561:        }
                    562:     }
1.1.1.2 ! ohara     563:     if (sqr == -1)
        !           564:       {
        !           565:        if ((sqr = mpn_add_1 (p + pla - pl,p + pla - pl,pl, CNST_LIMB(1))))
        !           566:          {
        !           567:            /* p[pla-pl]...p[pla-1] are all zero */
        !           568:            mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
        !           569:            mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
        !           570:          }
        !           571:       }
        !           572:     else if (sqr == 1)
        !           573:       {
        !           574:        if (pla >= 2 * pl)
        !           575:          {
        !           576:            while ((sqr = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, sqr)))
        !           577:              ;
        !           578:          }
        !           579:        else
        !           580:          {
        !           581:            sqr = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, sqr);
        !           582:            ASSERT (sqr == 0);
        !           583:          }
1.1       maekawa   584:       }
                    585:     else
                    586:       ASSERT (sqr == 0);
                    587:
                    588:     /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
1.1.1.2 ! ohara     589:        < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
        !           590:        < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
        !           591:     mpn_fft_norm_modF (op, p, pl, pla);
1.1       maekawa   592: }
                    593:
                    594:
                    595: /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*BITS_PER_MP_LIMB
                    596:    n and m have respectively nl and ml limbs
                    597:    op must have space for pl+1 limbs
1.1.1.2 ! ohara     598:    One must have pl = mpn_fft_next_size (pl, k).
1.1       maekawa   599: */
                    600:
                    601: void
                    602: mpn_mul_fft (mp_ptr op, mp_size_t pl,
1.1.1.2 ! ohara     603:             mp_srcptr n, mp_size_t nl,
        !           604:             mp_srcptr m, mp_size_t ml,
        !           605:             int k)
        !           606: {
        !           607:   int K,maxLK,i,j;
        !           608:   mp_size_t N, Nprime, nprime, M, Mp, l;
        !           609:   mp_ptr *Ap,*Bp, A, T, B;
        !           610:   int **_fft_l;
        !           611:   int sqr = (n == m && nl == ml);
        !           612:   TMP_DECL(marker);
        !           613:
        !           614:   TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
        !           615:   ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
        !           616:
        !           617:   TMP_MARK(marker);
        !           618:   N = pl * BITS_PER_MP_LIMB;
        !           619:   _fft_l = TMP_ALLOC_TYPE (k + 1, int*);
        !           620:   for (i = 0; i <= k; i++)
        !           621:     _fft_l[i] = TMP_ALLOC_TYPE (1<<i, int);
        !           622:   mpn_fft_initl (_fft_l, k);
        !           623:   K = 1<<k;
        !           624:   M = N/K;     /* N = 2^k M */
        !           625:   l = M/BITS_PER_MP_LIMB;
        !           626:   maxLK = (K>BITS_PER_MP_LIMB) ? K : BITS_PER_MP_LIMB;
        !           627:
        !           628:   Nprime = ((2 * M + k + 2 + maxLK) / maxLK) * maxLK; /* ceil((2*M+k+3)/maxLK)*maxLK; */
        !           629:   nprime = Nprime / BITS_PER_MP_LIMB;
        !           630:   TRACE (printf ("N=%d K=%d, M=%d, l=%d, maxLK=%d, Np=%d, np=%d\n",
        !           631:                 N, K, M, l, maxLK, Nprime, nprime));
        !           632:   if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
        !           633:     {
        !           634:       maxLK = (1 << mpn_fft_best_k (nprime,n == m)) * BITS_PER_MP_LIMB;
        !           635:       if (Nprime % maxLK)
        !           636:        {
        !           637:          Nprime = ((Nprime / maxLK) + 1) * maxLK;
        !           638:          nprime = Nprime / BITS_PER_MP_LIMB;
        !           639:        }
1.1       maekawa   640:       TRACE (printf ("new maxLK=%d, Np=%d, np=%d\n", maxLK, Nprime, nprime));
                    641:     }
                    642:
1.1.1.2 ! ohara     643:   T = TMP_ALLOC_LIMBS (nprime + 1);
        !           644:   Mp = Nprime/K;
1.1       maekawa   645:
1.1.1.2 ! ohara     646:   TRACE (printf ("%dx%d limbs -> %d times %dx%d limbs (%1.2f)\n",
        !           647:                pl,pl,K,nprime,nprime,2.0*(double)N/Nprime/K);
        !           648:         printf ("   temp space %ld\n", 2 * K * (nprime + 1)));
        !           649:
        !           650:   A = __GMP_ALLOCATE_FUNC_LIMBS (2 * K * (nprime + 1));
        !           651:   B = A + K * (nprime + 1);
        !           652:   Ap = TMP_ALLOC_MP_PTRS (K);
        !           653:   Bp = TMP_ALLOC_MP_PTRS (K);
        !           654:   /* special decomposition for main call */
        !           655:   for (i = 0; i < K; i++)
        !           656:     {
        !           657:       Ap[i] = A + i * (nprime + 1); Bp[i] = B + i * (nprime + 1);
1.1       maekawa   658:       /* store the next M bits of n into A[i] */
                    659:       /* supposes that M is a multiple of BITS_PER_MP_LIMB */
1.1.1.2 ! ohara     660:       if (nl > 0)
        !           661:        {
        !           662:          j = (nl>=l) ? l : nl; /* limbs to store in Ap[i] */
        !           663:          MPN_COPY (Ap[i], n, j); n += l;
        !           664:          MPN_ZERO (Ap[i] + j, nprime + 1 - j);
        !           665:          mpn_fft_mul_2exp_modF (Ap[i], i * Mp, nprime, T);
        !           666:        }
        !           667:       else MPN_ZERO (Ap[i], nprime + 1);
1.1       maekawa   668:       nl -= l;
1.1.1.2 ! ohara     669:       if (n != m)
        !           670:        {
        !           671:          if (ml > 0)
        !           672:            {
        !           673:              j = (ml>=l) ? l : ml; /* limbs to store in Bp[i] */
        !           674:              MPN_COPY (Bp[i], m, j); m += l;
        !           675:              MPN_ZERO (Bp[i] + j, nprime + 1 - j);
        !           676:              mpn_fft_mul_2exp_modF (Bp[i], i * Mp, nprime, T);
        !           677:            }
        !           678:          else
        !           679:            MPN_ZERO (Bp[i], nprime + 1);
1.1       maekawa   680:        }
                    681:       ml -= l;
                    682:     }
1.1.1.2 ! ohara     683:   mpn_mul_fft_internal (op, n, m, pl, k, K, Ap, Bp, A, B, nprime, l, Mp, _fft_l, T, 0);
        !           684:   TMP_FREE(marker);
        !           685:   __GMP_FREE_FUNC_LIMBS (A, 2 * K * (nprime + 1));
1.1       maekawa   686: }
                    687:
                    688:
                    689: /* Multiply {n,nl}*{m,ml} and write the result to {op,nl+ml}.
                    690:
                    691:    FIXME: Duplicating the result like this is wasteful, do something better
                    692:    perhaps at the norm_modF stage above. */
                    693:
                    694: void
                    695: mpn_mul_fft_full (mp_ptr op,
1.1.1.2 ! ohara     696:                  mp_srcptr n, mp_size_t nl,
        !           697:                  mp_srcptr m, mp_size_t ml)
1.1       maekawa   698: {
1.1.1.2 ! ohara     699:   mp_ptr pad_op;
        !           700:   mp_size_t pl;
        !           701:   int k;
        !           702:   int sqr = (n == m && nl == ml);
1.1       maekawa   703:
1.1.1.2 ! ohara     704:   k = mpn_fft_best_k (nl + ml, sqr);
        !           705:   pl = mpn_fft_next_size (nl + ml, k);
1.1       maekawa   706:
1.1.1.2 ! ohara     707:   TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl=%ld k=%d\n", nl, ml, pl, k));
1.1       maekawa   708:
1.1.1.2 ! ohara     709:   pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl + 1);
1.1       maekawa   710:   mpn_mul_fft (pad_op, pl, n, nl, m, ml, k);
                    711:
1.1.1.2 ! ohara     712:   ASSERT_MPN_ZERO_P (pad_op + nl + ml, pl + 1 - (nl + ml));
        !           713:   MPN_COPY (op, pad_op, nl + ml);
1.1       maekawa   714:
1.1.1.2 ! ohara     715:   __GMP_FREE_FUNC_LIMBS (pad_op, pl + 1);
1.1       maekawa   716: }

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>