Changeset 2567 in Sophya for trunk/SophyaExt/LinAlg/intflapack.cc


Ignore:
Timestamp:
Jul 27, 2004, 9:59:05 AM (21 years ago)
Author:
cmv
Message:
  • change argument for SVD_DC
  • add Least Square with SVD DC (cmv 27/07/04)
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/SophyaExt/LinAlg/intflapack.cc

    r2563 r2567  
    11#include <iostream>
     2#include <math.h>
    23#include "intflapack.h"
    34#include "tvector.h"
    45#include "tmatrix.h"
    56#include <typeinfo>
     7
     8#define GARDMEM 5
    69
    710/*************** Pour memoire  (Christophe) ***************
     
    114117              complex<r_8>* b, int_4* ldb, complex<r_8>* work, int_4* lwork, int_4* info);
    115118
     119// Driver pour resolution de systemes au sens de Xi2 par SVD Divide & Conquer
     120  void sgelsd_(int_4* m,int_4* n,int_4* nrhs,r_4* a,int_4* lda,
     121              r_4* b,int_4* ldb,r_4* s,r_4* rcond,int_4* rank,
     122              r_4* work,int_4* lwork,int_4* iwork,int_4* info);
     123  void dgelsd_(int_4* m,int_4* n,int_4* nrhs,r_8* a,int_4* lda,
     124              r_8* b,int_4* ldb,r_8* s,r_8* rcond,int_4* rank,
     125              r_8* work,int_4* lwork,int_4* iwork,int_4* info);
     126  void cgelsd_(int_4* m,int_4* n,int_4* nrhs,complex<r_4>* a,int_4* lda,
     127              complex<r_4>* b,int_4* ldb,r_4* s,r_4* rcond,int_4* rank,
     128              complex<r_4>* work,int_4* lwork,r_4* rwork,int_4* iwork,int_4* info);
     129  void zgelsd_(int_4* m,int_4* n,int_4* nrhs,complex<r_8>* a,int_4* lda,
     130              complex<r_8>* b,int_4* ldb,r_8* s,r_8* rcond,int_4* rank,
     131              complex<r_8>* work,int_4* lwork,r_8* rwork,int_4* iwork,int_4* info);
     132
    116133// Driver pour decomposition SVD
    117134  void sgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_4* a, int_4* lda,
     
    305322
    306323  if (typeid(T) == typeid(r_4) ) {
    307     lwork = ilaenv_en_C(1,"SSYTRF",struplo,n,-1,-1,-1) * n  +5;
    308     work = new T[lwork];
     324    lwork = ilaenv_en_C(1,"SSYTRF",struplo,n,-1,-1,-1) * n;
     325    work = new T[lwork  +GARDMEM];
    309326    ssysv_(&uplo, &n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb,
    310327          (r_4 *)work, &lwork, &info);
    311328  } else if (typeid(T) == typeid(r_8) )  {
    312     lwork = ilaenv_en_C(1,"DSYTRF",struplo,n,-1,-1,-1) * n  +5;
    313     work = new T[lwork];
     329    lwork = ilaenv_en_C(1,"DSYTRF",struplo,n,-1,-1,-1) * n;
     330    work = new T[lwork  +GARDMEM];
    314331    dsysv_(&uplo, &n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb,
    315332          (r_8 *)work, &lwork, &info);
    316333  } else if (typeid(T) == typeid(complex<r_4>) )  {
    317     lwork = ilaenv_en_C(1,"CSYTRF",struplo,n,-1,-1,-1) * n  +5;
    318     work = new T[lwork];
     334    lwork = ilaenv_en_C(1,"CSYTRF",struplo,n,-1,-1,-1) * n;
     335    work = new T[lwork  +GARDMEM];
    319336    csysv_(&uplo, &n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
    320337          (complex<r_4> *)b.Data(), &ldb,
    321338          (complex<r_4> *)work, &lwork, &info);
    322339  } else if (typeid(T) == typeid(complex<r_8>) )  {
    323     lwork = ilaenv_en_C(1,"ZSYTRF",struplo,n,-1,-1,-1) * n  +5;
    324     work = new T[lwork];
     340    lwork = ilaenv_en_C(1,"ZSYTRF",struplo,n,-1,-1,-1) * n;
     341    work = new T[lwork  +GARDMEM];
    325342    zsysv_(&uplo, &n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
    326343          (complex<r_8> *)b.Data(), &ldb,
     
    420437}
    421438
     439////////////////////////////////////////////////////////////////////////////////////
     440//! Interface to Lapack least squares solver driver s/d/c/zgelsd().
     441/*! Solves the linear least squares problem defined by an m-by-n matrix
     442  \b a and an m element vector \b b , using SVD factorization Divide and Conquer.
     443  Inout arrays should have FortranMemory mapping (column packed).
     444  \param rcond : definition of zero value (S(i) <= RCOND*S(0) are treated as zero).
     445                 If RCOND < 0, machine precision is used instead.
     446  \param a : input matrix, overwritten on output
     447  \param b : input vector b overwritten by solution on output (beware of size changing)
     448  \param x : output matrix of solutions.
     449  \return : return code from lapack driver _gelsd()
     450  \warning : b is not resized.
     451 */
     452template <class T>
     453int LapackServer<T>::LeastSquareSolveSVD_DC(TMatrix<T>& a,TMatrix<T>& b,TVector<r_8>& s,int_4& rank,r_8 rcond)
     454{
     455  if ( ( a.NbDimensions() != 2 ) )
     456    throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) a != 2"));
     457 
     458  if (!a.IsPacked() || !b.IsPacked())
     459     throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) a Or b Not Packed"));
     460
     461  int_4 m = a.NRows();
     462  int_4 n = a.NCols();
     463
     464  if(b.NRows() != n)
     465     throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) bad matching dim between a and b"));
     466
     467  int_4 nrhs = b.NCols();
     468  int_4 minmn = (m < n) ? m : n;
     469  int_4 maxmn = (m > n) ? m : n;
     470
     471  if(b.NRows() != n)
     472     throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) bad matching dim between a and b"));
     473
     474  int_4 lda = m;
     475  int_4 ldb = maxmn;
     476  int_4 info;
     477
     478  { // Use {} for automatic des-alloc bsave
     479  TMatrix<T> bsave(n,nrhs); bsave.SetMemoryMapping(BaseArray::FortranMemoryMapping);
     480  bsave = b;
     481  b.ReSize(maxmn,nrhs); b = (T) 0;
     482  for(int i=0;i<n;i++) for(int j=0;j<nrhs;j++) b(i,j) = bsave(i,j);
     483  }
     484  s.ReSize(minmn);
     485 
     486  int_4 smlsiz = 25;
     487  if(typeid(T) == typeid(r_4) )               smlsiz = ilaenv_en_C(9,"SGELSD"," ",0,0,0,0);
     488  else if(typeid(T) == typeid(r_8) )          smlsiz = ilaenv_en_C(9,"DGELSD"," ",0,0,0,0);
     489  else if(typeid(T) == typeid(complex<r_4>) ) smlsiz = ilaenv_en_C(9,"CGELSD"," ",0,0,0,0);
     490  else if(typeid(T) == typeid(complex<r_8>) ) smlsiz = ilaenv_en_C(9,"ZGELSD"," ",0,0,0,0);
     491  if(smlsiz<0) smlsiz = 0;
     492
     493  r_8 dum = log((r_8)minmn/(r_8)(smlsiz+1.)) / log(2.);
     494  int_4 nlvl = int_4(dum) + 1; if(nlvl<0) nlvl = 0;
     495
     496  if(typeid(T) == typeid(r_4) ) {
     497    int_4 lwork = 12*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + minmn*nrhs + (smlsiz+1)*(smlsiz+1);
     498    r_4* work = new r_4[lwork +GARDMEM];
     499    int_4* iwork = new int_4[3*minmn*nlvl+11*minmn  +GARDMEM];
     500    r_4* sloc = new r_4[minmn];
     501    r_4 srcond = rcond;
     502    sgelsd_(&m,&n,&nrhs,(r_4*)a.Data(),&lda,
     503           (r_4*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
     504           (r_4*)work,&lwork,(int_4*)iwork,&info);
     505    for(int_4 i=0;i<minmn;i++) s(i) = sloc[i];
     506    delete [] sloc; delete [] work; delete [] iwork;
     507  } else if(typeid(T) == typeid(r_8) )  {
     508    int_4 lwork = 12*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + minmn*nrhs + (smlsiz+1)*(smlsiz+1);
     509    r_8* work = new r_8[lwork +GARDMEM];
     510    int_4* iwork = new int_4[3*minmn*nlvl+11*minmn  +GARDMEM];
     511    dgelsd_(&m,&n,&nrhs,(r_8*)a.Data(),&lda,
     512           (r_8*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
     513           (r_8*)work,&lwork,(int_4*)iwork,&info);
     514    delete [] work; delete [] iwork;
     515  } else if(typeid(T) == typeid(complex<r_4>) )  {
     516    int_4 lwork = 2*minmn + minmn*nrhs;
     517    complex<r_4>* work = new complex<r_4>[lwork +GARDMEM];
     518    int_4 lrwork = 10*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + 3*smlsiz*nrhs + (smlsiz+1)*(smlsiz+1);
     519    r_4* rwork = new r_4[lrwork +GARDMEM];
     520    int_4* iwork = new int_4[3*minmn*nlvl+11*minmn  +GARDMEM];
     521    r_4* sloc = new r_4[minmn];
     522    r_4 srcond = rcond;
     523    cgelsd_(&m,&n,&nrhs,(complex<r_4>*)a.Data(),&lda,
     524           (complex<r_4>*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
     525           (complex<r_4>*)work,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
     526    for(int_4 i=0;i<minmn;i++) s(i) = sloc[i];
     527    delete [] sloc; delete [] work; delete [] rwork; delete [] iwork;
     528  } else if(typeid(T) == typeid(complex<r_8>) )  {
     529    int_4 lwork = 2*minmn + minmn*nrhs;
     530    complex<r_8>* work = new complex<r_8>[lwork +GARDMEM];
     531    int_4 lrwork = 10*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + 3*smlsiz*nrhs + (smlsiz+1)*(smlsiz+1);
     532    r_8* rwork = new r_8[lrwork +GARDMEM];
     533    int_4* iwork = new int_4[3*minmn*nlvl+11*minmn  +GARDMEM];
     534    r_8 srcond = rcond;
     535    zgelsd_(&m,&n,&nrhs,(complex<r_8>*)a.Data(),&lda,
     536           (complex<r_8>*)b.Data(),&ldb,(r_8*)s.Data(),&srcond,&rank,
     537           (complex<r_8>*)work,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
     538    delete [] work; delete [] rwork; delete [] iwork;
     539  } else {
     540    string tn = typeid(T).name();
     541    cerr << " LapackServer::LeastSquareSolveSVD_DC(a,b) - Unsupported DataType T = " << tn << endl;
     542    throw TypeMismatchExc("LapackServer::LeastSquareSolveSVD_DC(a,b) - Unsupported DataType (T)");
     543  }
     544
     545  return(info);
     546}
     547
    422548
    423549////////////////////////////////////////////////////////////////////////////////////
     
    521647  int_4 ldu = up->Step(up->ColsKA());
    522648  int_4 ldvt = vtp->Step(vtp->ColsKA());
     649  int_4 info;
    523650
    524651  int_4 lwork = maxmn*5*wspace_size_factor;
    525652  T * work = new T[lwork];
    526   int_4 info;
    527653
    528654  if (typeid(T) == typeid(r_4) ) {
     
    535661            (r_8 *)work, &lwork, &info);
    536662  } else if (typeid(T) == typeid(complex<r_4>) ) {
    537     r_4 * rwork = new r_4[5*minmn +5];
     663    r_4 * rwork = new r_4[5*minmn +GARDMEM];
    538664    r_4 * sloc  = new r_4[minmn];
    539665    cgesvd_(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
     
    544670    delete [] rwork; delete [] sloc;
    545671  } else if (typeid(T) == typeid(complex<r_8>) )  {
    546     r_8 * rwork = new r_8[5*minmn +5];
     672    r_8 * rwork = new r_8[5*minmn +GARDMEM];
    547673    r_8 * sloc  = new r_8[minmn];
    548674    zgesvd_(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
     
    595721    r_4* sloc = new r_4[minmn];
    596722    int_4 lwork = 3*minmn*minmn + supermax;
    597     r_4* work = new r_4[lwork +5];
    598     int_4* iwork = new int_4[8*minmn +5];
     723    r_4* work = new r_4[lwork +GARDMEM];
     724    int_4* iwork = new int_4[8*minmn +GARDMEM];
    599725    sgesdd_(&jobz,&m,&n,(r_4*)a.Data(),&lda,
    600726           (r_4*)sloc,(r_4*)u.Data(),&ldu,(r_4*)vt.Data(),&ldvt,
     
    604730  } else if(typeid(T) == typeid(r_8) ) {
    605731    int_4 lwork = 3*minmn*minmn + supermax;
    606     r_8* work = new r_8[lwork +5];
    607     int_4* iwork = new int_4[8*minmn +5];
     732    r_8* work = new r_8[lwork +GARDMEM];
     733    int_4* iwork = new int_4[8*minmn +GARDMEM];
    608734    dgesdd_(&jobz,&m,&n,(r_8*)a.Data(),&lda,
    609735           (r_8*)s.Data(),(r_8*)u.Data(),&ldu,(r_8*)vt.Data(),&ldvt,
     
    613739    r_4* sloc = new r_4[minmn];
    614740    int_4 lwork = minmn*minmn+2*minmn+maxmn;
    615     complex<r_4>* work = new complex<r_4>[lwork +5];
    616     r_4* rwork = new r_4[5*minmn*minmn+5*minmn +5];
    617     int_4* iwork = new int_4[8*minmn +5];
     741    complex<r_4>* work = new complex<r_4>[lwork +GARDMEM];
     742    r_4* rwork = new r_4[5*minmn*minmn+5*minmn +GARDMEM];
     743    int_4* iwork = new int_4[8*minmn +GARDMEM];
    618744    cgesdd_(&jobz,&m,&n,(complex<r_4>*)a.Data(),&lda,
    619745           (r_4*)sloc,(complex<r_4>*)u.Data(),&ldu,(complex<r_4>*)vt.Data(),&ldvt,
     
    623749  } else if(typeid(T) == typeid(complex<r_8>) )  {
    624750    int_4 lwork = minmn*minmn+2*minmn+maxmn;
    625     complex<r_8>* work = new complex<r_8>[lwork +5];
    626     r_8* rwork = new r_8[5*minmn*minmn+5*minmn +5];
    627     int_4* iwork = new int_4[8*minmn +5];
     751    complex<r_8>* work = new complex<r_8>[lwork +GARDMEM];
     752    r_8* rwork = new r_8[5*minmn*minmn+5*minmn +GARDMEM];
     753    int_4* iwork = new int_4[8*minmn +GARDMEM];
    628754    zgesdd_(&jobz,&m,&n,(complex<r_8>*)a.Data(),&lda,
    629755           (r_8*)s.Data(),(complex<r_8>*)u.Data(),&ldu,(complex<r_8>*)vt.Data(),&ldvt,
     
    672798
    673799  if (typeid(T) == typeid(r_4) ) {
    674     int_4 lwork = 3*n-1 +5; r_4* work = new r_4[lwork];
     800    int_4 lwork = 3*n-1; r_4* work = new r_4[lwork  +GARDMEM];
    675801    r_4* w = new r_4[n];
    676802    ssyev_(&jobz,&uplo,&n,(r_4 *)a.Data(),&lda,(r_4 *)w,(r_4 *)work,&lwork,&info);
     
    678804    delete [] work; delete [] w;
    679805  } else if (typeid(T) == typeid(r_8) )  {
    680     int_4 lwork = 3*n-1 +5; r_8* work = new r_8[lwork];
     806    int_4 lwork = 3*n-1; r_8* work = new r_8[lwork  +GARDMEM];
    681807    r_8* w = new r_8[n];
    682808    dsyev_(&jobz,&uplo,&n,(r_8 *)a.Data(),&lda,(r_8 *)w,(r_8 *)work,&lwork,&info);
     
    684810    delete [] work; delete [] w;
    685811  } else if (typeid(T) == typeid(complex<r_4>) )  {
    686     int_4 lwork = 2*n-1 +5; complex<r_4>* work = new complex<r_4>[lwork];
    687     r_4* rwork = new r_4[3*n-2  +5]; r_4* w = new r_4[n];
     812    int_4 lwork = 2*n-1; complex<r_4>* work = new complex<r_4>[lwork  +GARDMEM];
     813    r_4* rwork = new r_4[3*n-2  +GARDMEM]; r_4* w = new r_4[n];
    688814    cheev_(&jobz,&uplo,&n,(complex<r_4> *)a.Data(),&lda,(r_4 *)w
    689815          ,(complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
     
    691817    delete [] work; delete [] rwork; delete [] w;
    692818  } else if (typeid(T) == typeid(complex<r_8>) )  {
    693     int_4 lwork = 2*n-1 +5; complex<r_8>* work = new complex<r_8>[lwork];
    694     r_8* rwork = new r_8[3*n-2  +5]; r_8* w = new r_8[n];
     819    int_4 lwork = 2*n-1; complex<r_8>* work = new complex<r_8>[lwork  +GARDMEM];
     820    r_8* rwork = new r_8[3*n-2  +GARDMEM]; r_8* w = new r_8[n];
    695821    zheev_(&jobz,&uplo,&n,(complex<r_8> *)a.Data(),&lda,(r_8 *)w
    696822          ,(complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
     
    755881
    756882  if (typeid(T) == typeid(r_4) ) {
    757     int_4 lwork = 4*n +5; r_4* work = new r_4[lwork];
     883    int_4 lwork = 4*n; r_4* work = new r_4[lwork  +GARDMEM];
    758884    r_4* wr = new r_4[n]; r_4* wi = new r_4[n]; r_4* vl = NULL;
    759885    sgeev_(&jobvl,&jobvr,&n,(r_4 *)a.Data(),&lda,(r_4 *)wr,(r_4 *)wi,
     
    763889    delete [] work; delete [] wr; delete [] wi;
    764890  } else if (typeid(T) == typeid(r_8) )  {
    765     int_4 lwork = 4*n +5; r_8* work = new r_8[lwork];
     891    int_4 lwork = 4*n; r_8* work = new r_8[lwork  +GARDMEM];
    766892    r_8* wr = new r_8[n]; r_8* wi = new r_8[n]; r_8* vl = NULL;
    767893    dgeev_(&jobvl,&jobvr,&n,(r_8 *)a.Data(),&lda,(r_8 *)wr,(r_8 *)wi,
     
    771897    delete [] work; delete [] wr; delete [] wi;
    772898  } else if (typeid(T) == typeid(complex<r_4>) )  {
    773     int_4 lwork = 2*n +5; complex<r_4>* work = new complex<r_4>[lwork];
    774     r_4* rwork = new r_4[2*n+5]; r_4* vl = NULL; TVector< complex<r_4> > w(n);
     899    int_4 lwork = 2*n; complex<r_4>* work = new complex<r_4>[lwork  +GARDMEM];
     900    r_4* rwork = new r_4[2*n  +GARDMEM]; r_4* vl = NULL; TVector< complex<r_4> > w(n);
    775901    cgeev_(&jobvl,&jobvr,&n,(complex<r_4> *)a.Data(),&lda,(complex<r_4> *)w.Data(),
    776902           (complex<r_4> *)vl,&ldvl,(complex<r_4> *)evec.Data(),&ldvr,
     
    779905    delete [] work; delete [] rwork;
    780906  } else if (typeid(T) == typeid(complex<r_8>) )  {
    781     int_4 lwork = 2*n +5; complex<r_8>* work = new complex<r_8>[lwork];
    782     r_8* rwork = new r_8[2*n+5]; r_8* vl = NULL;
     907    int_4 lwork = 2*n; complex<r_8>* work = new complex<r_8>[lwork  +GARDMEM];
     908    r_8* rwork = new r_8[2*n  +GARDMEM]; r_8* vl = NULL;
    783909    zgeev_(&jobvl,&jobvr,&n,(complex<r_8> *)a.Data(),&lda,(complex<r_8> *)eval.Data(),
    784910           (complex<r_8> *)vl,&ldvl,(complex<r_8> *)evec.Data(),&ldvr,
Note: See TracChangeset for help on using the changeset viewer.