Changeset 1342 in Sophya for trunk/SophyaExt/LinAlg/intflapack.cc


Ignore:
Timestamp:
Nov 24, 2000, 10:47:37 AM (25 years ago)
Author:
ansari
Message:

Ajout de SVD ds LapackServer - Reza 24/11/2000

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/SophyaExt/LinAlg/intflapack.cc

    r1042 r1342  
    11#include <iostream.h>
    22#include "intflapack.h"
     3#include "tvector.h"
     4#include "tmatrix.h"
    35#include <typeinfo>
    46
    57extern "C" {
    6 void sgesv_(int_4* n, int_4* nrhs, r_4* a, int_4* lda,
    7             int_4* ipiv, r_4* b, int_4* ldb, int_4* info);
    8 void dgesv_(int_4* n, int_4* nrhs, r_8* a, int_4* lda,
    9             int_4* ipiv, r_8* b, int_4* ldb, int_4* info);
    10 void cgesv_(int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
    11             int_4* ipiv, complex<r_4>* b, int_4* ldb, int_4* info);
    12 void zgesv_(int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
    13             int_4* ipiv, complex<r_8>* b, int_4* ldb, int_4* info);
     8// Drivers pour resolution de systemes lineaires
     9  void sgesv_(int_4* n, int_4* nrhs, r_4* a, int_4* lda,
     10              int_4* ipiv, r_4* b, int_4* ldb, int_4* info);
     11  void dgesv_(int_4* n, int_4* nrhs, r_8* a, int_4* lda,
     12              int_4* ipiv, r_8* b, int_4* ldb, int_4* info);
     13  void cgesv_(int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
     14              int_4* ipiv, complex<r_4>* b, int_4* ldb, int_4* info);
     15  void zgesv_(int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
     16              int_4* ipiv, complex<r_8>* b, int_4* ldb, int_4* info);
     17
     18// Driver pour decomposition SVD
     19  void sgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_4* a, int_4* lda,
     20               r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
     21               r_4* work, int_4* lwork, int_4* info);
     22  void dgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_8* a, int_4* lda,
     23               r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
     24               r_8* work, int_4* lwork, int_4* info);
     25  void cgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
     26               complex<r_4>* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
     27               complex<r_4>* work, int_4* lwork, int_4* info);
     28  void zgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
     29               complex<r_8>* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
     30               complex<r_8>* work, int_4* lwork, int_4* info);
     31               
     32}
     33
     34
     35//   -------------- Classe LapackServer<T> --------------
     36
     37template <class T>
     38LapackServer<T>::LapackServer<T>()
     39{
     40  SetWorkSpaceSizeFactor();
     41}
     42
     43template <class T>
     44LapackServer<T>::~LapackServer<T>()
     45{
    1446}
    1547
     
    2052    throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
    2153 
    22   uint_4 rowa = a.RowsKA();
    23   uint_4 cola = a.ColsKA();
    24   uint_4 rowb = b.RowsKA();
    25   uint_4 colb = b.ColsKA();
     54  int_4 rowa = a.RowsKA();
     55  int_4 cola = a.ColsKA();
     56  int_4 rowb = b.RowsKA();
     57  int_4 colb = b.ColsKA();
    2658  if ( a.Size(rowa) !=  a.Size(cola))
    2759    throw(SzMismatchError("LapackServer::LinSolve(a,b) a Not a square Array"));
     
    3062
    3163  if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
    32      throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Packed Columns "));
     64     throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Column Packed"));
    3365
    3466  int_4 n = a.Size(rowa);
     
    5688  }
    5789  delete[] ipiv;
     90  return(info);
     91}
     92
     93template <class T>
     94int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s)
     95{
     96  return (SVDDriver(a, s, NULL, NULL) );
     97}
     98
     99template <class T>
     100int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s, TArray<T> & u, TArray<T> & vt)
     101{
     102  return (SVDDriver(a, s, &u, &vt) );
     103}
     104
     105template <class T>
     106int LapackServer<T>::SVDDriver(TArray<T>& a, TArray<T> & s, TArray<T>* up, TArray<T>* vtp)
     107{
     108  if ( ( a.NbDimensions() != 2 )  )
     109    throw(SzMismatchError("LapackServer::SVD(a, ...) a.NbDimensions() != 2"));
     110
     111  int_4 rowa = a.RowsKA();
     112  int_4 cola = a.ColsKA();
     113
     114  if ( !a.IsPacked(rowa) )
     115     throw(SzMismatchError("LapackServer::SVD(a, ...) a Not Column Packed "));
     116
     117  int_4 m = a.Size(rowa);
     118  int_4 n = a.Size(cola);
     119  int_4 maxmn = (m > n) ? m : n;
     120  int_4 minmn = (m < n) ? m : n;
     121
     122  char jobu, jobvt;
     123  jobu = 'N';
     124  jobvt = 'N';
     125
     126  sa_size_t sz[2];
     127  if ( up != NULL) {
     128    if ( dynamic_cast< TVector<T> * > (vtp) )
     129      throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for u !") );
     130    up->SetMemoryMapping(BaseArray::FortranMemoryMapping);
     131    sz[0] = sz[1] = m;
     132    up->ReSize(2, sz );
     133    jobu = 'A';
     134  }
     135  else {
     136    up = new TMatrix<T>(1,1);
     137    jobu = 'N';
     138  }
     139  if ( vtp != NULL) {
     140    if ( dynamic_cast< TVector<T> * > (vtp) )
     141      throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for vt !") );
     142    vtp->SetMemoryMapping(BaseArray::FortranMemoryMapping);
     143    sz[0] = sz[1] = n;
     144    vtp->ReSize(2, sz );
     145    jobvt = 'A';
     146  }
     147  else {
     148    vtp = new TMatrix<T>(1,1);
     149    jobvt = 'N';
     150  }
     151
     152  TVector<T> *vs = dynamic_cast< TVector<T> * > (&s);
     153  if (vs) vs->ReSize(minmn);
     154  else {
     155    TMatrix<T> *ms = dynamic_cast< TMatrix<T> * > (&s);
     156    if (ms) ms->ReSize(minmn,1);
     157    else  {
     158      sz[0] = minmn; sz[1] = 1;
     159      s.ReSize(1, sz);
     160    }
     161  }
     162 
     163  int_4 lda = a.Step(a.ColsKA());
     164  int_4 ldu = up->Step(up->ColsKA());
     165  int_4 ldvt = vtp->Step(vtp->ColsKA());
     166
     167  int_4 lwork = maxmn*5*wspace_size_factor;
     168  T * work = new T[lwork];
     169  int_4 info; 
     170
     171  if (typeid(T) == typeid(r_4) )
     172    sgesvd_(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
     173            (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
     174            (r_4 *)work, &lwork, &info);
     175  else if (typeid(T) == typeid(r_8) )
     176    dgesvd_(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
     177            (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
     178            (r_8 *)work, &lwork, &info);
     179  else if (typeid(T) == typeid(complex<r_4>) )
     180    cgesvd_(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
     181            (complex<r_4> *)s.Data(), (complex<r_4> *) up->Data(), &ldu,
     182            (complex<r_4> *)vtp->Data(), &ldvt,
     183            (complex<r_4> *)work, &lwork, &info);
     184  else if (typeid(T) == typeid(complex<r_8>) )
     185    zgesvd_(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
     186            (complex<r_8> *)s.Data(), (complex<r_8> *) up->Data(), &ldu,
     187            (complex<r_8> *)vtp->Data(), &ldvt,
     188            (complex<r_8> *)work, &lwork, &info);
     189  else {
     190    if (jobu == 'N') delete up;
     191    if (jobvt == 'N') delete vtp;
     192    string tn = typeid(T).name();
     193    cerr << " LapackServer::SVDDriver(...) - Unsupported DataType T = " << tn << endl;
     194    throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
     195  }
     196
     197  if (jobu == 'N') delete up;
     198  if (jobvt == 'N') delete vtp;
    58199  return(info);
    59200}
Note: See TracChangeset for help on using the changeset viewer.