Changeset 2561 in Sophya for trunk/SophyaExt/LinAlg


Ignore:
Timestamp:
Jul 23, 2004, 12:53:35 PM (21 years ago)
Author:
cmv
Message:

add SVD decomp by Divide and Conquer (cmv 23/07/04)

Location:
trunk/SophyaExt/LinAlg
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/SophyaExt/LinAlg/intflapack.cc

    r2559 r2561  
    128128               complex<r_8>* work, int_4* lwork, r_8* rwork, int_4* info);
    129129
     130// Driver pour decomposition SVD Divide and Conquer
     131  void sgesdd_(char* jobz, int_4* m, int_4* n, r_4* a, int_4* lda,
     132               r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
     133               r_4* work, int_4* lwork, int_4* iwork, int_4* info);
     134  void dgesdd_(char* jobz, int_4* m, int_4* n, r_8* a, int_4* lda,
     135               r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
     136               r_8* work, int_4* lwork, int_4* iwork, int_4* info);
     137  void cgesdd_(char* jobz, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
     138               r_4* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
     139               complex<r_4>* work, int_4* lwork, r_4* rwork, int_4* iwork, int_4* info);
     140  void zgesdd_(char* jobz, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
     141               r_8* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
     142               complex<r_8>* work, int_4* lwork, r_8* rwork, int_4* iwork, int_4* info);
     143
    130144// Driver pour eigen decomposition for symetric/hermitian matrices
    131145  void ssyev_(char* jobz, char* uplo, int_4* n, r_4* a, int_4* lda, r_4* w,
     
    253267  \param a : input matrix symetric , overwritten on output
    254268  \param b : input-output, input vector b, contains x on exit
    255   \return : return code from lapack driver _gesv()
     269  \return : return code from lapack driver
    256270 */
    257271template <class T>
     
    347361{
    348362  if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
    349     throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
     363    throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) a Or b NbDimensions() != 2"));
    350364 
    351365  int_4 rowa = a.RowsKA();
     
    450464{
    451465  if ( ( a.NbDimensions() != 2 )  )
    452     throw(SzMismatchError("LapackServer::SVD(a, ...) a.NbDimensions() != 2"));
     466    throw(SzMismatchError("LapackServer::SVDDriver(a, ...) a.NbDimensions() != 2"));
    453467
    454468  int_4 rowa = a.RowsKA();
     
    456470
    457471  if ( !a.IsPacked(rowa) )
    458      throw(SzMismatchError("LapackServer::SVD(a, ...) a Not Column Packed "));
     472     throw(SzMismatchError("LapackServer::SVDDriver(a, ...) a Not Column Packed "));
    459473
    460474  int_4 m = a.Size(rowa);
     
    470484  if ( up != NULL) {
    471485    if ( dynamic_cast< TVector<T> * > (vtp) )
    472       throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for u !") );
     486      throw( TypeMismatchExc("LapackServer::SVDDriver() Wrong type (=TVector<T>) for u !") );
    473487    up->SetMemoryMapping(BaseArray::FortranMemoryMapping);
    474488    sz[0] = sz[1] = m;
     
    482496  if ( vtp != NULL) {
    483497    if ( dynamic_cast< TVector<T> * > (vtp) )
    484       throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for vt !") );
     498      throw( TypeMismatchExc("LapackServer::SVDDriver() Wrong type (=TVector<T>) for vt !") );
    485499    vtp->SetMemoryMapping(BaseArray::FortranMemoryMapping);
    486500    sz[0] = sz[1] = n;
     
    510524  int_4 lwork = maxmn*5*wspace_size_factor;
    511525  T * work = new T[lwork];
    512   int_4 info; 
     526  int_4 info;
    513527
    514528  if (typeid(T) == typeid(r_4) ) {
     
    543557    string tn = typeid(T).name();
    544558    cerr << " LapackServer::SVDDriver(...) - Unsupported DataType T = " << tn << endl;
    545     throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
     559    throw TypeMismatchExc("LapackServer::SVDDriver(a,b) - Unsupported DataType (T)");
    546560  }
    547561
    548562  if (jobu == 'N') delete up;
    549563  if (jobvt == 'N') delete vtp;
     564  return(info);
     565}
     566
     567
     568//! Interface to Lapack SVD driver s/d/c/zgesdd().
     569/*! Same as SVD but with Divide and Conquer method */
     570template <class T>
     571int LapackServer<T>::SVD_DC(TMatrix<T>& a, TVector<T>& s, TMatrix<T>& u, TMatrix<T>& vt)
     572{
     573
     574  if ( !a.IsPacked() )
     575     throw(SzMismatchError("LapackServer::SVD_DC(a, ...) a Not Packed "));
     576
     577  int_4 m = a.NRows();
     578  int_4 n = a.NCols();
     579  int_4 maxmn = (m > n) ? m : n;
     580  int_4 minmn = (m < n) ? m : n;
     581  int_4 supermax = 4*minmn*minmn+4*minmn; if(maxmn>supermax) supermax=maxmn;
     582
     583  char jobz = 'A';
     584
     585  s.ReSize(minmn);
     586  u.ReSize(m,m);
     587  vt.ReSize(n,n);
     588 
     589  int_4 lda = n;
     590  int_4 ldu = u.NCols();
     591  int_4 ldvt = vt.NCols();
     592  int_4 info; 
     593
     594  if(typeid(T) == typeid(r_4) ) {
     595    int_4 lwork = 3*minmn*minmn + supermax;
     596    r_4* work = new r_4[lwork +5];
     597    int_4* iwork = new int_4[8*minmn +5];
     598    sgesdd_(&jobz,&m,&n,(r_4*)a.Data(),&lda,
     599           (r_4*)s.Data(),(r_4*)u.Data(),&ldu,(r_4*)vt.Data(),&ldvt,
     600           (r_4*)work,&lwork,(int_4*)iwork,&info);
     601    delete [] work; delete [] iwork;
     602  } else if(typeid(T) == typeid(r_8) ) {
     603    int_4 lwork = 3*minmn*minmn + supermax;
     604    r_8* work = new r_8[lwork +5];
     605    int_4* iwork = new int_4[8*minmn +5];
     606    dgesdd_(&jobz,&m,&n,(r_8*)a.Data(),&lda,
     607           (r_8*)s.Data(),(r_8*)u.Data(),&ldu,(r_8*)vt.Data(),&ldvt,
     608           (r_8*)work,&lwork,(int_4*)iwork,&info);
     609    delete [] work; delete [] iwork;
     610  } else if(typeid(T) == typeid(complex<r_4>) ) {
     611    r_4* sloc = new r_4[minmn];
     612    int_4 lwork = minmn*minmn+2*minmn+maxmn;
     613    complex<r_4>* work = new complex<r_4>[lwork +5];
     614    r_4* rwork = new r_4[5*minmn*minmn+5*minmn +5];
     615    int_4* iwork = new int_4[8*minmn +5];
     616    cgesdd_(&jobz,&m,&n,(complex<r_4>*)a.Data(),&lda,
     617           (r_4*)sloc,(complex<r_4>*)u.Data(),&ldu,(complex<r_4>*)vt.Data(),&ldvt,
     618           (complex<r_4>*)work,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
     619    for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
     620    delete [] sloc; delete [] work; delete [] rwork; delete [] iwork;
     621  } else if(typeid(T) == typeid(complex<r_8>) )  {
     622    r_8* sloc = new r_8[minmn];
     623    int_4 lwork = minmn*minmn+2*minmn+maxmn;
     624    complex<r_8>* work = new complex<r_8>[lwork +5];
     625    r_8* rwork = new r_8[5*minmn*minmn+5*minmn +5];
     626    int_4* iwork = new int_4[8*minmn +5];
     627    zgesdd_(&jobz,&m,&n,(complex<r_8>*)a.Data(),&lda,
     628           (r_8*)sloc,(complex<r_8>*)u.Data(),&ldu,(complex<r_8>*)vt.Data(),&ldvt,
     629           (complex<r_8>*)work,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
     630    for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
     631    delete [] sloc; delete [] work; delete [] rwork; delete [] iwork;
     632  } else {
     633    string tn = typeid(T).name();
     634    cerr << " LapackServer::SVD_DC(...) - Unsupported DataType T = " << tn << endl;
     635    throw TypeMismatchExc("LapackServer::SVD_DC - Unsupported DataType (T)");
     636  }
     637
    550638  return(info);
    551639}
     
    559647  \param eigenvector : if true compute eigenvectors, if not only eigenvalues
    560648  \param a : on return array of eigenvectors (same order than eval, one vector = one column)
    561   \return : return code from lapack driver _gesvd()
     649  \return : return code from lapack driver
    562650 */
    563651
     
    574662    throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a Not Column Packed"));
    575663
    576   char uplo='U'; char struplo[5]; struplo[0]=uplo; struplo[1]='\0';
     664  char uplo='U';
    577665  char jobz='N'; if(eigenvector) jobz='V';
    578   char strjobz[5]; strjobz[0]=jobz; strjobz[1]='\0';
    579666
    580667  int_4 n = a.Size(rowa);
     
    587674    int_4 lwork = 3*n-1 +5; r_4* work = new r_4[lwork];
    588675    r_4* w = new r_4[n];
    589     ssyev_(strjobz,struplo,&n,(r_4 *)a.Data(),&lda,(r_4 *)w,(r_4 *)work,&lwork,&info);
     676    ssyev_(&jobz,&uplo,&n,(r_4 *)a.Data(),&lda,(r_4 *)w,(r_4 *)work,&lwork,&info);
    590677    if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
    591678    delete [] work; delete [] w;
     
    593680    int_4 lwork = 3*n-1 +5; r_8* work = new r_8[lwork];
    594681    r_8* w = new r_8[n];
    595     dsyev_(strjobz,struplo,&n,(r_8 *)a.Data(),&lda,(r_8 *)w,(r_8 *)work,&lwork,&info);
     682    dsyev_(&jobz,&uplo,&n,(r_8 *)a.Data(),&lda,(r_8 *)w,(r_8 *)work,&lwork,&info);
    596683    if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
    597684    delete [] work; delete [] w;
     
    599686    int_4 lwork = 2*n-1 +5; complex<r_4>* work = new complex<r_4>[lwork];
    600687    r_4* rwork = new r_4[3*n-2  +5]; r_4* w = new r_4[n];
    601     cheev_(strjobz,struplo,&n,(complex<r_4> *)a.Data(),&lda,(r_4 *)w
     688    cheev_(&jobz,&uplo,&n,(complex<r_4> *)a.Data(),&lda,(r_4 *)w
    602689          ,(complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
    603690    if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
     
    606693    int_4 lwork = 2*n-1 +5; complex<r_8>* work = new complex<r_8>[lwork];
    607694    r_8* rwork = new r_8[3*n-2  +5]; r_8* w = new r_8[n];
    608     zheev_(strjobz,struplo,&n,(complex<r_8> *)a.Data(),&lda,(r_8 *)w
     695    zheev_(&jobz,&uplo,&n,(complex<r_8> *)a.Data(),&lda,(r_8 *)w
    609696          ,(complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
    610697    if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
     
    627714  \param eigenvector : if true compute (right) eigenvectors, if not only eigenvalues
    628715  \param a : on return array of eigenvectors
    629   \return : return code from lapack driver _gesvd()
     716  \return : return code from lapack driver
    630717  \verbatim
    631718  eval : contains the computed eigenvalues.
     
    656743    throw(SzMismatchError("LapackServer::LapackEigen(a,b) a Not Column Packed"));
    657744
    658   char jobvl = 'N'; char strjobvl[5]; strjobvl[0] = jobvl; strjobvl[1] = '\0';
     745  char jobvl = 'N';
    659746  char jobvr = 'N'; if(eigenvector) jobvr='V';
    660   char strjobvr[5]; strjobvr[0] = jobvr; strjobvr[1] = '\0';
    661747
    662748  int_4 n = a.Size(rowa);
     
    671757    int_4 lwork = 4*n +5; r_4* work = new r_4[lwork];
    672758    r_4* wr = new r_4[n]; r_4* wi = new r_4[n]; r_4* vl = NULL;
    673     sgeev_(strjobvl,strjobvr,&n,(r_4 *)a.Data(),&lda,(r_4 *)wr,(r_4 *)wi,
     759    sgeev_(&jobvl,&jobvr,&n,(r_4 *)a.Data(),&lda,(r_4 *)wr,(r_4 *)wi,
    674760           (r_4 *)vl,&ldvl,(r_4 *)evec.Data(),&ldvr,
    675761           (r_4 *)work,&lwork,&info);
     
    679765    int_4 lwork = 4*n +5; r_8* work = new r_8[lwork];
    680766    r_8* wr = new r_8[n]; r_8* wi = new r_8[n]; r_8* vl = NULL;
    681     dgeev_(strjobvl,strjobvr,&n,(r_8 *)a.Data(),&lda,(r_8 *)wr,(r_8 *)wi,
     767    dgeev_(&jobvl,&jobvr,&n,(r_8 *)a.Data(),&lda,(r_8 *)wr,(r_8 *)wi,
    682768           (r_8 *)vl,&ldvl,(r_8 *)evec.Data(),&ldvr,
    683769           (r_8 *)work,&lwork,&info);
     
    687773    int_4 lwork = 2*n +5; complex<r_4>* work = new complex<r_4>[lwork];
    688774    r_4* rwork = new r_4[2*n+5]; r_4* vl = NULL; TVector< complex<r_4> > w(n);
    689     cgeev_(strjobvl,strjobvr,&n,(complex<r_4> *)a.Data(),&lda,(complex<r_4> *)w.Data(),
     775    cgeev_(&jobvl,&jobvr,&n,(complex<r_4> *)a.Data(),&lda,(complex<r_4> *)w.Data(),
    690776           (complex<r_4> *)vl,&ldvl,(complex<r_4> *)evec.Data(),&ldvr,
    691777           (complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
     
    695781    int_4 lwork = 2*n +5; complex<r_8>* work = new complex<r_8>[lwork];
    696782    r_8* rwork = new r_8[2*n+5]; r_8* vl = NULL;
    697     zgeev_(strjobvl,strjobvr,&n,(complex<r_8> *)a.Data(),&lda,(complex<r_8> *)eval.Data(),
     783    zgeev_(&jobvl,&jobvr,&n,(complex<r_8> *)a.Data(),&lda,(complex<r_8> *)eval.Data(),
    698784           (complex<r_8> *)vl,&ldvl,(complex<r_8> *)evec.Data(),&ldvr,
    699785           (complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
  • trunk/SophyaExt/LinAlg/intflapack.h

    r2556 r2561  
    2020  virtual int SVD(TArray<T>& a, TArray<T> & s);
    2121  virtual int SVD(TArray<T>& a, TArray<T> & s, TArray<T> & u, TArray<T> & vt);
     22  virtual int SVD_DC(TMatrix<T>& a, TVector<T>& s, TMatrix<T>& u, TMatrix<T>& vt);
    2223 
    2324  virtual int LapackEigenSym(TArray<T>& a, TVector<r_8>& b, bool eigenvector=true);
     
    8384
    8485/*! \ingroup LinAlg
     86    \fn LapackSVD_DC(TMatrix<T>&, TVector<T>&, TMatrix<T>&, TMatrix<T>&)
     87    \brief SVD decomposition DC using LapackServer.
     88*/
     89template <class T>
     90inline int LapackSVD_DC(TMatrix<T>& a, TVector<T>& s, TMatrix<T>& u, TMatrix<T>& vt)
     91{ LapackServer<T> lps; return( lps.SVD_DC(a, s, u, vt) ); }
     92
     93
     94/*! \ingroup LinAlg
    8595    \fn LapackEigenSym(TArray<T>&, TArray<T> &)
    8696    \brief Compute the eigenvalues and eigenvectors of A (symetric or hermitian).
Note: See TracChangeset for help on using the changeset viewer.