source: Sophya/trunk/SophyaExt/LinAlg/intflapack.cc@ 2556

Last change on this file since 2556 was 2556, checked in by cmv, 21 years ago
  • Introduction de l'interface Lapack d'inversion des matrices symetriques
  • Introduction de l'interface Lapack de recherche de valeurs et vecteurs propres (cas general, symetrique et hermitique)
  • Introduction d'un fonction d'interface pour le calculateur de workspace (ilaenv_)
  • Commentaires sur les diverses methodes et sur les matrices FORTRAN
  • Pour tester cf Tests/tsttminv.cc

(cmv, 21/07/04)

File size: 30.2 KB
Line 
1#include <iostream>
2#include "intflapack.h"
3#include "tvector.h"
4#include "tmatrix.h"
5#include <typeinfo>
6
7/*************** Pour memoire (Christophe) ***************
8Les dispositions memoires (FORTRAN) pour les vecteurs et matrices LAPACK:
9
101./ --- REAL X(N):
11 if an array X of dimension (N) holds a vector x,
12 then X(i) holds "x_i" for i=1,...,N
13
142./ --- REAL A(LDA,N):
15 if a two-dimensional array A of dimension (LDA,N) holds an m-by-n matrix A,
16 then A(i,j) holds "a_ij" for i=1,...,m et j=1,...,n (LDA must be at least m).
17 Note that array arguments are usually declared in the software as assumed-size
18 arrays (last dimension *), for example: REAL A(LDA,*)
19 --- Rangement en memoire:
20 | 11 12 13 14 |
21 Ex: Real A(4,4): A = | 21 22 23 24 |
22 | 31 32 33 34 |
23 | 41 42 43 44 |
24 memoire: {11 21 31 41} {12 22 32 42} {13 23 33 43} {14 24 34 44}
25 First indice (line) "i" varies then the second (column):
26 (put all the first column, then put all the second column,
27 ..., then put all the last column)
28***********************************************************/
29
30/*!
31 \defgroup LinAlg LinAlg module
32 This module contains classes and functions for complex linear
33 algebra on arrays. This module is intended mainly to have
34 classes implementing C++ interfaces between Sophya objects
35 and external linear algebra libraries, such as LAPACK.
36*/
37
38/*!
39 \class SOPHYA::LapackServer
40 \ingroup LinAlg
41 This class implements an interface to LAPACK library driver routines.
42 The LAPACK (Linear Algebra PACKage) is a collection high performance
43 routines to solve common problems in numerical linear algebra.
44 its is available from http://www.netlib.org.
45
46 The present version of our LapackServer (Feb 2001) provides only
47 interfaces for the linear system solver and singular value
48 decomposition (SVD). Only arrays with BaseArray::FortranMemoryMapping
49 can be handled by LapackServer. LapackServer can be instanciated
50 for simple and double precision real or complex array types.
51
52 The example below shows solving a linear system A*X = B
53
54 \code
55 #include "intflapack.h"
56 // ...
57 // Use FortranMemoryMapping as default
58 BaseArray::SetDefaultMemoryMapping(BaseArray::FortranMemoryMapping);
59 // Create an fill the arrays A and B
60 int n = 20;
61 Matrix A(n, n);
62 A = RandomSequence();
63 Vector X(n),B(n);
64 X = RandomSequence();
65 B = A*X;
66 // Solve the linear system A*X = B
67 LapackServer<r_8> lps;
68 lps.LinSolve(A,B);
69 // We get the result in B, which should be equal to X ...
70 // Compute the difference B-X ;
71 Vector diff = B-X;
72 \endcode
73
74*/
75
76////////////////////////////////////////////////////////////////////////////////////
77extern "C" {
78// Le calculateur de workingspace
79 int_4 ilaenv_(int_4 *ispec,char *name,char *opts,int_4 *n1,int_4 *n2,int_4 *n3,int_4 *n4,
80 int_4 nc1,int_4 nc2);
81
82// Drivers pour resolution de systemes lineaires
83 void sgesv_(int_4* n, int_4* nrhs, r_4* a, int_4* lda,
84 int_4* ipiv, r_4* b, int_4* ldb, int_4* info);
85 void dgesv_(int_4* n, int_4* nrhs, r_8* a, int_4* lda,
86 int_4* ipiv, r_8* b, int_4* ldb, int_4* info);
87 void cgesv_(int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
88 int_4* ipiv, complex<r_4>* b, int_4* ldb, int_4* info);
89 void zgesv_(int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
90 int_4* ipiv, complex<r_8>* b, int_4* ldb, int_4* info);
91
92// Drivers pour resolution de systemes lineaires symetriques
93 void ssysv_(char* uplo, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
94 int_4* ipiv, r_4* b, int_4* ldb,
95 r_4* work, int_4* lwork, int_4* info);
96 void dsysv_(char* uplo, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
97 int_4* ipiv, r_8* b, int_4* ldb,
98 r_8* work, int_4* lwork, int_4* info);
99 void csysv_(char* uplo, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
100 int_4* ipiv, complex<r_4>* b, int_4* ldb,
101 complex<r_4>* work, int_4* lwork, int_4* info);
102 void zsysv_(char* uplo, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
103 int_4* ipiv, complex<r_8>* b, int_4* ldb,
104 complex<r_8>* work, int_4* lwork, int_4* info);
105
106// Driver pour resolution de systemes au sens de Xi2
107 void sgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
108 r_4* b, int_4* ldb, r_4* work, int_4* lwork, int_4* info);
109 void dgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
110 r_8* b, int_4* ldb, r_8* work, int_4* lwork, int_4* info);
111 void cgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
112 complex<r_4>* b, int_4* ldb, complex<r_4>* work, int_4* lwork, int_4* info);
113 void zgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
114 complex<r_8>* b, int_4* ldb, complex<r_8>* work, int_4* lwork, int_4* info);
115
116// Driver pour decomposition SVD
117 void sgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_4* a, int_4* lda,
118 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
119 r_4* work, int_4* lwork, int_4* info);
120 void dgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_8* a, int_4* lda,
121 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
122 r_8* work, int_4* lwork, int_4* info);
123 void cgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
124 complex<r_4>* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
125 complex<r_4>* work, int_4* lwork, int_4* info);
126 void zgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
127 complex<r_8>* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
128 complex<r_8>* work, int_4* lwork, int_4* info);
129
130// Driver pour eigen decomposition for symetric/hermitian matrices
131 void ssyev_(char* jobz, char* uplo, int_4* n, r_4* a, int_4* lda, r_4* w,
132 r_4* work, int_4 *lwork, int_4* info);
133 void dsyev_(char* jobz, char* uplo, int_4* n, r_8* a, int_4* lda, r_8* w,
134 r_8* work, int_4 *lwork, int_4* info);
135 void cheev_(char* jobz, char* uplo, int_4* n, complex<r_4>* a, int_4* lda, r_4* w,
136 complex<r_4>* work, int_4 *lwork, r_4* rwork, int_4* info);
137 void zheev_(char* jobz, char* uplo, int_4* n, complex<r_8>* a, int_4* lda, r_8* w,
138 complex<r_8>* work, int_4 *lwork, r_8* rwork, int_4* info);
139
140// Driver pour eigen decomposition for general squared matrices
141 void sgeev_(char* jobl, char* jobvr, int_4* n, r_4* a, int_4* lda, r_4* wr, r_4* wi,
142 r_4* vl, int_4* ldvl, r_4* vr, int_4* ldvr,
143 r_4* work, int_4 *lwork, int_4* info);
144 void dgeev_(char* jobl, char* jobvr, int_4* n, r_8* a, int_4* lda, r_8* wr, r_8* wi,
145 r_8* vl, int_4* ldvl, r_8* vr, int_4* ldvr,
146 r_8* work, int_4 *lwork, int_4* info);
147 void cgeev_(char* jobl, char* jobvr, int_4* n, complex<r_4>* a, int_4* lda, complex<r_4>* w,
148 complex<r_4>* vl, int_4* ldvl, complex<r_4>* vr, int_4* ldvr,
149 complex<r_4>* work, int_4 *lwork, r_4* rwork, int_4* info);
150 void zgeev_(char* jobl, char* jobvr, int_4* n, complex<r_8>* a, int_4* lda, complex<r_8>* w,
151 complex<r_8>* vl, int_4* ldvl, complex<r_8>* vr, int_4* ldvr,
152 complex<r_8>* work, int_4 *lwork, r_8* rwork, int_4* info);
153
154}
155
156// -------------- Classe LapackServer<T> --------------
157
158////////////////////////////////////////////////////////////////////////////////////
159template <class T>
160LapackServer<T>::LapackServer()
161{
162 SetWorkSpaceSizeFactor();
163}
164
165template <class T>
166LapackServer<T>::~LapackServer()
167{
168}
169
170// --- ATTENTION BUG POSSIBLE dans l'avenir (CMV) --- REZA A LIRE S.T.P.
171// -> Cette connerie de Fortran/C interface
172// Dans les routines fortran de lapack:
173// Appel depuis le C avec:
174// int_4 lwork = -1;
175// SUBROUTINE SSYSV( UPLO,N,NRHS,A,LDA,IPIV,B,LDB,WORK,LWORK,INFO)
176// INTEGER INFO, LDA, LDB, LWORK, N, NRHS
177// LOGICAL LQUERY
178// LQUERY = ( LWORK.EQ.-1 )
179// ELSE IF( LWORK.LT.1 .AND. .NOT.LQUERY ) THEN
180// ==> le test est bien interprete sous Linux mais pas sous OSF
181// ==> Sous OSF "LWORK.EQ.-1" est FALSE quand on passe lwork=-1 par argument
182// ==> POUR REZA: confusion entier 4 / 8 bits ??? (bizarre on l'aurait vu avant?)
183////////////////////////////////////////////////////////////////////////////////////
184template <class T>
185int_4 LapackServer<T>::ilaenv_en_C(int_4 ispec,char *name,char *opts,int_4 n1,int_4 n2,int_4 n3,int_4 n4)
186{
187 int_4 nc1 = strlen(name);
188 int_4 nc2 = strlen(opts);
189 int_4 rc=0;
190 rc = ilaenv_(&ispec,name,opts,&n1,&n2,&n3,&n4,nc1,nc2);
191 //cout<<"ilaenv_en_C("<<ispec<<","<<name<<"("<<nc1<<"),"<<opts<<"("<<nc2<<"),"
192 // <<n1<<","<<n2<<","<<n3<<","<<n4<<") = "<<rc<<endl;
193 return rc;
194}
195
196////////////////////////////////////////////////////////////////////////////////////
197//! Interface to Lapack linear system solver driver s/d/c/zgesvd().
198/*! Solve the linear system a * x = b. Input arrays
199 should have FortranMemory mapping (column packed).
200 \param a : input matrix, overwritten on output
201 \param b : input-output, input vector b, contains x on exit
202 \return : return code from lapack driver _gesv()
203 */
204template <class T>
205int LapackServer<T>::LinSolve(TArray<T>& a, TArray<T> & b)
206{
207 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
208 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
209
210 int_4 rowa = a.RowsKA();
211 int_4 cola = a.ColsKA();
212 int_4 rowb = b.RowsKA();
213 int_4 colb = b.ColsKA();
214 if ( a.Size(rowa) != a.Size(cola))
215 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Not a square Array"));
216 if ( a.Size(rowa) != b.Size(rowb))
217 throw(SzMismatchError("LapackServer::LinSolve(a,b) RowSize(a <> b) "));
218
219 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
220 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Column Packed"));
221
222 int_4 n = a.Size(rowa);
223 int_4 nrhs = b.Size(colb);
224 int_4 lda = a.Step(cola);
225 int_4 ldb = b.Step(colb);
226 int_4 info;
227 int_4* ipiv = new int_4[n];
228
229 if (typeid(T) == typeid(r_4) )
230 sgesv_(&n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb, &info);
231 else if (typeid(T) == typeid(r_8) )
232 dgesv_(&n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb, &info);
233 else if (typeid(T) == typeid(complex<r_4>) )
234 cgesv_(&n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
235 (complex<r_4> *)b.Data(), &ldb, &info);
236 else if (typeid(T) == typeid(complex<r_8>) )
237 zgesv_(&n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
238 (complex<r_8> *)b.Data(), &ldb, &info);
239 else {
240 delete[] ipiv;
241 string tn = typeid(T).name();
242 cerr << " LapackServer::LinSolve(a,b) - Unsupported DataType T = " << tn << endl;
243 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
244 }
245 delete[] ipiv;
246 return(info);
247}
248
249////////////////////////////////////////////////////////////////////////////////////
250//! Interface to Lapack linear system solver driver s/d/c/zsysvd().
251/*! Solve the linear system a * x = b with a symetric. Input arrays
252 should have FortranMemory mapping (column packed).
253 \param a : input matrix symetric , overwritten on output
254 \param b : input-output, input vector b, contains x on exit
255 \return : return code from lapack driver _gesv()
256 */
257template <class T>
258int LapackServer<T>::LinSolveSym(TArray<T>& a, TArray<T> & b)
259// --- REMARQUES DE CMV ---
260// 1./ contrairement a ce qui est dit dans la doc, il s'agit
261// de matrices SYMETRIQUES complexes et non HERMITIENNES !!!
262// 2./ pourquoi les routines de LinSolve pour des matrices symetriques
263// sont plus de deux fois plus lentes que les LinSolve generales sur OSF
264// et sensiblement plus lentes sous Linux ???
265{
266 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
267 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Or b NbDimensions() != 2"));
268 int_4 rowa = a.RowsKA();
269 int_4 cola = a.ColsKA();
270 int_4 rowb = b.RowsKA();
271 int_4 colb = b.ColsKA();
272 if ( a.Size(rowa) != a.Size(cola))
273 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Not a square Array"));
274 if ( a.Size(rowa) != b.Size(rowb))
275 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) RowSize(a <> b) "));
276
277 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
278 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Or b Not Column Packed"));
279
280 int_4 n = a.Size(rowa);
281 int_4 nrhs = b.Size(colb);
282 int_4 lda = a.Step(cola);
283 int_4 ldb = b.Step(colb);
284 int_4 info = 0;
285 int_4* ipiv = new int_4[n];
286 int_4 lwork = -1;
287 T * work = NULL;
288
289 char uplo = 'U'; // char uplo = 'L';
290 char struplo[5]; struplo[0] = uplo; struplo[1] = '\0';
291
292 if (typeid(T) == typeid(r_4) ) {
293 lwork = ilaenv_en_C(1,"SSYTRF",struplo,n,-1,-1,-1) * n +5;
294 work = new T[lwork];
295 ssysv_(&uplo, &n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb,
296 (r_4 *)work, &lwork, &info);
297 } else if (typeid(T) == typeid(r_8) ) {
298 lwork = ilaenv_en_C(1,"DSYTRF",struplo,n,-1,-1,-1) * n +5;
299 work = new T[lwork];
300 dsysv_(&uplo, &n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb,
301 (r_8 *)work, &lwork, &info);
302 } else if (typeid(T) == typeid(complex<r_4>) ) {
303 lwork = ilaenv_en_C(1,"CSYTRF",struplo,n,-1,-1,-1) * n +5;
304 work = new T[lwork];
305 csysv_(&uplo, &n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
306 (complex<r_4> *)b.Data(), &ldb,
307 (complex<r_4> *)work, &lwork, &info);
308 } else if (typeid(T) == typeid(complex<r_8>) ) {
309 lwork = ilaenv_en_C(1,"ZSYTRF",struplo,n,-1,-1,-1) * n +5;
310 work = new T[lwork];
311 zsysv_(&uplo, &n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
312 (complex<r_8> *)b.Data(), &ldb,
313 (complex<r_8> *)work, &lwork, &info);
314 } else {
315 if(work) delete[] work;
316 delete[] ipiv;
317 string tn = typeid(T).name();
318 cerr << " LapackServer::LinSolveSym(a,b) - Unsupported DataType T = " << tn << endl;
319 throw TypeMismatchExc("LapackServer::LinSolveSym(a,b) - Unsupported DataType (T)");
320 }
321 if(work) delete[] work;
322 delete[] ipiv;
323 return(info);
324}
325
326////////////////////////////////////////////////////////////////////////////////////
327//! Interface to Lapack least squares solver driver s/d/c/zgels().
328/*! Solves the linear least squares problem defined by an m-by-n matrix
329 \b a and an m element vector \b b .
330 A solution \b x to the overdetermined system of linear equations
331 b = a * x is computed, minimizing the norm of b-a*x.
332 Underdetermined systems (m<n) are not yet handled.
333 Inout arrays should have FortranMemory mapping (column packed).
334 \param a : input matrix, overwritten on output
335 \param b : input-output, input vector b, contains x on exit.
336 \return : return code from lapack driver _gels()
337 \warning : b is not resized.
338 */
339/*
340 $CHECK$ - A faire - cas m<n
341 If the linear system is underdetermined, the minimum norm
342 solution is computed.
343*/
344
345template <class T>
346int LapackServer<T>::LeastSquareSolve(TArray<T>& a, TArray<T> & b)
347{
348 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
349 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
350
351 int_4 rowa = a.RowsKA();
352 int_4 cola = a.ColsKA();
353 int_4 rowb = b.RowsKA();
354 int_4 colb = b.ColsKA();
355
356
357 if ( a.Size(rowa) != b.Size(rowb))
358 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) RowSize(a <> b) "));
359
360 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
361 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) a Or b Not Column Packed"));
362
363 if ( a.Size(rowa) < a.Size(cola)) { // $CHECK$ - m<n a changer
364 cout << " LapackServer<T>::LeastSquareSolve() - m<n - Not yet implemented for "
365 << " underdetermined systems ! " << endl;
366 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) NRows<NCols - "));
367 }
368 int_4 m = a.Size(rowa);
369 int_4 n = a.Size(cola);
370 int_4 nrhs = b.Size(colb);
371
372 int_4 lda = a.Step(cola);
373 int_4 ldb = b.Step(colb);
374 int_4 info;
375
376 int_4 minmn = (m < n) ? m : n;
377 int_4 maxmn = (m > n) ? m : n;
378 int_4 maxmnrhs = (nrhs > maxmn) ? nrhs : maxmn;
379 if (maxmnrhs < 1) maxmnrhs = 1;
380
381 int_4 lwork = minmn+maxmnrhs*5;
382 T * work = new T[lwork];
383
384 char trans = 'N';
385
386 if (typeid(T) == typeid(r_4) )
387 sgels_(&trans, &m, &n, &nrhs, (r_4 *)a.Data(), &lda,
388 (r_4 *)b.Data(), &ldb, (r_4 *)work, &lwork, &info);
389 else if (typeid(T) == typeid(r_8) )
390 dgels_(&trans, &m, &n, &nrhs, (r_8 *)a.Data(), &lda,
391 (r_8 *)b.Data(), &ldb, (r_8 *)work, &lwork, &info);
392 else if (typeid(T) == typeid(complex<r_4>) )
393 cgels_(&trans, &m, &n, &nrhs, (complex<r_4> *)a.Data(), &lda,
394 (complex<r_4> *)b.Data(), &ldb, (complex<r_4> *)work, &lwork, &info);
395 else if (typeid(T) == typeid(complex<r_8>) )
396 zgels_(&trans, &m, &n, &nrhs, (complex<r_8> *)a.Data(), &lda,
397 (complex<r_8> *)b.Data(), &ldb, (complex<r_8> *)work, &lwork, &info);
398 else {
399 delete[] work;
400 string tn = typeid(T).name();
401 cerr << " LapackServer::LeastSquareSolve(a,b) - Unsupported DataType T = " << tn << endl;
402 throw TypeMismatchExc("LapackServer::LeastSquareSolve(a,b) - Unsupported DataType (T)");
403 }
404 delete[] work;
405 return(info);
406}
407
408
409////////////////////////////////////////////////////////////////////////////////////
410//! Interface to Lapack SVD driver s/d/c/zgesv().
411/*! Computes the vector of singular values of \b a. Input arrays
412 should have FortranMemoryMapping (column packed).
413 \param a : input m-by-n matrix
414 \param s : Vector of min(m,n) singular values (descending order)
415 \return : return code from lapack driver _gesvd()
416 */
417
418template <class T>
419int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s)
420{
421 return (SVDDriver(a, s, NULL, NULL) );
422}
423
424//! Interface to Lapack SVD driver s/d/c/zgesv().
425/*! Computes the vector of singular values of \b a, as well as
426 right and left singular vectors of \b a.
427 \f[
428 A = U \Sigma V^T , ( A = U \Sigma V^H \ complex)
429 \f]
430 \f[
431 A v_i = \sigma_i u_i \ and A^T u_i = \sigma_i v_i \ (A^H \ complex)
432 \f]
433 U and V are orthogonal (unitary) matrices.
434 \param a : input m-by-n matrix (in FotranMemoryMapping)
435 \param s : Vector of min(m,n) singular values (descending order)
436 \param u : Matrix of left singular vectors
437 \param vt : Transpose of right singular vectors.
438 \return : return code from lapack driver _gesvd()
439 */
440template <class T>
441int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s, TArray<T> & u, TArray<T> & vt)
442{
443 return (SVDDriver(a, s, &u, &vt) );
444}
445
446
447//! Interface to Lapack SVD driver s/d/c/zgesv().
448template <class T>
449int LapackServer<T>::SVDDriver(TArray<T>& a, TArray<T> & s, TArray<T>* up, TArray<T>* vtp)
450{
451 if ( ( a.NbDimensions() != 2 ) )
452 throw(SzMismatchError("LapackServer::SVD(a, ...) a.NbDimensions() != 2"));
453
454 int_4 rowa = a.RowsKA();
455 int_4 cola = a.ColsKA();
456
457 if ( !a.IsPacked(rowa) )
458 throw(SzMismatchError("LapackServer::SVD(a, ...) a Not Column Packed "));
459
460 int_4 m = a.Size(rowa);
461 int_4 n = a.Size(cola);
462 int_4 maxmn = (m > n) ? m : n;
463 int_4 minmn = (m < n) ? m : n;
464
465 char jobu, jobvt;
466 jobu = 'N';
467 jobvt = 'N';
468
469 sa_size_t sz[2];
470 if ( up != NULL) {
471 if ( dynamic_cast< TVector<T> * > (vtp) )
472 throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for u !") );
473 up->SetMemoryMapping(BaseArray::FortranMemoryMapping);
474 sz[0] = sz[1] = m;
475 up->ReSize(2, sz );
476 jobu = 'A';
477 }
478 else {
479 up = new TMatrix<T>(1,1);
480 jobu = 'N';
481 }
482 if ( vtp != NULL) {
483 if ( dynamic_cast< TVector<T> * > (vtp) )
484 throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for vt !") );
485 vtp->SetMemoryMapping(BaseArray::FortranMemoryMapping);
486 sz[0] = sz[1] = n;
487 vtp->ReSize(2, sz );
488 jobvt = 'A';
489 }
490 else {
491 vtp = new TMatrix<T>(1,1);
492 jobvt = 'N';
493 }
494
495 TVector<T> *vs = dynamic_cast< TVector<T> * > (&s);
496 if (vs) vs->ReSize(minmn);
497 else {
498 TMatrix<T> *ms = dynamic_cast< TMatrix<T> * > (&s);
499 if (ms) ms->ReSize(minmn,1);
500 else {
501 sz[0] = minmn; sz[1] = 1;
502 s.ReSize(1, sz);
503 }
504 }
505
506 int_4 lda = a.Step(a.ColsKA());
507 int_4 ldu = up->Step(up->ColsKA());
508 int_4 ldvt = vtp->Step(vtp->ColsKA());
509
510 int_4 lwork = maxmn*5*wspace_size_factor;
511 T * work = new T[lwork];
512 int_4 info;
513
514 if (typeid(T) == typeid(r_4) )
515 sgesvd_(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
516 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
517 (r_4 *)work, &lwork, &info);
518 else if (typeid(T) == typeid(r_8) )
519 dgesvd_(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
520 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
521 (r_8 *)work, &lwork, &info);
522 else if (typeid(T) == typeid(complex<r_4>) )
523 cgesvd_(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
524 (complex<r_4> *)s.Data(), (complex<r_4> *) up->Data(), &ldu,
525 (complex<r_4> *)vtp->Data(), &ldvt,
526 (complex<r_4> *)work, &lwork, &info);
527 else if (typeid(T) == typeid(complex<r_8>) )
528 zgesvd_(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
529 (complex<r_8> *)s.Data(), (complex<r_8> *) up->Data(), &ldu,
530 (complex<r_8> *)vtp->Data(), &ldvt,
531 (complex<r_8> *)work, &lwork, &info);
532 else {
533 if (jobu == 'N') delete up;
534 if (jobvt == 'N') delete vtp;
535 string tn = typeid(T).name();
536 cerr << " LapackServer::SVDDriver(...) - Unsupported DataType T = " << tn << endl;
537 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
538 }
539
540 if (jobu == 'N') delete up;
541 if (jobvt == 'N') delete vtp;
542 return(info);
543}
544
545
546////////////////////////////////////////////////////////////////////////////////////
547/*! Computes the eigen values and eigen vectors of a symetric (or hermitian) matrix \b a.
548 Input arrays should have FortranMemoryMapping (column packed).
549 \param a : input symetric (or hermitian) n-by-n matrix
550 \param b : Vector of eigenvalues (descending order)
551 \param eigenvector : if true compute eigenvectors, if not only eigenvalues
552 \param a : on return array of eigenvectors (same order than eval, one vector = one column)
553 \return : return code from lapack driver _gesvd()
554 */
555
556template <class T>
557int LapackServer<T>::LapackEigenSym(TArray<T>& a, TVector<r_8>& b, bool eigenvector)
558{
559 if ( a.NbDimensions() != 2 )
560 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a NbDimensions() != 2"));
561 int_4 rowa = a.RowsKA();
562 int_4 cola = a.ColsKA();
563 if ( a.Size(rowa) != a.Size(cola))
564 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a Not a square Array"));
565 if (!a.IsPacked(rowa))
566 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a Not Column Packed"));
567
568 char uplo='U'; char struplo[5]; struplo[0]=uplo; struplo[1]='\0';
569 char jobz='N'; if(eigenvector) jobz='V';
570 char strjobz[5]; strjobz[0]=jobz; strjobz[1]='\0';
571
572 int_4 n = a.Size(rowa);
573 int_4 lda = a.Step(cola);
574 int_4 info = 0;
575
576 b.ReSize(n); b = 0.;
577
578 if (typeid(T) == typeid(r_4) ) {
579 int_4 lwork = 3*n-1 +5; r_4* work = new r_4[lwork];
580 r_4* w = new r_4[n];
581 ssyev_(strjobz,struplo,&n,(r_4 *)a.Data(),&lda,(r_4 *)w,(r_4 *)work,&lwork,&info);
582 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
583 delete [] work; delete [] w;
584 } else if (typeid(T) == typeid(r_8) ) {
585 int_4 lwork = 3*n-1 +5; r_8* work = new r_8[lwork];
586 r_8* w = new r_8[n];
587 dsyev_(strjobz,struplo,&n,(r_8 *)a.Data(),&lda,(r_8 *)w,(r_8 *)work,&lwork,&info);
588 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
589 delete [] work; delete [] w;
590 } else if (typeid(T) == typeid(complex<r_4>) ) {
591 int_4 lwork = 2*n-1 +5; complex<r_4>* work = new complex<r_4>[lwork];
592 r_4* rwork = new r_4[3*n-2 +5]; r_4* w = new r_4[n];
593 cheev_(strjobz,struplo,&n,(complex<r_4> *)a.Data(),&lda,(r_4 *)w
594 ,(complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
595 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
596 delete [] work; delete [] rwork; delete [] w;
597 } else if (typeid(T) == typeid(complex<r_8>) ) {
598 int_4 lwork = 2*n-1 +5; complex<r_8>* work = new complex<r_8>[lwork];
599 r_8* rwork = new r_8[3*n-2 +5]; r_8* w = new r_8[n];
600 zheev_(strjobz,struplo,&n,(complex<r_8> *)a.Data(),&lda,(r_8 *)w
601 ,(complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
602 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
603 delete [] work; delete [] rwork; delete [] w;
604 } else {
605 string tn = typeid(T).name();
606 cerr << " LapackServer::LapackEigenSym(a,b) - Unsupported DataType T = " << tn << endl;
607 throw TypeMismatchExc("LapackServer::LapackEigenSym(a,b) - Unsupported DataType (T)");
608 }
609
610 return(info);
611}
612
613////////////////////////////////////////////////////////////////////////////////////
614/*! Computes the eigen values and eigen vectors of a general squared matrix \b a.
615 Input arrays should have FortranMemoryMapping (column packed).
616 \param a : input general n-by-n matrix
617 \param eval : Vector of eigenvalues (complex double precision)
618 \param evec : Matrix of eigenvector (same order than eval, one vector = one column)
619 \param eigenvector : if true compute (right) eigenvectors, if not only eigenvalues
620 \param a : on return array of eigenvectors
621 \return : return code from lapack driver _gesvd()
622 \verbatim
623 eval : contains the computed eigenvalues.
624 --- For real matrices "a" :
625 Complex conjugate pairs of eigenvalues appear consecutively
626 with the eigenvalue having the positive imaginary part first.
627 evec : the right eigenvectors v(j) are stored one after another
628 in the columns of evec, in the same order as their eigenvalues.
629 --- For real matrices "a" :
630 If the j-th eigenvalue is real, then v(j) = evec(:,j),
631 the j-th column of evec.
632 If the j-th and (j+1)-st eigenvalues form a complex
633 conjugate pair, then v(j) = evec(:,j) + i*evec(:,j+1) and
634 v(j+1) = evec(:,j) - i*evec(:,j+1).
635 \endverbatim
636*/
637
638template <class T>
639int LapackServer<T>::LapackEigen(TArray<T>& a, TVector< complex<r_8> >& eval, TMatrix<T>& evec, bool eigenvector)
640{
641 if ( a.NbDimensions() != 2 )
642 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a NbDimensions() != 2"));
643 int_4 rowa = a.RowsKA();
644 int_4 cola = a.ColsKA();
645 if ( a.Size(rowa) != a.Size(cola))
646 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a Not a square Array"));
647 if (!a.IsPacked(rowa))
648 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a Not Column Packed"));
649
650 char jobvl = 'N'; char strjobvl[5]; strjobvl[0] = jobvl; strjobvl[1] = '\0';
651 char jobvr = 'N'; if(eigenvector) jobvr='V';
652 char strjobvr[5]; strjobvr[0] = jobvr; strjobvr[1] = '\0';
653
654 int_4 n = a.Size(rowa);
655 int_4 lda = a.Step(cola);
656 int_4 info = 0;
657
658 eval.ReSize(n); eval = complex<r_8>(0.,0.);
659 if(eigenvector) {evec.ReSize(n,n); evec = (T) 0.;}
660 int_4 ldvr = n, ldvl = 1;
661
662 if (typeid(T) == typeid(r_4) ) {
663 int_4 lwork = 4*n +5; r_4* work = new r_4[lwork];
664 r_4* wr = new r_4[n]; r_4* wi = new r_4[n]; r_4* vl = NULL;
665 sgeev_(strjobvl,strjobvr,&n,(r_4 *)a.Data(),&lda,(r_4 *)wr,(r_4 *)wi,
666 (r_4 *)vl,&ldvl,(r_4 *)evec.Data(),&ldvr,
667 (r_4 *)work,&lwork,&info);
668 if(info==0) for(int i=0;i<n;i++) eval(i) = complex<r_8>(wr[i],wi[i]);
669 delete [] work; delete [] wr; delete [] wi;
670 } else if (typeid(T) == typeid(r_8) ) {
671 int_4 lwork = 4*n +5; r_8* work = new r_8[lwork];
672 r_8* wr = new r_8[n]; r_8* wi = new r_8[n]; r_8* vl = NULL;
673 dgeev_(strjobvl,strjobvr,&n,(r_8 *)a.Data(),&lda,(r_8 *)wr,(r_8 *)wi,
674 (r_8 *)vl,&ldvl,(r_8 *)evec.Data(),&ldvr,
675 (r_8 *)work,&lwork,&info);
676 if(info==0) for(int i=0;i<n;i++) eval(i) = complex<r_8>(wr[i],wi[i]);
677 delete [] work; delete [] wr; delete [] wi;
678 } else if (typeid(T) == typeid(complex<r_4>) ) {
679 int_4 lwork = 2*n +5; complex<r_4>* work = new complex<r_4>[lwork];
680 r_4* rwork = new r_4[2*n+5]; r_4* vl = NULL; TVector< complex<r_4> > w(n);
681 cgeev_(strjobvl,strjobvr,&n,(complex<r_4> *)a.Data(),&lda,(complex<r_4> *)w.Data(),
682 (complex<r_4> *)vl,&ldvl,(complex<r_4> *)evec.Data(),&ldvr,
683 (complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
684 if(info==0) for(int i=0;i<n;i++) eval(i) = w(i);
685 delete [] work; delete [] rwork;
686 } else if (typeid(T) == typeid(complex<r_8>) ) {
687 int_4 lwork = 2*n +5; complex<r_8>* work = new complex<r_8>[lwork];
688 r_8* rwork = new r_8[2*n+5]; r_8* vl = NULL;
689 zgeev_(strjobvl,strjobvr,&n,(complex<r_8> *)a.Data(),&lda,(complex<r_8> *)eval.Data(),
690 (complex<r_8> *)vl,&ldvl,(complex<r_8> *)evec.Data(),&ldvr,
691 (complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
692 delete [] work; delete [] rwork;
693 } else {
694 string tn = typeid(T).name();
695 cerr << " LapackServer::LapackEigen(a,b) - Unsupported DataType T = " << tn << endl;
696 throw TypeMismatchExc("LapackServer::LapackEigen(a,b) - Unsupported DataType (T)");
697 }
698
699 return(info);
700}
701
702
703
704
705
706////////////////////////////////////////////////////////////////////////////////////
707void rztest_lapack(TArray<r_4>& aa, TArray<r_4>& bb)
708{
709 if ( aa.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
710 if ( aa.SizeX() != aa.SizeY()) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a square Matrix"));
711 if ( bb.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
712 if ( bb.SizeX() != aa.SizeX() ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A <> B "));
713 if ( !bb.IsPacked() || !bb.IsPacked() )
714 throw(SzMismatchError("rztest_lapack(TMatrix<r_4> Not packed A or B "));
715
716 int_4 n = aa.SizeX();
717 int_4 nrhs = bb.SizeY();
718 int_4 lda = n;
719 int_4 ldb = bb.SizeX();
720 int_4 info;
721 int_4* ipiv = new int_4[n];
722 sgesv_(&n, &nrhs, aa.Data(), &lda, ipiv, bb.Data(), &ldb, &info);
723 delete[] ipiv;
724 cout << "rztest_lapack/Info= " << info << endl;
725 cout << aa << "\n" << bb << endl;
726 return;
727}
728
729///////////////////////////////////////////////////////////////
730#ifdef __CXX_PRAGMA_TEMPLATES__
731#pragma define_template LapackServer<r_4>
732#pragma define_template LapackServer<r_8>
733#pragma define_template LapackServer< complex<r_4> >
734#pragma define_template LapackServer< complex<r_8> >
735#endif
736
737#if defined(ANSI_TEMPLATES) || defined(GNU_TEMPLATES)
738template class LapackServer<r_4>;
739template class LapackServer<r_8>;
740template class LapackServer< complex<r_4> >;
741template class LapackServer< complex<r_8> >;
742#endif
743
744#if defined(OS_LINUX)
745// Pour le link avec f2c sous Linux
746extern "C" {
747 void MAIN__();
748}
749
750void MAIN__()
751{
752 cerr << "MAIN__() function for linking with libf2c.a " << endl;
753 cerr << " This function should never be called !!! " << endl;
754 throw PError("MAIN__() should not be called - see intflapack.cc");
755}
756#endif
Note: See TracBrowser for help on using the repository browser.