source: Sophya/trunk/SophyaExt/LinAlg/intflapack.cc@ 2561

Last change on this file since 2561 was 2561, checked in by cmv, 21 years ago

add SVD decomp by Divide and Conquer (cmv 23/07/04)

File size: 33.8 KB
RevLine 
[2322]1#include <iostream>
[775]2#include "intflapack.h"
[1342]3#include "tvector.h"
4#include "tmatrix.h"
[814]5#include <typeinfo>
[775]6
[2556]7/*************** Pour memoire (Christophe) ***************
8Les dispositions memoires (FORTRAN) pour les vecteurs et matrices LAPACK:
9
101./ --- REAL X(N):
11 if an array X of dimension (N) holds a vector x,
12 then X(i) holds "x_i" for i=1,...,N
13
142./ --- REAL A(LDA,N):
15 if a two-dimensional array A of dimension (LDA,N) holds an m-by-n matrix A,
16 then A(i,j) holds "a_ij" for i=1,...,m et j=1,...,n (LDA must be at least m).
17 Note that array arguments are usually declared in the software as assumed-size
18 arrays (last dimension *), for example: REAL A(LDA,*)
19 --- Rangement en memoire:
20 | 11 12 13 14 |
21 Ex: Real A(4,4): A = | 21 22 23 24 |
22 | 31 32 33 34 |
23 | 41 42 43 44 |
24 memoire: {11 21 31 41} {12 22 32 42} {13 23 33 43} {14 24 34 44}
25 First indice (line) "i" varies then the second (column):
26 (put all the first column, then put all the second column,
27 ..., then put all the last column)
28***********************************************************/
29
[1424]30/*!
31 \defgroup LinAlg LinAlg module
32 This module contains classes and functions for complex linear
33 algebra on arrays. This module is intended mainly to have
34 classes implementing C++ interfaces between Sophya objects
35 and external linear algebra libraries, such as LAPACK.
36*/
37
38/*!
39 \class SOPHYA::LapackServer
40 \ingroup LinAlg
41 This class implements an interface to LAPACK library driver routines.
42 The LAPACK (Linear Algebra PACKage) is a collection high performance
43 routines to solve common problems in numerical linear algebra.
44 its is available from http://www.netlib.org.
45
46 The present version of our LapackServer (Feb 2001) provides only
47 interfaces for the linear system solver and singular value
48 decomposition (SVD). Only arrays with BaseArray::FortranMemoryMapping
49 can be handled by LapackServer. LapackServer can be instanciated
50 for simple and double precision real or complex array types.
51
52 The example below shows solving a linear system A*X = B
53
54 \code
55 #include "intflapack.h"
56 // ...
57 // Use FortranMemoryMapping as default
58 BaseArray::SetDefaultMemoryMapping(BaseArray::FortranMemoryMapping);
59 // Create an fill the arrays A and B
60 int n = 20;
61 Matrix A(n, n);
62 A = RandomSequence();
63 Vector X(n),B(n);
64 X = RandomSequence();
65 B = A*X;
66 // Solve the linear system A*X = B
67 LapackServer<r_8> lps;
68 lps.LinSolve(A,B);
69 // We get the result in B, which should be equal to X ...
70 // Compute the difference B-X ;
71 Vector diff = B-X;
72 \endcode
73
74*/
75
[2556]76////////////////////////////////////////////////////////////////////////////////////
[775]77extern "C" {
[2554]78// Le calculateur de workingspace
79 int_4 ilaenv_(int_4 *ispec,char *name,char *opts,int_4 *n1,int_4 *n2,int_4 *n3,int_4 *n4,
80 int_4 nc1,int_4 nc2);
81
[1342]82// Drivers pour resolution de systemes lineaires
83 void sgesv_(int_4* n, int_4* nrhs, r_4* a, int_4* lda,
84 int_4* ipiv, r_4* b, int_4* ldb, int_4* info);
85 void dgesv_(int_4* n, int_4* nrhs, r_8* a, int_4* lda,
86 int_4* ipiv, r_8* b, int_4* ldb, int_4* info);
87 void cgesv_(int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
88 int_4* ipiv, complex<r_4>* b, int_4* ldb, int_4* info);
89 void zgesv_(int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
90 int_4* ipiv, complex<r_8>* b, int_4* ldb, int_4* info);
91
[2554]92// Drivers pour resolution de systemes lineaires symetriques
93 void ssysv_(char* uplo, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
94 int_4* ipiv, r_4* b, int_4* ldb,
95 r_4* work, int_4* lwork, int_4* info);
96 void dsysv_(char* uplo, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
97 int_4* ipiv, r_8* b, int_4* ldb,
98 r_8* work, int_4* lwork, int_4* info);
99 void csysv_(char* uplo, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
100 int_4* ipiv, complex<r_4>* b, int_4* ldb,
101 complex<r_4>* work, int_4* lwork, int_4* info);
102 void zsysv_(char* uplo, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
103 int_4* ipiv, complex<r_8>* b, int_4* ldb,
104 complex<r_8>* work, int_4* lwork, int_4* info);
105
106// Driver pour resolution de systemes au sens de Xi2
[1494]107 void sgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
108 r_4* b, int_4* ldb, r_4* work, int_4* lwork, int_4* info);
109 void dgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
110 r_8* b, int_4* ldb, r_8* work, int_4* lwork, int_4* info);
111 void cgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
112 complex<r_4>* b, int_4* ldb, complex<r_4>* work, int_4* lwork, int_4* info);
113 void zgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
114 complex<r_8>* b, int_4* ldb, complex<r_8>* work, int_4* lwork, int_4* info);
115
[1342]116// Driver pour decomposition SVD
117 void sgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_4* a, int_4* lda,
118 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
119 r_4* work, int_4* lwork, int_4* info);
120 void dgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_8* a, int_4* lda,
121 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
122 r_8* work, int_4* lwork, int_4* info);
123 void cgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
[2559]124 r_4* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
125 complex<r_4>* work, int_4* lwork, r_4* rwork, int_4* info);
[1342]126 void zgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
[2559]127 r_8* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
128 complex<r_8>* work, int_4* lwork, r_8* rwork, int_4* info);
[2556]129
[2561]130// Driver pour decomposition SVD Divide and Conquer
131 void sgesdd_(char* jobz, int_4* m, int_4* n, r_4* a, int_4* lda,
132 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
133 r_4* work, int_4* lwork, int_4* iwork, int_4* info);
134 void dgesdd_(char* jobz, int_4* m, int_4* n, r_8* a, int_4* lda,
135 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
136 r_8* work, int_4* lwork, int_4* iwork, int_4* info);
137 void cgesdd_(char* jobz, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
138 r_4* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
139 complex<r_4>* work, int_4* lwork, r_4* rwork, int_4* iwork, int_4* info);
140 void zgesdd_(char* jobz, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
141 r_8* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
142 complex<r_8>* work, int_4* lwork, r_8* rwork, int_4* iwork, int_4* info);
143
[2556]144// Driver pour eigen decomposition for symetric/hermitian matrices
145 void ssyev_(char* jobz, char* uplo, int_4* n, r_4* a, int_4* lda, r_4* w,
146 r_4* work, int_4 *lwork, int_4* info);
147 void dsyev_(char* jobz, char* uplo, int_4* n, r_8* a, int_4* lda, r_8* w,
148 r_8* work, int_4 *lwork, int_4* info);
149 void cheev_(char* jobz, char* uplo, int_4* n, complex<r_4>* a, int_4* lda, r_4* w,
150 complex<r_4>* work, int_4 *lwork, r_4* rwork, int_4* info);
151 void zheev_(char* jobz, char* uplo, int_4* n, complex<r_8>* a, int_4* lda, r_8* w,
152 complex<r_8>* work, int_4 *lwork, r_8* rwork, int_4* info);
153
154// Driver pour eigen decomposition for general squared matrices
155 void sgeev_(char* jobl, char* jobvr, int_4* n, r_4* a, int_4* lda, r_4* wr, r_4* wi,
156 r_4* vl, int_4* ldvl, r_4* vr, int_4* ldvr,
157 r_4* work, int_4 *lwork, int_4* info);
158 void dgeev_(char* jobl, char* jobvr, int_4* n, r_8* a, int_4* lda, r_8* wr, r_8* wi,
159 r_8* vl, int_4* ldvl, r_8* vr, int_4* ldvr,
160 r_8* work, int_4 *lwork, int_4* info);
161 void cgeev_(char* jobl, char* jobvr, int_4* n, complex<r_4>* a, int_4* lda, complex<r_4>* w,
162 complex<r_4>* vl, int_4* ldvl, complex<r_4>* vr, int_4* ldvr,
163 complex<r_4>* work, int_4 *lwork, r_4* rwork, int_4* info);
164 void zgeev_(char* jobl, char* jobvr, int_4* n, complex<r_8>* a, int_4* lda, complex<r_8>* w,
165 complex<r_8>* vl, int_4* ldvl, complex<r_8>* vr, int_4* ldvr,
166 complex<r_8>* work, int_4 *lwork, r_8* rwork, int_4* info);
167
[775]168}
169
[1342]170// -------------- Classe LapackServer<T> --------------
171
[2556]172////////////////////////////////////////////////////////////////////////////////////
[814]173template <class T>
[1344]174LapackServer<T>::LapackServer()
[1342]175{
176 SetWorkSpaceSizeFactor();
177}
178
179template <class T>
[1344]180LapackServer<T>::~LapackServer()
[1342]181{
182}
183
[2556]184// --- ATTENTION BUG POSSIBLE dans l'avenir (CMV) --- REZA A LIRE S.T.P.
[2554]185// -> Cette connerie de Fortran/C interface
186// Dans les routines fortran de lapack:
187// Appel depuis le C avec:
188// int_4 lwork = -1;
189// SUBROUTINE SSYSV( UPLO,N,NRHS,A,LDA,IPIV,B,LDB,WORK,LWORK,INFO)
190// INTEGER INFO, LDA, LDB, LWORK, N, NRHS
191// LOGICAL LQUERY
192// LQUERY = ( LWORK.EQ.-1 )
193// ELSE IF( LWORK.LT.1 .AND. .NOT.LQUERY ) THEN
194// ==> le test est bien interprete sous Linux mais pas sous OSF
195// ==> Sous OSF "LWORK.EQ.-1" est FALSE quand on passe lwork=-1 par argument
196// ==> POUR REZA: confusion entier 4 / 8 bits ??? (bizarre on l'aurait vu avant?)
[2556]197////////////////////////////////////////////////////////////////////////////////////
[2554]198template <class T>
199int_4 LapackServer<T>::ilaenv_en_C(int_4 ispec,char *name,char *opts,int_4 n1,int_4 n2,int_4 n3,int_4 n4)
200{
201 int_4 nc1 = strlen(name);
202 int_4 nc2 = strlen(opts);
203 int_4 rc=0;
204 rc = ilaenv_(&ispec,name,opts,&n1,&n2,&n3,&n4,nc1,nc2);
205 //cout<<"ilaenv_en_C("<<ispec<<","<<name<<"("<<nc1<<"),"<<opts<<"("<<nc2<<"),"
206 // <<n1<<","<<n2<<","<<n3<<","<<n4<<") = "<<rc<<endl;
207 return rc;
208}
209
[2556]210////////////////////////////////////////////////////////////////////////////////////
[1424]211//! Interface to Lapack linear system solver driver s/d/c/zgesvd().
212/*! Solve the linear system a * x = b. Input arrays
213 should have FortranMemory mapping (column packed).
214 \param a : input matrix, overwritten on output
215 \param b : input-output, input vector b, contains x on exit
216 \return : return code from lapack driver _gesv()
217 */
[1342]218template <class T>
[1042]219int LapackServer<T>::LinSolve(TArray<T>& a, TArray<T> & b)
[814]220{
221 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
222 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
223
[1342]224 int_4 rowa = a.RowsKA();
225 int_4 cola = a.ColsKA();
226 int_4 rowb = b.RowsKA();
227 int_4 colb = b.ColsKA();
[814]228 if ( a.Size(rowa) != a.Size(cola))
229 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Not a square Array"));
[1042]230 if ( a.Size(rowa) != b.Size(rowb))
[814]231 throw(SzMismatchError("LapackServer::LinSolve(a,b) RowSize(a <> b) "));
232
233 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
[1342]234 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Column Packed"));
[814]235
236 int_4 n = a.Size(rowa);
237 int_4 nrhs = b.Size(colb);
238 int_4 lda = a.Step(cola);
239 int_4 ldb = b.Step(colb);
240 int_4 info;
241 int_4* ipiv = new int_4[n];
242
243 if (typeid(T) == typeid(r_4) )
244 sgesv_(&n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb, &info);
245 else if (typeid(T) == typeid(r_8) )
246 dgesv_(&n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb, &info);
247 else if (typeid(T) == typeid(complex<r_4>) )
248 cgesv_(&n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
249 (complex<r_4> *)b.Data(), &ldb, &info);
250 else if (typeid(T) == typeid(complex<r_8>) )
251 zgesv_(&n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
252 (complex<r_8> *)b.Data(), &ldb, &info);
253 else {
254 delete[] ipiv;
255 string tn = typeid(T).name();
256 cerr << " LapackServer::LinSolve(a,b) - Unsupported DataType T = " << tn << endl;
257 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
258 }
259 delete[] ipiv;
[1042]260 return(info);
[814]261}
262
[2556]263////////////////////////////////////////////////////////////////////////////////////
[2554]264//! Interface to Lapack linear system solver driver s/d/c/zsysvd().
265/*! Solve the linear system a * x = b with a symetric. Input arrays
266 should have FortranMemory mapping (column packed).
267 \param a : input matrix symetric , overwritten on output
268 \param b : input-output, input vector b, contains x on exit
[2561]269 \return : return code from lapack driver
[2554]270 */
271template <class T>
272int LapackServer<T>::LinSolveSym(TArray<T>& a, TArray<T> & b)
273// --- REMARQUES DE CMV ---
274// 1./ contrairement a ce qui est dit dans la doc, il s'agit
275// de matrices SYMETRIQUES complexes et non HERMITIENNES !!!
276// 2./ pourquoi les routines de LinSolve pour des matrices symetriques
[2556]277// sont plus de deux fois plus lentes que les LinSolve generales sur OSF
278// et sensiblement plus lentes sous Linux ???
[2554]279{
280 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
281 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Or b NbDimensions() != 2"));
282 int_4 rowa = a.RowsKA();
283 int_4 cola = a.ColsKA();
284 int_4 rowb = b.RowsKA();
285 int_4 colb = b.ColsKA();
286 if ( a.Size(rowa) != a.Size(cola))
287 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Not a square Array"));
288 if ( a.Size(rowa) != b.Size(rowb))
289 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) RowSize(a <> b) "));
290
291 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
292 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Or b Not Column Packed"));
293
294 int_4 n = a.Size(rowa);
295 int_4 nrhs = b.Size(colb);
296 int_4 lda = a.Step(cola);
297 int_4 ldb = b.Step(colb);
298 int_4 info = 0;
299 int_4* ipiv = new int_4[n];
300 int_4 lwork = -1;
301 T * work = NULL;
302
303 char uplo = 'U'; // char uplo = 'L';
304 char struplo[5]; struplo[0] = uplo; struplo[1] = '\0';
305
306 if (typeid(T) == typeid(r_4) ) {
[2556]307 lwork = ilaenv_en_C(1,"SSYTRF",struplo,n,-1,-1,-1) * n +5;
308 work = new T[lwork];
[2554]309 ssysv_(&uplo, &n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb,
310 (r_4 *)work, &lwork, &info);
311 } else if (typeid(T) == typeid(r_8) ) {
[2556]312 lwork = ilaenv_en_C(1,"DSYTRF",struplo,n,-1,-1,-1) * n +5;
313 work = new T[lwork];
[2554]314 dsysv_(&uplo, &n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb,
315 (r_8 *)work, &lwork, &info);
316 } else if (typeid(T) == typeid(complex<r_4>) ) {
[2556]317 lwork = ilaenv_en_C(1,"CSYTRF",struplo,n,-1,-1,-1) * n +5;
318 work = new T[lwork];
[2554]319 csysv_(&uplo, &n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
320 (complex<r_4> *)b.Data(), &ldb,
321 (complex<r_4> *)work, &lwork, &info);
322 } else if (typeid(T) == typeid(complex<r_8>) ) {
[2556]323 lwork = ilaenv_en_C(1,"ZSYTRF",struplo,n,-1,-1,-1) * n +5;
324 work = new T[lwork];
[2554]325 zsysv_(&uplo, &n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
326 (complex<r_8> *)b.Data(), &ldb,
327 (complex<r_8> *)work, &lwork, &info);
328 } else {
[2556]329 if(work) delete[] work;
[2554]330 delete[] ipiv;
331 string tn = typeid(T).name();
332 cerr << " LapackServer::LinSolveSym(a,b) - Unsupported DataType T = " << tn << endl;
333 throw TypeMismatchExc("LapackServer::LinSolveSym(a,b) - Unsupported DataType (T)");
334 }
[2556]335 if(work) delete[] work;
[2554]336 delete[] ipiv;
337 return(info);
338}
339
[2556]340////////////////////////////////////////////////////////////////////////////////////
[1566]341//! Interface to Lapack least squares solver driver s/d/c/zgels().
342/*! Solves the linear least squares problem defined by an m-by-n matrix
343 \b a and an m element vector \b b .
344 A solution \b x to the overdetermined system of linear equations
345 b = a * x is computed, minimizing the norm of b-a*x.
346 Underdetermined systems (m<n) are not yet handled.
347 Inout arrays should have FortranMemory mapping (column packed).
348 \param a : input matrix, overwritten on output
349 \param b : input-output, input vector b, contains x on exit.
350 \return : return code from lapack driver _gels()
351 \warning : b is not resized.
352 */
353/*
354 $CHECK$ - A faire - cas m<n
355 If the linear system is underdetermined, the minimum norm
356 solution is computed.
357*/
358
[1494]359template <class T>
360int LapackServer<T>::LeastSquareSolve(TArray<T>& a, TArray<T> & b)
361{
362 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
[2561]363 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) a Or b NbDimensions() != 2"));
[1494]364
365 int_4 rowa = a.RowsKA();
366 int_4 cola = a.ColsKA();
367 int_4 rowb = b.RowsKA();
368 int_4 colb = b.ColsKA();
369
370
371 if ( a.Size(rowa) != b.Size(rowb))
372 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) RowSize(a <> b) "));
373
374 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
[1566]375 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) a Or b Not Column Packed"));
[1494]376
[1566]377 if ( a.Size(rowa) < a.Size(cola)) { // $CHECK$ - m<n a changer
378 cout << " LapackServer<T>::LeastSquareSolve() - m<n - Not yet implemented for "
379 << " underdetermined systems ! " << endl;
380 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) NRows<NCols - "));
381 }
[1494]382 int_4 m = a.Size(rowa);
383 int_4 n = a.Size(cola);
384 int_4 nrhs = b.Size(colb);
385
386 int_4 lda = a.Step(cola);
387 int_4 ldb = b.Step(colb);
388 int_4 info;
389
390 int_4 minmn = (m < n) ? m : n;
391 int_4 maxmn = (m > n) ? m : n;
392 int_4 maxmnrhs = (nrhs > maxmn) ? nrhs : maxmn;
393 if (maxmnrhs < 1) maxmnrhs = 1;
394
395 int_4 lwork = minmn+maxmnrhs*5;
396 T * work = new T[lwork];
397
398 char trans = 'N';
399
400 if (typeid(T) == typeid(r_4) )
401 sgels_(&trans, &m, &n, &nrhs, (r_4 *)a.Data(), &lda,
402 (r_4 *)b.Data(), &ldb, (r_4 *)work, &lwork, &info);
403 else if (typeid(T) == typeid(r_8) )
404 dgels_(&trans, &m, &n, &nrhs, (r_8 *)a.Data(), &lda,
405 (r_8 *)b.Data(), &ldb, (r_8 *)work, &lwork, &info);
406 else if (typeid(T) == typeid(complex<r_4>) )
407 cgels_(&trans, &m, &n, &nrhs, (complex<r_4> *)a.Data(), &lda,
408 (complex<r_4> *)b.Data(), &ldb, (complex<r_4> *)work, &lwork, &info);
409 else if (typeid(T) == typeid(complex<r_8>) )
410 zgels_(&trans, &m, &n, &nrhs, (complex<r_8> *)a.Data(), &lda,
411 (complex<r_8> *)b.Data(), &ldb, (complex<r_8> *)work, &lwork, &info);
412 else {
413 delete[] work;
414 string tn = typeid(T).name();
415 cerr << " LapackServer::LeastSquareSolve(a,b) - Unsupported DataType T = " << tn << endl;
416 throw TypeMismatchExc("LapackServer::LeastSquareSolve(a,b) - Unsupported DataType (T)");
417 }
418 delete[] work;
419 return(info);
420}
421
422
[2556]423////////////////////////////////////////////////////////////////////////////////////
[1424]424//! Interface to Lapack SVD driver s/d/c/zgesv().
425/*! Computes the vector of singular values of \b a. Input arrays
426 should have FortranMemoryMapping (column packed).
427 \param a : input m-by-n matrix
428 \param s : Vector of min(m,n) singular values (descending order)
429 \return : return code from lapack driver _gesvd()
430 */
431
[1342]432template <class T>
433int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s)
434{
435 return (SVDDriver(a, s, NULL, NULL) );
436}
437
[1424]438//! Interface to Lapack SVD driver s/d/c/zgesv().
439/*! Computes the vector of singular values of \b a, as well as
440 right and left singular vectors of \b a.
441 \f[
442 A = U \Sigma V^T , ( A = U \Sigma V^H \ complex)
443 \f]
444 \f[
445 A v_i = \sigma_i u_i \ and A^T u_i = \sigma_i v_i \ (A^H \ complex)
446 \f]
447 U and V are orthogonal (unitary) matrices.
448 \param a : input m-by-n matrix (in FotranMemoryMapping)
449 \param s : Vector of min(m,n) singular values (descending order)
450 \param u : Matrix of left singular vectors
451 \param vt : Transpose of right singular vectors.
452 \return : return code from lapack driver _gesvd()
453 */
[1342]454template <class T>
455int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s, TArray<T> & u, TArray<T> & vt)
456{
457 return (SVDDriver(a, s, &u, &vt) );
458}
459
[1424]460
461//! Interface to Lapack SVD driver s/d/c/zgesv().
[1342]462template <class T>
463int LapackServer<T>::SVDDriver(TArray<T>& a, TArray<T> & s, TArray<T>* up, TArray<T>* vtp)
464{
465 if ( ( a.NbDimensions() != 2 ) )
[2561]466 throw(SzMismatchError("LapackServer::SVDDriver(a, ...) a.NbDimensions() != 2"));
[1342]467
468 int_4 rowa = a.RowsKA();
469 int_4 cola = a.ColsKA();
470
471 if ( !a.IsPacked(rowa) )
[2561]472 throw(SzMismatchError("LapackServer::SVDDriver(a, ...) a Not Column Packed "));
[1342]473
474 int_4 m = a.Size(rowa);
475 int_4 n = a.Size(cola);
476 int_4 maxmn = (m > n) ? m : n;
477 int_4 minmn = (m < n) ? m : n;
478
479 char jobu, jobvt;
480 jobu = 'N';
481 jobvt = 'N';
482
483 sa_size_t sz[2];
484 if ( up != NULL) {
485 if ( dynamic_cast< TVector<T> * > (vtp) )
[2561]486 throw( TypeMismatchExc("LapackServer::SVDDriver() Wrong type (=TVector<T>) for u !") );
[1342]487 up->SetMemoryMapping(BaseArray::FortranMemoryMapping);
488 sz[0] = sz[1] = m;
489 up->ReSize(2, sz );
490 jobu = 'A';
491 }
492 else {
493 up = new TMatrix<T>(1,1);
494 jobu = 'N';
495 }
496 if ( vtp != NULL) {
497 if ( dynamic_cast< TVector<T> * > (vtp) )
[2561]498 throw( TypeMismatchExc("LapackServer::SVDDriver() Wrong type (=TVector<T>) for vt !") );
[1342]499 vtp->SetMemoryMapping(BaseArray::FortranMemoryMapping);
500 sz[0] = sz[1] = n;
501 vtp->ReSize(2, sz );
502 jobvt = 'A';
503 }
504 else {
505 vtp = new TMatrix<T>(1,1);
506 jobvt = 'N';
507 }
508
509 TVector<T> *vs = dynamic_cast< TVector<T> * > (&s);
510 if (vs) vs->ReSize(minmn);
511 else {
512 TMatrix<T> *ms = dynamic_cast< TMatrix<T> * > (&s);
513 if (ms) ms->ReSize(minmn,1);
514 else {
515 sz[0] = minmn; sz[1] = 1;
516 s.ReSize(1, sz);
517 }
518 }
519
520 int_4 lda = a.Step(a.ColsKA());
521 int_4 ldu = up->Step(up->ColsKA());
522 int_4 ldvt = vtp->Step(vtp->ColsKA());
523
524 int_4 lwork = maxmn*5*wspace_size_factor;
525 T * work = new T[lwork];
[2561]526 int_4 info;
[1342]527
[2559]528 if (typeid(T) == typeid(r_4) ) {
[1342]529 sgesvd_(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
530 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
531 (r_4 *)work, &lwork, &info);
[2559]532 } else if (typeid(T) == typeid(r_8) ) {
[1342]533 dgesvd_(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
534 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
535 (r_8 *)work, &lwork, &info);
[2559]536 } else if (typeid(T) == typeid(complex<r_4>) ) {
537 r_4 * rwork = new r_4[5*minmn +5];
538 r_4 * sloc = new r_4[minmn];
[1342]539 cgesvd_(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
[2559]540 (r_4 *)sloc, (complex<r_4> *) up->Data(), &ldu,
[1342]541 (complex<r_4> *)vtp->Data(), &ldvt,
[2559]542 (complex<r_4> *)work, &lwork, (r_4 *)rwork, &info);
543 for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
544 delete [] rwork; delete [] sloc;
545 } else if (typeid(T) == typeid(complex<r_8>) ) {
546 r_8 * rwork = new r_8[5*minmn +5];
547 r_8 * sloc = new r_8[minmn];
[1342]548 zgesvd_(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
[2559]549 (r_8 *)sloc, (complex<r_8> *) up->Data(), &ldu,
[1342]550 (complex<r_8> *)vtp->Data(), &ldvt,
[2559]551 (complex<r_8> *)work, &lwork, (r_8 *)rwork, &info);
552 for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
553 delete [] rwork; delete [] sloc;
554 } else {
[1342]555 if (jobu == 'N') delete up;
556 if (jobvt == 'N') delete vtp;
557 string tn = typeid(T).name();
558 cerr << " LapackServer::SVDDriver(...) - Unsupported DataType T = " << tn << endl;
[2561]559 throw TypeMismatchExc("LapackServer::SVDDriver(a,b) - Unsupported DataType (T)");
[1342]560 }
561
562 if (jobu == 'N') delete up;
563 if (jobvt == 'N') delete vtp;
564 return(info);
565}
566
[2556]567
[2561]568//! Interface to Lapack SVD driver s/d/c/zgesdd().
569/*! Same as SVD but with Divide and Conquer method */
570template <class T>
571int LapackServer<T>::SVD_DC(TMatrix<T>& a, TVector<T>& s, TMatrix<T>& u, TMatrix<T>& vt)
572{
573
574 if ( !a.IsPacked() )
575 throw(SzMismatchError("LapackServer::SVD_DC(a, ...) a Not Packed "));
576
577 int_4 m = a.NRows();
578 int_4 n = a.NCols();
579 int_4 maxmn = (m > n) ? m : n;
580 int_4 minmn = (m < n) ? m : n;
581 int_4 supermax = 4*minmn*minmn+4*minmn; if(maxmn>supermax) supermax=maxmn;
582
583 char jobz = 'A';
584
585 s.ReSize(minmn);
586 u.ReSize(m,m);
587 vt.ReSize(n,n);
588
589 int_4 lda = n;
590 int_4 ldu = u.NCols();
591 int_4 ldvt = vt.NCols();
592 int_4 info;
593
594 if(typeid(T) == typeid(r_4) ) {
595 int_4 lwork = 3*minmn*minmn + supermax;
596 r_4* work = new r_4[lwork +5];
597 int_4* iwork = new int_4[8*minmn +5];
598 sgesdd_(&jobz,&m,&n,(r_4*)a.Data(),&lda,
599 (r_4*)s.Data(),(r_4*)u.Data(),&ldu,(r_4*)vt.Data(),&ldvt,
600 (r_4*)work,&lwork,(int_4*)iwork,&info);
601 delete [] work; delete [] iwork;
602 } else if(typeid(T) == typeid(r_8) ) {
603 int_4 lwork = 3*minmn*minmn + supermax;
604 r_8* work = new r_8[lwork +5];
605 int_4* iwork = new int_4[8*minmn +5];
606 dgesdd_(&jobz,&m,&n,(r_8*)a.Data(),&lda,
607 (r_8*)s.Data(),(r_8*)u.Data(),&ldu,(r_8*)vt.Data(),&ldvt,
608 (r_8*)work,&lwork,(int_4*)iwork,&info);
609 delete [] work; delete [] iwork;
610 } else if(typeid(T) == typeid(complex<r_4>) ) {
611 r_4* sloc = new r_4[minmn];
612 int_4 lwork = minmn*minmn+2*minmn+maxmn;
613 complex<r_4>* work = new complex<r_4>[lwork +5];
614 r_4* rwork = new r_4[5*minmn*minmn+5*minmn +5];
615 int_4* iwork = new int_4[8*minmn +5];
616 cgesdd_(&jobz,&m,&n,(complex<r_4>*)a.Data(),&lda,
617 (r_4*)sloc,(complex<r_4>*)u.Data(),&ldu,(complex<r_4>*)vt.Data(),&ldvt,
618 (complex<r_4>*)work,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
619 for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
620 delete [] sloc; delete [] work; delete [] rwork; delete [] iwork;
621 } else if(typeid(T) == typeid(complex<r_8>) ) {
622 r_8* sloc = new r_8[minmn];
623 int_4 lwork = minmn*minmn+2*minmn+maxmn;
624 complex<r_8>* work = new complex<r_8>[lwork +5];
625 r_8* rwork = new r_8[5*minmn*minmn+5*minmn +5];
626 int_4* iwork = new int_4[8*minmn +5];
627 zgesdd_(&jobz,&m,&n,(complex<r_8>*)a.Data(),&lda,
628 (r_8*)sloc,(complex<r_8>*)u.Data(),&ldu,(complex<r_8>*)vt.Data(),&ldvt,
629 (complex<r_8>*)work,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
630 for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
631 delete [] sloc; delete [] work; delete [] rwork; delete [] iwork;
632 } else {
633 string tn = typeid(T).name();
634 cerr << " LapackServer::SVD_DC(...) - Unsupported DataType T = " << tn << endl;
635 throw TypeMismatchExc("LapackServer::SVD_DC - Unsupported DataType (T)");
636 }
637
638 return(info);
639}
640
641
[2556]642////////////////////////////////////////////////////////////////////////////////////
643/*! Computes the eigen values and eigen vectors of a symetric (or hermitian) matrix \b a.
644 Input arrays should have FortranMemoryMapping (column packed).
645 \param a : input symetric (or hermitian) n-by-n matrix
646 \param b : Vector of eigenvalues (descending order)
647 \param eigenvector : if true compute eigenvectors, if not only eigenvalues
648 \param a : on return array of eigenvectors (same order than eval, one vector = one column)
[2561]649 \return : return code from lapack driver
[2556]650 */
651
652template <class T>
653int LapackServer<T>::LapackEigenSym(TArray<T>& a, TVector<r_8>& b, bool eigenvector)
654{
655 if ( a.NbDimensions() != 2 )
656 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a NbDimensions() != 2"));
657 int_4 rowa = a.RowsKA();
658 int_4 cola = a.ColsKA();
659 if ( a.Size(rowa) != a.Size(cola))
660 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a Not a square Array"));
661 if (!a.IsPacked(rowa))
662 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a Not Column Packed"));
663
[2561]664 char uplo='U';
[2556]665 char jobz='N'; if(eigenvector) jobz='V';
666
667 int_4 n = a.Size(rowa);
668 int_4 lda = a.Step(cola);
669 int_4 info = 0;
670
671 b.ReSize(n); b = 0.;
672
673 if (typeid(T) == typeid(r_4) ) {
674 int_4 lwork = 3*n-1 +5; r_4* work = new r_4[lwork];
675 r_4* w = new r_4[n];
[2561]676 ssyev_(&jobz,&uplo,&n,(r_4 *)a.Data(),&lda,(r_4 *)w,(r_4 *)work,&lwork,&info);
[2556]677 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
678 delete [] work; delete [] w;
679 } else if (typeid(T) == typeid(r_8) ) {
680 int_4 lwork = 3*n-1 +5; r_8* work = new r_8[lwork];
681 r_8* w = new r_8[n];
[2561]682 dsyev_(&jobz,&uplo,&n,(r_8 *)a.Data(),&lda,(r_8 *)w,(r_8 *)work,&lwork,&info);
[2556]683 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
684 delete [] work; delete [] w;
685 } else if (typeid(T) == typeid(complex<r_4>) ) {
686 int_4 lwork = 2*n-1 +5; complex<r_4>* work = new complex<r_4>[lwork];
687 r_4* rwork = new r_4[3*n-2 +5]; r_4* w = new r_4[n];
[2561]688 cheev_(&jobz,&uplo,&n,(complex<r_4> *)a.Data(),&lda,(r_4 *)w
[2556]689 ,(complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
690 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
691 delete [] work; delete [] rwork; delete [] w;
692 } else if (typeid(T) == typeid(complex<r_8>) ) {
693 int_4 lwork = 2*n-1 +5; complex<r_8>* work = new complex<r_8>[lwork];
694 r_8* rwork = new r_8[3*n-2 +5]; r_8* w = new r_8[n];
[2561]695 zheev_(&jobz,&uplo,&n,(complex<r_8> *)a.Data(),&lda,(r_8 *)w
[2556]696 ,(complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
697 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
698 delete [] work; delete [] rwork; delete [] w;
699 } else {
700 string tn = typeid(T).name();
701 cerr << " LapackServer::LapackEigenSym(a,b) - Unsupported DataType T = " << tn << endl;
702 throw TypeMismatchExc("LapackServer::LapackEigenSym(a,b) - Unsupported DataType (T)");
703 }
704
705 return(info);
706}
707
708////////////////////////////////////////////////////////////////////////////////////
709/*! Computes the eigen values and eigen vectors of a general squared matrix \b a.
710 Input arrays should have FortranMemoryMapping (column packed).
711 \param a : input general n-by-n matrix
712 \param eval : Vector of eigenvalues (complex double precision)
713 \param evec : Matrix of eigenvector (same order than eval, one vector = one column)
714 \param eigenvector : if true compute (right) eigenvectors, if not only eigenvalues
715 \param a : on return array of eigenvectors
[2561]716 \return : return code from lapack driver
[2556]717 \verbatim
718 eval : contains the computed eigenvalues.
719 --- For real matrices "a" :
720 Complex conjugate pairs of eigenvalues appear consecutively
721 with the eigenvalue having the positive imaginary part first.
722 evec : the right eigenvectors v(j) are stored one after another
723 in the columns of evec, in the same order as their eigenvalues.
724 --- For real matrices "a" :
725 If the j-th eigenvalue is real, then v(j) = evec(:,j),
726 the j-th column of evec.
727 If the j-th and (j+1)-st eigenvalues form a complex
728 conjugate pair, then v(j) = evec(:,j) + i*evec(:,j+1) and
729 v(j+1) = evec(:,j) - i*evec(:,j+1).
730 \endverbatim
731*/
732
733template <class T>
734int LapackServer<T>::LapackEigen(TArray<T>& a, TVector< complex<r_8> >& eval, TMatrix<T>& evec, bool eigenvector)
735{
736 if ( a.NbDimensions() != 2 )
737 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a NbDimensions() != 2"));
738 int_4 rowa = a.RowsKA();
739 int_4 cola = a.ColsKA();
740 if ( a.Size(rowa) != a.Size(cola))
741 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a Not a square Array"));
742 if (!a.IsPacked(rowa))
743 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a Not Column Packed"));
744
[2561]745 char jobvl = 'N';
[2556]746 char jobvr = 'N'; if(eigenvector) jobvr='V';
747
748 int_4 n = a.Size(rowa);
749 int_4 lda = a.Step(cola);
750 int_4 info = 0;
751
752 eval.ReSize(n); eval = complex<r_8>(0.,0.);
753 if(eigenvector) {evec.ReSize(n,n); evec = (T) 0.;}
754 int_4 ldvr = n, ldvl = 1;
755
756 if (typeid(T) == typeid(r_4) ) {
757 int_4 lwork = 4*n +5; r_4* work = new r_4[lwork];
758 r_4* wr = new r_4[n]; r_4* wi = new r_4[n]; r_4* vl = NULL;
[2561]759 sgeev_(&jobvl,&jobvr,&n,(r_4 *)a.Data(),&lda,(r_4 *)wr,(r_4 *)wi,
[2556]760 (r_4 *)vl,&ldvl,(r_4 *)evec.Data(),&ldvr,
761 (r_4 *)work,&lwork,&info);
762 if(info==0) for(int i=0;i<n;i++) eval(i) = complex<r_8>(wr[i],wi[i]);
763 delete [] work; delete [] wr; delete [] wi;
764 } else if (typeid(T) == typeid(r_8) ) {
765 int_4 lwork = 4*n +5; r_8* work = new r_8[lwork];
766 r_8* wr = new r_8[n]; r_8* wi = new r_8[n]; r_8* vl = NULL;
[2561]767 dgeev_(&jobvl,&jobvr,&n,(r_8 *)a.Data(),&lda,(r_8 *)wr,(r_8 *)wi,
[2556]768 (r_8 *)vl,&ldvl,(r_8 *)evec.Data(),&ldvr,
769 (r_8 *)work,&lwork,&info);
770 if(info==0) for(int i=0;i<n;i++) eval(i) = complex<r_8>(wr[i],wi[i]);
771 delete [] work; delete [] wr; delete [] wi;
772 } else if (typeid(T) == typeid(complex<r_4>) ) {
773 int_4 lwork = 2*n +5; complex<r_4>* work = new complex<r_4>[lwork];
774 r_4* rwork = new r_4[2*n+5]; r_4* vl = NULL; TVector< complex<r_4> > w(n);
[2561]775 cgeev_(&jobvl,&jobvr,&n,(complex<r_4> *)a.Data(),&lda,(complex<r_4> *)w.Data(),
[2556]776 (complex<r_4> *)vl,&ldvl,(complex<r_4> *)evec.Data(),&ldvr,
777 (complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
778 if(info==0) for(int i=0;i<n;i++) eval(i) = w(i);
779 delete [] work; delete [] rwork;
780 } else if (typeid(T) == typeid(complex<r_8>) ) {
781 int_4 lwork = 2*n +5; complex<r_8>* work = new complex<r_8>[lwork];
782 r_8* rwork = new r_8[2*n+5]; r_8* vl = NULL;
[2561]783 zgeev_(&jobvl,&jobvr,&n,(complex<r_8> *)a.Data(),&lda,(complex<r_8> *)eval.Data(),
[2556]784 (complex<r_8> *)vl,&ldvl,(complex<r_8> *)evec.Data(),&ldvr,
785 (complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
786 delete [] work; delete [] rwork;
787 } else {
788 string tn = typeid(T).name();
789 cerr << " LapackServer::LapackEigen(a,b) - Unsupported DataType T = " << tn << endl;
790 throw TypeMismatchExc("LapackServer::LapackEigen(a,b) - Unsupported DataType (T)");
791 }
792
793 return(info);
794}
795
796
797
798
799
800////////////////////////////////////////////////////////////////////////////////////
[775]801void rztest_lapack(TArray<r_4>& aa, TArray<r_4>& bb)
802{
803 if ( aa.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
804 if ( aa.SizeX() != aa.SizeY()) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a square Matrix"));
805 if ( bb.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
[788]806 if ( bb.SizeX() != aa.SizeX() ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A <> B "));
[775]807 if ( !bb.IsPacked() || !bb.IsPacked() )
808 throw(SzMismatchError("rztest_lapack(TMatrix<r_4> Not packed A or B "));
809
[788]810 int_4 n = aa.SizeX();
811 int_4 nrhs = bb.SizeY();
[775]812 int_4 lda = n;
[788]813 int_4 ldb = bb.SizeX();
[775]814 int_4 info;
815 int_4* ipiv = new int_4[n];
816 sgesv_(&n, &nrhs, aa.Data(), &lda, ipiv, bb.Data(), &ldb, &info);
[814]817 delete[] ipiv;
[775]818 cout << "rztest_lapack/Info= " << info << endl;
819 cout << aa << "\n" << bb << endl;
820 return;
821}
[814]822
823///////////////////////////////////////////////////////////////
824#ifdef __CXX_PRAGMA_TEMPLATES__
825#pragma define_template LapackServer<r_4>
826#pragma define_template LapackServer<r_8>
827#pragma define_template LapackServer< complex<r_4> >
828#pragma define_template LapackServer< complex<r_8> >
829#endif
830
831#if defined(ANSI_TEMPLATES) || defined(GNU_TEMPLATES)
832template class LapackServer<r_4>;
833template class LapackServer<r_8>;
834template class LapackServer< complex<r_4> >;
835template class LapackServer< complex<r_8> >;
836#endif
837
838#if defined(OS_LINUX)
839// Pour le link avec f2c sous Linux
840extern "C" {
841 void MAIN__();
842}
843
844void MAIN__()
845{
846 cerr << "MAIN__() function for linking with libf2c.a " << endl;
847 cerr << " This function should never be called !!! " << endl;
848 throw PError("MAIN__() should not be called - see intflapack.cc");
849}
850#endif
Note: See TracBrowser for help on using the repository browser.