[BACK]Return to standard_complex_linear_solvers.adb CVS log [TXT][DIR] Up to [local] / OpenXM_contrib / PHC / Ada / Math_Lib / Matrices

Annotation of OpenXM_contrib/PHC/Ada/Math_Lib/Matrices/standard_complex_linear_solvers.adb, Revision 1.1.1.1

1.1       maekawa     1: with Standard_Complex_Numbers;           use Standard_Complex_Numbers;
                      2:
                      3: package body Standard_Complex_Linear_Solvers is
                      4:
                      5: -- AUXLILIARIES :
                      6:
                      7:   function cabs ( c : Complex_Number ) return double_float is
                      8:   begin
                      9:     return (ABS(REAL_PART(c)) + ABS(IMAG_PART(c)));
                     10:   end cabs;
                     11:
                     12:   function dconjg ( x : Complex_Number ) return Complex_Number is
                     13:   begin
                     14:     return Create(REAL_PART(x),-IMAG_PART(x));
                     15:   end dconjg;
                     16:
                     17:   function csign ( x,y : Complex_Number ) return Complex_Number is
                     18:   begin
                     19:     return (Create(cabs(x)) * y / Create(cabs(y)));
                     20:   end csign;
                     21:
                     22: -- TARGET ROUTINES :
                     23:
                     24:   procedure Scale ( a : in out Matrix; b : in out Vector ) is
                     25:
                     26:     fac : Complex_Number;
                     27:
                     28:     function Maximum ( a : in Matrix; i : in integer ) return Complex_Number is
                     29:
                     30:       res : integer := a'first(2);
                     31:       max : double_float := cabs(a(i,res));
                     32:       tmp : double_float;
                     33:
                     34:     begin
                     35:       for j in a'first(2)+1..a'last(2) loop
                     36:         tmp := cabs(a(i,j));
                     37:         if tmp > max
                     38:          then max := tmp; res := j;
                     39:         end if;
                     40:       end loop;
                     41:       return a(i,res);
                     42:     end Maximum;
                     43:
                     44:     procedure Divide ( a : in out Matrix; b : in out Vector;
                     45:                        i : in integer; fac : in Complex_Number ) is
                     46:     begin
                     47:       for j in a'range(2) loop
                     48:         a(i,j) := a(i,j)/fac;
                     49:       end loop;
                     50:       b(i) := b(i)/fac;
                     51:     end Divide;
                     52:
                     53:   begin
                     54:     for i in a'range(1) loop
                     55:       fac := Maximum(a,i);
                     56:       Divide(a,b,i,fac);
                     57:     end loop;
                     58:   end Scale;
                     59:
                     60: -- TARGET ROUTINES :
                     61:
                     62:   procedure lufac ( a : in out Matrix; n : in integer;
                     63:                     ipvt : out Standard_Natural_Vectors.Vector;
                     64:                     info : out natural ) is
                     65:
                     66:     kp1,l,nm1 : integer;
                     67:     smax : double_float;
                     68:     temp : Complex_Number;
                     69:
                     70:   begin
                     71:     info := 0;
                     72:     nm1 := n - 1;
                     73:     if nm1 >= 1
                     74:      then for k in 1..nm1 loop
                     75:             kp1 := k + 1;
                     76:
                     77:           -- find the pivot index l
                     78:
                     79:             l := k; smax := cabs(a(k,k));  --modulus(a(k,k));
                     80:             for i in kp1..n loop
                     81:               if cabs(a(i,k)) > smax --modulus(a(i,k)) > smax
                     82:                then l := i;
                     83:                     smax := cabs(a(i,k)); --modulus(a(i,k));
                     84:               end if;
                     85:             end loop;
                     86:             ipvt(k) := l;
                     87:
                     88:             if smax = 0.0
                     89:              then -- this column is already triangularized
                     90:                   info := k;
                     91:              else
                     92:
                     93:                   -- interchange if necessary
                     94:
                     95:                   if l /= k
                     96:                    then temp := a(l,k);
                     97:                         a(l,k) := a(k,k);
                     98:                         a(k,k) := temp;
                     99:                   end if;
                    100:
                    101:                   -- compute multipliers
                    102:
                    103:                   temp := -Create(1.0)/a(k,k);
                    104:                   for i in kp1..n loop
                    105:                     a(i,k) := temp * a(i,k);
                    106:                   end loop;
                    107:
                    108:                   -- row elimination with column indexing
                    109:
                    110:                   for j in kp1..n loop
                    111:                     temp := a(l,j);
                    112:                     if l /= k
                    113:                      then a(l,j) := a(k,j);
                    114:                           a(k,j) := temp;
                    115:                     end if;
                    116:                     for i in kp1..n loop
                    117:                       a(i,j) := a(i,j) + temp * a(i,k);
                    118:                     end loop;
                    119:                   end loop;
                    120:
                    121:             end if;
                    122:          end loop;
                    123:     end if;
                    124:     ipvt(n) := n;
                    125:     if AbsVal(a(n,n)) = 0.0
                    126:      then info := n;
                    127:     end if;
                    128:   end lufac;
                    129:
                    130:   procedure lufco ( a : in out Matrix; n : in integer;
                    131:                     ipvt : out Standard_Natural_Vectors.Vector;
                    132:                     rcond : out double_float ) is
                    133:
                    134:   -- NOTE :
                    135:   --   rcond = 1/(norm(a)*(estimate of norm(inverse(a))))
                    136:   --   estimate = norm(z)/norm(y) where a*z = y and ctrans(a)*y = e.
                    137:   --   ctrans(a) is the conjugate transpose of a.
                    138:   --   The components of e are chosen to cause maximum local
                    139:   --   growth in teh elements of w where ctrans(u)*w = e.
                    140:   --   The vectors are frequently rescaled to avoid overflow.
                    141:
                    142:     z : Standard_Complex_Vectors.Vector(1..n);
                    143:     info,kb,kp1,l : integer;
                    144:     s,sm,sum,anorm,ynorm : double_float;
                    145:     ek,t,wk,wkm : Complex_Number;
                    146:     ipvtt : Standard_Natural_Vectors.Vector(1..n);
                    147:
                    148:   begin
                    149:     anorm := 0.0;                                    -- compute 1-norm of a
                    150:     for j in 1..n loop
                    151:       sum := 0.0;
                    152:       for i in 1..n loop
                    153:         sum := sum + cabs(a(i,j));
                    154:       end loop;
                    155:       if sum > anorm
                    156:        then anorm := sum;
                    157:       end if;
                    158:     end loop;
                    159:     lufac(a,n,ipvtt,info);                                        -- factor
                    160:     for i in 1..n loop
                    161:       ipvt(i) := ipvtt(i);
                    162:     end loop;
                    163:     ek := Create(1.0);                              -- solve ctrans(u)*w = e
                    164:     for j in 1..n loop
                    165:       z(j) := Create(0.0);
                    166:     end loop;
                    167:     for k in 1..n loop
                    168:       if cabs(z(k)) /= 0.0
                    169:        then ek := csign(ek,-z(k));
                    170:       end if;
                    171:       if cabs(ek-z(k)) > cabs(a(k,k))
                    172:        then s := cabs(a(k,k))/cabs(ek-z(k));
                    173:             z := Create(s) * z;
                    174:             ek := Create(s) * ek;
                    175:       end if;
                    176:       wk := ek - z(k);
                    177:       wkm := -ek - z(k);
                    178:       s := cabs(wk);
                    179:       sm := cabs(wkm);
                    180:       if cabs(a(k,k)) = 0.0
                    181:        then wk := Create(1.0);
                    182:             wkm := Create(1.0);
                    183:        else wk := wk / dconjg(a(k,k));
                    184:             wkm := wkm / dconjg(a(k,k));
                    185:       end if;
                    186:       kp1 := k + 1;
                    187:        if kp1 <= n
                    188:         then for j in kp1..n loop
                    189:                sm := sm + cabs(z(j)+wkm*dconjg(a(k,j)));
                    190:                z(j) := z(j) + wk*dconjg(a(k,j));
                    191:                s := s + cabs(z(j));
                    192:              end loop;
                    193:              if s < sm
                    194:               then t := wkm - wk;
                    195:                    wk := wkm;
                    196:                    for j in kp1..n loop
                    197:                      z(j) := z(j) + t*dconjg(a(k,j));
                    198:                    end loop;
                    199:              end if;
                    200:        end if;
                    201:        z(k) := wk;
                    202:      end loop;
                    203:      sum := 0.0;
                    204:      for i in 1..n loop
                    205:        sum := sum + cabs(z(i));
                    206:      end loop;
                    207:      s := 1.0 / sum;
                    208:      z := Create(s) * z;
                    209:      for k in 1..n loop                           -- solve ctrans(l)*y = w
                    210:        kb := n+1-k;
                    211:        if kb < n
                    212:         then t := Create(0.0);
                    213:              for i in (kb+1)..n loop
                    214:                t := t + dconjg(a(i,kb))*z(i);
                    215:              end loop;
                    216:              z(kb) := z(kb) + t;
                    217:        end if;
                    218:        if cabs(z(kb)) > 1.0
                    219:         then s := 1.0 / cabs(z(kb));
                    220:              z := Create(s) * z;
                    221:        end if;
                    222:        l := ipvtt(kb);
                    223:        t := z(l);
                    224:        z(l) := z(kb);
                    225:        z(kb)  := t;
                    226:      end loop;
                    227:      sum := 0.0;
                    228:      for i in 1..n loop
                    229:        sum := sum + cabs(z(i));
                    230:      end loop;
                    231:      s := 1.0 / sum;
                    232:      z := Create(s) * z;
                    233:      ynorm := 1.0;
                    234:      for k in 1..n loop                                    -- solve l*v = y
                    235:        l := ipvtt(k);
                    236:        t := z(l);
                    237:        z(l) := z(k);
                    238:        z(k) := t;
                    239:        if k < n
                    240:         then for i in (k+1)..n loop
                    241:                z(i) := z(i) + t * a(i,k);
                    242:              end loop;
                    243:        end if;
                    244:        if cabs(z(k)) > 1.0
                    245:         then s := 1.0 / cabs(z(k));
                    246:              z := Create(s) * z;
                    247:              ynorm := s * ynorm;
                    248:        end if;
                    249:      end loop;
                    250:      sum := 0.0;
                    251:      for i in 1..n loop
                    252:        sum := sum + cabs(z(i));
                    253:      end loop;
                    254:      s := 1.0 / sum;
                    255:      z := Create(s) * z;
                    256:      ynorm := s * ynorm;
                    257:      for k in 1..n loop                                    -- solve u*z = v
                    258:        kb := n+1-k;
                    259:        if cabs(z(kb)) > cabs(a(kb,kb))
                    260:         then s := cabs(a(kb,kb)) / cabs(z(kb));
                    261:              z := Create(s) * z;
                    262:              ynorm := s * ynorm;
                    263:        end if;
                    264:        if cabs(a(kb,kb)) = 0.0
                    265:         then z(kb) := Create(1.0);
                    266:         else z(kb) := z(kb) / a(kb,kb);
                    267:        end if;
                    268:        t := -z(kb);
                    269:        for i in 1..(kb-1) loop
                    270:          z(i) := z(i) + t * a(i,kb);
                    271:        end loop;
                    272:      end loop;
                    273:      sum := 0.0;                                       -- make znorm = 1.0
                    274:      for i in 1..n loop
                    275:        sum := sum + cabs(z(i));
                    276:      end loop;
                    277:      s := 1.0 / sum;
                    278:      z := Create(s) * z;
                    279:      ynorm := s * ynorm;
                    280:      if anorm = 0.0
                    281:       then rcond := 0.0;
                    282:       else rcond := ynorm/anorm;
                    283:      end if;
                    284:   end lufco;
                    285:
                    286:   procedure lusolve ( a : in Matrix; n : in integer;
                    287:                       ipvt : in Standard_Natural_Vectors.Vector;
                    288:                       b : in out Vector ) is
                    289:
                    290:     l,nm1,kb : integer;
                    291:     temp : Complex_Number;
                    292:
                    293:   begin
                    294:     nm1 := n-1;
                    295:     if nm1 >= 1                                             -- solve l*y = b
                    296:      then for k in 1..nm1 loop
                    297:             l := ipvt(k);
                    298:             temp := b(l);
                    299:             if l /= k
                    300:              then b(l) := b(k);
                    301:                   b(k) := temp;
                    302:             end if;
                    303:             for i in (k+1)..n loop
                    304:               b(i) := b(i) + temp * a(i,k);
                    305:             end loop;
                    306:           end loop;
                    307:     end if;
                    308:     for k in 1..n loop                                     -- solve u*x = y
                    309:       kb := n+1-k;
                    310:       b(kb) := b(kb) / a(kb,kb);
                    311:       temp := -b(kb);
                    312:       for j in 1..(kb-1) loop
                    313:         b(j) := b(j) + temp * a(j,kb);
                    314:       end loop;
                    315:     end loop;
                    316:   end lusolve;
                    317:
                    318:   procedure Triangulate ( a : in out Matrix; n,m : in integer ) is
                    319:
                    320:     max,cbs : double_float;
                    321:     temp : Complex_Number;
                    322:     pivot,k,kcolumn : integer;
                    323:     tol : constant double_float := 10.0**(-10);
                    324:
                    325:   begin
                    326:     k := 1;
                    327:     kcolumn := 1;
                    328:     while (k <= n) and (kcolumn <= m) loop
                    329:       max := 0.0;                                             -- find pivot
                    330:       pivot := 0;
                    331:       for l in k..n loop
                    332:         cbs := cabs(a(l,kcolumn));
                    333:         if (cbs > tol) and then (cbs > max)
                    334:          then max := cbs;
                    335:               pivot := l;
                    336:         end if;
                    337:       end loop;
                    338:       if pivot = 0
                    339:        then kcolumn := kcolumn + 1;
                    340:        else if pivot /= k                       -- interchange if necessary
                    341:              then for i in 1..m loop
                    342:                     temp := a(pivot,i);
                    343:                     a(pivot,i) := a(k,i);
                    344:                     a(k,i) := temp;
                    345:                   end loop;
                    346:             end if;
                    347:             for j in (kcolumn+1)..m loop                   -- triangulate a
                    348:               a(k,j) := a(k,j) / a(k,kcolumn);
                    349:             end loop;
                    350:             a(k,kcolumn) := Create(1.0);
                    351:             for i in (k+1)..n loop
                    352:               for j in (kcolumn+1)..m loop
                    353:                 a(i,j) := a(i,j) - a(i,kcolumn) * a(k,j);
                    354:               end loop;
                    355:               a(i,kcolumn) := Create(0.0);
                    356:             end loop;
                    357:             k := k + 1;
                    358:             kcolumn := kcolumn + 1;
                    359:       end if;
                    360:     end loop;
                    361:   end Triangulate;
                    362:
                    363:   procedure Diagonalize ( a : in out Matrix; n,m : in integer ) is
                    364:
                    365:     max : double_float;
                    366:     temp : Complex_Number;
                    367:     pivot,k,kcolumn : integer;
                    368:
                    369:   begin
                    370:     k := 1;
                    371:     kcolumn := 1;
                    372:     while (k <= n) and (kcolumn <= m) loop
                    373:       max := 0.0;                                               -- find pivot
                    374:       for l in k..n loop
                    375:         if cabs(a(l,kcolumn)) > max
                    376:          then max := cabs(a(l,kcolumn));
                    377:               pivot := l;
                    378:         end if;
                    379:       end loop;
                    380:       if max = 0.0
                    381:        then kcolumn := kcolumn + 1;
                    382:        else if pivot /= k                        -- interchange if necessary
                    383:              then for i in 1..m loop
                    384:                     temp := a(pivot,i);
                    385:                     a(pivot,i) := a(k,i);
                    386:                     a(k,i) := temp;
                    387:                   end loop;
                    388:             end if;
                    389:             for j in (kcolumn+1)..m loop                    -- diagonalize a
                    390:               a(k,j) := a(k,j) / a(k,kcolumn);
                    391:             end loop;
                    392:             a(k,kcolumn) := Create(1.0);
                    393:             for i in 1..(k-1) loop
                    394:               for j in (kcolumn+1)..m loop
                    395:                 a(i,j) := a(i,j) - a(i,kcolumn) * a(k,j);
                    396:               end loop;
                    397:             end loop;
                    398:             for i in (k+1)..n loop
                    399:               for j in (kcolumn+1)..m loop
                    400:                 a(i,j) := a(i,j) - a(i,kcolumn) * a(k,j);
                    401:               end loop;
                    402:             end loop;
                    403:             for j in 1..(k-1) loop
                    404:               a(j,kcolumn) := Create(0.0);
                    405:             end loop;
                    406:             for j in (k+1)..n loop
                    407:               a(j,kcolumn) := Create(0.0);
                    408:             end loop;
                    409:             k := k + 1;
                    410:             kcolumn := kcolumn + 1;
                    411:       end if;
                    412:     end loop;
                    413:   end Diagonalize;
                    414:
                    415: end Standard_Complex_Linear_Solvers;

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>