source: Sophya/trunk/SophyaExt/LinAlg/intflapack.cc@ 2979

Last change on this file since 2979 was 2964, checked in by ansari, 19 years ago

Remplacement NotFoundExc par MathExc ds intflapack.cc , Reza 2/6/2006

File size: 48.3 KB
Line 
1#include <iostream>
2#include <math.h>
3#include "sopnamsp.h"
4#include "intflapack.h"
5#include "tvector.h"
6#include "tmatrix.h"
7#include <typeinfo>
8
9#define GARDMEM 5
10
11/*************** Pour memoire (Christophe) ***************
12Les dispositions memoires (FORTRAN) pour les vecteurs et matrices LAPACK:
13
141./ --- REAL X(N):
15 if an array X of dimension (N) holds a vector x,
16 then X(i) holds "x_i" for i=1,...,N
17
182./ --- REAL A(LDA,N):
19 if a two-dimensional array A of dimension (LDA,N) holds an m-by-n matrix A,
20 then A(i,j) holds "a_ij" for i=1,...,m et j=1,...,n (LDA must be at least m).
21 Note that array arguments are usually declared in the software as assumed-size
22 arrays (last dimension *), for example: REAL A(LDA,*)
23 --- Rangement en memoire:
24 | 11 12 13 14 |
25 Ex: Real A(4,4): A = | 21 22 23 24 |
26 | 31 32 33 34 |
27 | 41 42 43 44 |
28 memoire: {11 21 31 41} {12 22 32 42} {13 23 33 43} {14 24 34 44}
29 First indice (line) "i" varies then the second (column):
30 (put all the first column, then put all the second column,
31 ..., then put all the last column)
32***********************************************************/
33
34/*!
35 \defgroup LinAlg LinAlg module
36 This module contains classes and functions for complex linear
37 algebra on arrays. This module is intended mainly to have
38 classes implementing C++ interfaces between Sophya objects
39 and external linear algebra libraries, such as LAPACK.
40*/
41
42/*!
43 \class SOPHYA::LapackServer
44 \ingroup LinAlg
45 This class implements an interface to LAPACK library driver routines.
46 The LAPACK (Linear Algebra PACKage) is a collection high performance
47 routines to solve common problems in numerical linear algebra.
48 its is available from http://www.netlib.org.
49
50 The present version of LapackServer (Feb 2005) provides
51 interfaces for the linear system solver, singular value
52 decomposition (SVD), Least square solver and
53 eigen value / eigen vector decomposition.
54 Only arrays with BaseArray::FortranMemoryMapping
55 can be handled by LapackServer. LapackServer can be instanciated
56 for simple and double precision real or complex array types.
57 \warning The input array is overwritten in most cases.
58 The example below shows solving a linear system A*X = B
59
60 \code
61 #include "intflapack.h"
62 // ...
63 // Use FortranMemoryMapping as default
64 BaseArray::SetDefaultMemoryMapping(BaseArray::FortranMemoryMapping);
65 // Create an fill the arrays A and B
66 int n = 20;
67 Matrix A(n, n);
68 A = RandomSequence();
69 Vector X(n),B(n);
70 X = RandomSequence();
71 B = A*X;
72 // Solve the linear system A*X = B
73 LapackServer<r_8> lps;
74 lps.LinSolve(A,B);
75 // We get the result in B, which should be equal to X ...
76 // Compute the difference B-X ;
77 Vector diff = B-X;
78 \endcode
79
80*/
81
82/*
83 Decembre 2005 : Suite portage AIX xlC
84 On declare des noms en majuscule pour les routines fortran -
85 avec ou sans underscore _ , suivant les systemes
86*/
87#ifdef AIX
88
89#define ilaenv ilaenv
90
91#define sgesv sgesv
92#define dgesv dgesv
93#define cgesv cgesv
94#define zgesv zgesv
95
96#define ssysv ssysv
97#define dsysv dsysv
98#define csysv csysv
99#define zsysv zsysv
100
101#define sgels sgels
102#define dgels dgels
103#define cgels cgels
104#define zgels zgels
105
106#define sgelsd sgelsd
107#define dgelsd dgelsd
108#define cgelsd cgelsd
109#define zgelsd zgelsd
110
111#define sgesvd sgesvd
112#define dgesvd dgesvd
113#define cgesvd cgesvd
114#define zgesvd zgesvd
115
116#define sgesdd sgesdd
117#define dgesdd dgesdd
118#define cgesdd cgesdd
119#define zgesdd zgesdd
120
121#define ssyev ssyev
122#define dsyev dsyev
123#define cheev cheev
124#define zheev zheev
125
126#define sgeev sgeev
127#define dgeev dgeev
128#define cgeev cgeev
129#define zgeev zgeev
130
131#else
132#define ilaenv ilaenv_
133
134#define sgesv sgesv_
135#define dgesv dgesv_
136#define cgesv cgesv_
137#define zgesv zgesv_
138
139#define ssysv ssysv_
140#define dsysv dsysv_
141#define csysv csysv_
142#define zsysv zsysv_
143
144#define sgels sgels_
145#define dgels dgels_
146#define cgels cgels_
147#define zgels zgels_
148
149#define sgelsd sgelsd_
150#define dgelsd dgelsd_
151#define cgelsd cgelsd_
152#define zgelsd zgelsd_
153
154#define sgesvd sgesvd_
155#define dgesvd dgesvd_
156#define cgesvd cgesvd_
157#define zgesvd zgesvd_
158
159#define sgesdd sgesdd_
160#define dgesdd dgesdd_
161#define cgesdd cgesdd_
162#define zgesdd zgesdd_
163
164#define ssyev ssyev_
165#define dsyev dsyev_
166#define cheev cheev_
167#define zheev zheev_
168
169#define sgeev sgeev_
170#define dgeev dgeev_
171#define cgeev cgeev_
172#define zgeev zgeev_
173
174#endif
175////////////////////////////////////////////////////////////////////////////////////
176extern "C" {
177// Le calculateur de workingspace
178 int_4 ilaenv(int_4 *ispec,char *name,char *opts,int_4 *n1,int_4 *n2,int_4 *n3,int_4 *n4,
179 int_4 nc1,int_4 nc2);
180
181// Drivers pour resolution de systemes lineaires
182 void sgesv(int_4* n, int_4* nrhs, r_4* a, int_4* lda,
183 int_4* ipiv, r_4* b, int_4* ldb, int_4* info);
184 void dgesv(int_4* n, int_4* nrhs, r_8* a, int_4* lda,
185 int_4* ipiv, r_8* b, int_4* ldb, int_4* info);
186 void cgesv(int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
187 int_4* ipiv, complex<r_4>* b, int_4* ldb, int_4* info);
188 void zgesv(int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
189 int_4* ipiv, complex<r_8>* b, int_4* ldb, int_4* info);
190
191// Drivers pour resolution de systemes lineaires symetriques
192 void ssysv(char* uplo, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
193 int_4* ipiv, r_4* b, int_4* ldb,
194 r_4* work, int_4* lwork, int_4* info);
195 void dsysv(char* uplo, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
196 int_4* ipiv, r_8* b, int_4* ldb,
197 r_8* work, int_4* lwork, int_4* info);
198 void csysv(char* uplo, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
199 int_4* ipiv, complex<r_4>* b, int_4* ldb,
200 complex<r_4>* work, int_4* lwork, int_4* info);
201 void zsysv(char* uplo, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
202 int_4* ipiv, complex<r_8>* b, int_4* ldb,
203 complex<r_8>* work, int_4* lwork, int_4* info);
204
205// Driver pour resolution de systemes au sens de Xi2
206 void sgels(char * trans, int_4* m, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
207 r_4* b, int_4* ldb, r_4* work, int_4* lwork, int_4* info);
208 void dgels(char * trans, int_4* m, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
209 r_8* b, int_4* ldb, r_8* work, int_4* lwork, int_4* info);
210 void cgels(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
211 complex<r_4>* b, int_4* ldb, complex<r_4>* work, int_4* lwork, int_4* info);
212 void zgels(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
213 complex<r_8>* b, int_4* ldb, complex<r_8>* work, int_4* lwork, int_4* info);
214
215// Driver pour resolution de systemes au sens de Xi2 par SVD Divide & Conquer
216 void sgelsd(int_4* m,int_4* n,int_4* nrhs,r_4* a,int_4* lda,
217 r_4* b,int_4* ldb,r_4* s,r_4* rcond,int_4* rank,
218 r_4* work,int_4* lwork,int_4* iwork,int_4* info);
219 void dgelsd(int_4* m,int_4* n,int_4* nrhs,r_8* a,int_4* lda,
220 r_8* b,int_4* ldb,r_8* s,r_8* rcond,int_4* rank,
221 r_8* work,int_4* lwork,int_4* iwork,int_4* info);
222 void cgelsd(int_4* m,int_4* n,int_4* nrhs,complex<r_4>* a,int_4* lda,
223 complex<r_4>* b,int_4* ldb,r_4* s,r_4* rcond,int_4* rank,
224 complex<r_4>* work,int_4* lwork,r_4* rwork,int_4* iwork,int_4* info);
225 void zgelsd(int_4* m,int_4* n,int_4* nrhs,complex<r_8>* a,int_4* lda,
226 complex<r_8>* b,int_4* ldb,r_8* s,r_8* rcond,int_4* rank,
227 complex<r_8>* work,int_4* lwork,r_8* rwork,int_4* iwork,int_4* info);
228
229// Driver pour decomposition SVD
230 void sgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, r_4* a, int_4* lda,
231 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
232 r_4* work, int_4* lwork, int_4* info);
233 void dgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, r_8* a, int_4* lda,
234 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
235 r_8* work, int_4* lwork, int_4* info);
236 void cgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
237 r_4* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
238 complex<r_4>* work, int_4* lwork, r_4* rwork, int_4* info);
239 void zgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
240 r_8* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
241 complex<r_8>* work, int_4* lwork, r_8* rwork, int_4* info);
242
243// Driver pour decomposition SVD Divide and Conquer
244 void sgesdd(char* jobz, int_4* m, int_4* n, r_4* a, int_4* lda,
245 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
246 r_4* work, int_4* lwork, int_4* iwork, int_4* info);
247 void dgesdd(char* jobz, int_4* m, int_4* n, r_8* a, int_4* lda,
248 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
249 r_8* work, int_4* lwork, int_4* iwork, int_4* info);
250 void cgesdd(char* jobz, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
251 r_4* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
252 complex<r_4>* work, int_4* lwork, r_4* rwork, int_4* iwork, int_4* info);
253 void zgesdd(char* jobz, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
254 r_8* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
255 complex<r_8>* work, int_4* lwork, r_8* rwork, int_4* iwork, int_4* info);
256
257// Driver pour eigen decomposition for symetric/hermitian matrices
258 void ssyev(char* jobz, char* uplo, int_4* n, r_4* a, int_4* lda, r_4* w,
259 r_4* work, int_4 *lwork, int_4* info);
260 void dsyev(char* jobz, char* uplo, int_4* n, r_8* a, int_4* lda, r_8* w,
261 r_8* work, int_4 *lwork, int_4* info);
262 void cheev(char* jobz, char* uplo, int_4* n, complex<r_4>* a, int_4* lda, r_4* w,
263 complex<r_4>* work, int_4 *lwork, r_4* rwork, int_4* info);
264 void zheev(char* jobz, char* uplo, int_4* n, complex<r_8>* a, int_4* lda, r_8* w,
265 complex<r_8>* work, int_4 *lwork, r_8* rwork, int_4* info);
266
267// Driver pour eigen decomposition for general squared matrices
268 void sgeev(char* jobl, char* jobvr, int_4* n, r_4* a, int_4* lda, r_4* wr, r_4* wi,
269 r_4* vl, int_4* ldvl, r_4* vr, int_4* ldvr,
270 r_4* work, int_4 *lwork, int_4* info);
271 void dgeev(char* jobl, char* jobvr, int_4* n, r_8* a, int_4* lda, r_8* wr, r_8* wi,
272 r_8* vl, int_4* ldvl, r_8* vr, int_4* ldvr,
273 r_8* work, int_4 *lwork, int_4* info);
274 void cgeev(char* jobl, char* jobvr, int_4* n, complex<r_4>* a, int_4* lda, complex<r_4>* w,
275 complex<r_4>* vl, int_4* ldvl, complex<r_4>* vr, int_4* ldvr,
276 complex<r_4>* work, int_4 *lwork, r_4* rwork, int_4* info);
277 void zgeev(char* jobl, char* jobvr, int_4* n, complex<r_8>* a, int_4* lda, complex<r_8>* w,
278 complex<r_8>* vl, int_4* ldvl, complex<r_8>* vr, int_4* ldvr,
279 complex<r_8>* work, int_4 *lwork, r_8* rwork, int_4* info);
280
281}
282
283// -------------- Classe LapackServer<T> --------------
284
285////////////////////////////////////////////////////////////////////////////////////
286template <class T>
287LapackServer<T>::LapackServer(bool throw_on_error)
288 : Throw_On_Error(throw_on_error)
289{
290 SetWorkSpaceSizeFactor();
291}
292
293template <class T>
294LapackServer<T>::~LapackServer()
295{
296}
297
298////////////////////////////////////////////////////////////////////////////////////
299template <class T>
300int_4 LapackServer<T>::ilaenv_en_C(int_4 ispec,char *name,char *opts,int_4 n1,int_4 n2,int_4 n3,int_4 n4)
301{
302 int_4 nc1 = strlen(name), nc2 = strlen(opts), rc=0;
303 rc = ilaenv(&ispec,name,opts,&n1,&n2,&n3,&n4,nc1,nc2);
304 //cout<<"ilaenv_en_C("<<ispec<<","<<name<<"("<<nc1<<"),"<<opts<<"("<<nc2<<"),"
305 // <<n1<<","<<n2<<","<<n3<<","<<n4<<") = "<<rc<<endl;
306 return rc;
307}
308
309template <class T>
310int_4 LapackServer<T>::type2i4(void *val,int nbytes)
311// Retourne un entier contenant la valeur contenue dans val
312// - nbytes = nombre de bytes dans le contenu de val
313// ex: r_4 x = 3.4; type2i4(&x,4) -> 3
314// ex: r_8 x = 3.4; type2i4(&x,8) -> 3
315// ex: complex<r_4> x(3.4,7.8); type2i4(&x,4) -> 3
316// ex: complex<r_8> x(3.4,7.8); type2i4(&x,8) -> 3
317{
318 r_4* x4; r_8* x8; int_4 lw=0;
319 if(nbytes==4) {x4 = (r_4*)val; lw = (int_4)(*x4);}
320 else {x8 = (r_8*)val; lw = (int_4)(*x8);}
321 return lw;
322}
323
324////////////////////////////////////////////////////////////////////////////////////
325//! Interface to Lapack linear system solver driver s/d/c/zgesv().
326/*! Solve the linear system a * x = b using LU factorization.
327 Input arrays should have FortranMemory mapping (column packed).
328 \param a : input matrix, overwritten on output
329 \param b : input-output, input vector b, contains x on exit
330 \return : return code from lapack driver _gesv()
331 */
332template <class T>
333int LapackServer<T>::LinSolve(TArray<T>& a, TArray<T> & b)
334{
335 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
336 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
337
338 int_4 rowa = a.RowsKA();
339 int_4 cola = a.ColsKA();
340 int_4 rowb = b.RowsKA();
341 int_4 colb = b.ColsKA();
342 if ( a.Size(rowa) != a.Size(cola))
343 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Not a square Array"));
344 if ( a.Size(rowa) != b.Size(rowb))
345 throw(SzMismatchError("LapackServer::LinSolve(a,b) RowSize(a <> b) "));
346
347 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
348 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Column Packed"));
349
350 int_4 n = a.Size(rowa);
351 int_4 nrhs = b.Size(colb);
352 int_4 lda = a.Step(cola);
353 int_4 ldb = b.Step(colb);
354 int_4 info;
355 int_4* ipiv = new int_4[n];
356
357 if (typeid(T) == typeid(r_4) )
358 sgesv(&n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb, &info);
359 else if (typeid(T) == typeid(r_8) )
360 dgesv(&n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb, &info);
361 else if (typeid(T) == typeid(complex<r_4>) )
362 cgesv(&n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
363 (complex<r_4> *)b.Data(), &ldb, &info);
364 else if (typeid(T) == typeid(complex<r_8>) )
365 zgesv(&n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
366 (complex<r_8> *)b.Data(), &ldb, &info);
367 else {
368 delete[] ipiv;
369 string tn = typeid(T).name();
370 cerr << " LapackServer::LinSolve(a,b) - Unsupported DataType T = " << tn << endl;
371 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
372 }
373 delete[] ipiv;
374 if(info!=0 && Throw_On_Error) {
375 char serr[128]; sprintf(serr,"LinSolve_Error info=%d",info);
376 throw MathExc(serr);
377 }
378 return(info);
379}
380
381////////////////////////////////////////////////////////////////////////////////////
382//! Interface to Lapack linear system solver driver s/d/c/zsysv().
383/*! Solve the linear system a * x = b with a symetric matrix using LU factorization.
384 Input arrays should have FortranMemory mapping (column packed).
385 \param a : input matrix symetric , overwritten on output
386 \param b : input-output, input vector b, contains x on exit
387 \return : return code from lapack driver _sysv()
388 */
389template <class T>
390int LapackServer<T>::LinSolveSym(TArray<T>& a, TArray<T> & b)
391// --- REMARQUES DE CMV ---
392// 1./ contrairement a ce qui est dit dans la doc, il s'agit
393// de matrices SYMETRIQUES complexes et non HERMITIENNES !!!
394// 2./ pourquoi les routines de LinSolve pour des matrices symetriques
395// sont plus de deux fois plus lentes que les LinSolve generales sur OSF
396// et sensiblement plus lentes sous Linux ???
397{
398 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
399 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Or b NbDimensions() != 2"));
400 int_4 rowa = a.RowsKA();
401 int_4 cola = a.ColsKA();
402 int_4 rowb = b.RowsKA();
403 int_4 colb = b.ColsKA();
404 if ( a.Size(rowa) != a.Size(cola))
405 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Not a square Array"));
406 if ( a.Size(rowa) != b.Size(rowb))
407 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) RowSize(a <> b) "));
408
409 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
410 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Or b Not Column Packed"));
411
412 int_4 n = a.Size(rowa);
413 int_4 nrhs = b.Size(colb);
414 int_4 lda = a.Step(cola);
415 int_4 ldb = b.Step(colb);
416 int_4 info = 0;
417 int_4* ipiv = new int_4[n];
418 int_4 lwork = -1;
419 T * work = NULL;
420 T wkget[2];
421
422 char uplo = 'U'; // char uplo = 'L';
423 char struplo[5]; struplo[0] = uplo; struplo[1] = '\0';
424
425 if (typeid(T) == typeid(r_4) ) {
426 ssysv(&uplo, &n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb,
427 (r_4 *)wkget, &lwork, &info);
428 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
429 ssysv(&uplo, &n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb,
430 (r_4 *)work, &lwork, &info);
431 } else if (typeid(T) == typeid(r_8) ) {
432 dsysv(&uplo, &n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb,
433 (r_8 *)wkget, &lwork, &info);
434 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
435 dsysv(&uplo, &n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb,
436 (r_8 *)work, &lwork, &info);
437 } else if (typeid(T) == typeid(complex<r_4>) ) {
438 csysv(&uplo, &n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
439 (complex<r_4> *)b.Data(), &ldb,
440 (complex<r_4> *)wkget, &lwork, &info);
441 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
442 csysv(&uplo, &n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
443 (complex<r_4> *)b.Data(), &ldb,
444 (complex<r_4> *)work, &lwork, &info);
445 } else if (typeid(T) == typeid(complex<r_8>) ) {
446 zsysv(&uplo, &n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
447 (complex<r_8> *)b.Data(), &ldb,
448 (complex<r_8> *)wkget, &lwork, &info);
449 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
450 zsysv(&uplo, &n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
451 (complex<r_8> *)b.Data(), &ldb,
452 (complex<r_8> *)work, &lwork, &info);
453 } else {
454 if(work) delete[] work;
455 delete[] ipiv;
456 string tn = typeid(T).name();
457 cerr << " LapackServer::LinSolveSym(a,b) - Unsupported DataType T = " << tn << endl;
458 throw TypeMismatchExc("LapackServer::LinSolveSym(a,b) - Unsupported DataType (T)");
459 }
460 if(work) delete[] work;
461 delete[] ipiv;
462 if(info!=0 && Throw_On_Error) {
463 char serr[128]; sprintf(serr,"LinSolveSym_Error info=%d",info);
464 throw MathExc(serr);
465 }
466 return(info);
467}
468
469////////////////////////////////////////////////////////////////////////////////////
470//! Interface to Lapack least squares solver driver s/d/c/zgels().
471/*! Solves the linear least squares problem defined by an m-by-n matrix
472 \b a and an m element vector \b b , using QR or LQ factorization .
473 A solution \b x to the overdetermined system of linear equations
474 b = a * x is computed, minimizing the norm of b-a*x.
475 Underdetermined systems (m<n) are not yet handled.
476 Inout arrays should have FortranMemory mapping (column packed).
477 \param a : input matrix, overwritten on output
478 \param b : input-output, input vector b, contains x on exit.
479 \return : return code from lapack driver _gels()
480 \warning : b is not resized.
481 */
482/*
483 $CHECK$ - A faire - cas m<n
484 If the linear system is underdetermined, the minimum norm
485 solution is computed.
486*/
487
488template <class T>
489int LapackServer<T>::LeastSquareSolve(TArray<T>& a, TArray<T> & b)
490{
491 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
492 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) a Or b NbDimensions() != 2"));
493
494 int_4 rowa = a.RowsKA();
495 int_4 cola = a.ColsKA();
496 int_4 rowb = b.RowsKA();
497 int_4 colb = b.ColsKA();
498
499
500 if ( a.Size(rowa) != b.Size(rowb))
501 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) RowSize(a <> b) "));
502
503 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
504 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) a Or b Not Column Packed"));
505
506 if ( a.Size(rowa) < a.Size(cola)) { // $CHECK$ - m<n a changer
507 cout << " LapackServer<T>::LeastSquareSolve() - m<n - Not yet implemented for "
508 << " underdetermined systems ! " << endl;
509 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) NRows<NCols - "));
510 }
511 int_4 m = a.Size(rowa);
512 int_4 n = a.Size(cola);
513 int_4 nrhs = b.Size(colb);
514
515 int_4 lda = a.Step(cola);
516 int_4 ldb = b.Step(colb);
517 int_4 info;
518
519 int_4 minmn = (m < n) ? m : n;
520 int_4 maxmn = (m > n) ? m : n;
521 int_4 maxmnrhs = (nrhs > maxmn) ? nrhs : maxmn;
522 if (maxmnrhs < 1) maxmnrhs = 1;
523
524 int_4 lwork = -1; //minmn+maxmnrhs*5;
525 T * work = NULL;
526 T wkget[2];
527
528 char trans = 'N';
529
530 if (typeid(T) == typeid(r_4) ) {
531 sgels(&trans, &m, &n, &nrhs, (r_4 *)a.Data(), &lda,
532 (r_4 *)b.Data(), &ldb, (r_4 *)wkget, &lwork, &info);
533 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
534 sgels(&trans, &m, &n, &nrhs, (r_4 *)a.Data(), &lda,
535 (r_4 *)b.Data(), &ldb, (r_4 *)work, &lwork, &info);
536 } else if (typeid(T) == typeid(r_8) ) {
537 dgels(&trans, &m, &n, &nrhs, (r_8 *)a.Data(), &lda,
538 (r_8 *)b.Data(), &ldb, (r_8 *)wkget, &lwork, &info);
539 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
540 dgels(&trans, &m, &n, &nrhs, (r_8 *)a.Data(), &lda,
541 (r_8 *)b.Data(), &ldb, (r_8 *)work, &lwork, &info);
542 } else if (typeid(T) == typeid(complex<r_4>) ) {
543 cgels(&trans, &m, &n, &nrhs, (complex<r_4> *)a.Data(), &lda,
544 (complex<r_4> *)b.Data(), &ldb, (complex<r_4> *)wkget, &lwork, &info);
545 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
546 cgels(&trans, &m, &n, &nrhs, (complex<r_4> *)a.Data(), &lda,
547 (complex<r_4> *)b.Data(), &ldb, (complex<r_4> *)work, &lwork, &info);
548 } else if (typeid(T) == typeid(complex<r_8>) ) {
549 zgels(&trans, &m, &n, &nrhs, (complex<r_8> *)a.Data(), &lda,
550 (complex<r_8> *)b.Data(), &ldb, (complex<r_8> *)wkget, &lwork, &info);
551 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
552 zgels(&trans, &m, &n, &nrhs, (complex<r_8> *)a.Data(), &lda,
553 (complex<r_8> *)b.Data(), &ldb, (complex<r_8> *)work, &lwork, &info);
554 } else {
555 if(work) delete [] work; work=NULL;
556 string tn = typeid(T).name();
557 cerr << " LapackServer::LeastSquareSolve(a,b) - Unsupported DataType T = " << tn << endl;
558 throw TypeMismatchExc("LapackServer::LeastSquareSolve(a,b) - Unsupported DataType (T)");
559 }
560 if(work) delete [] work;
561 if(info!=0 && Throw_On_Error) {
562 char serr[128]; sprintf(serr,"LeastSquareSolve_Error info=%d",info);
563 throw MathExc(serr);
564 }
565 return(info);
566}
567
568////////////////////////////////////////////////////////////////////////////////////
569//! Square matrix inversion using Lapack linear system solver
570/*! Compute the inverse of a square matrix using linear system solver routine
571 Input arrays should have FortranMemory mapping (column packed).
572 \param a : input matrix, overwritten on output
573 \param ainv : output matrix, contains inverse(a) on exit.
574 ainv is allocated if it has size 0
575 If not allocated, ainv is automatically
576 \return : return code from LapackServer::LinSolve()
577 \sa LapackServer::LinSolve()
578 */
579template <class T>
580int LapackServer<T>::ComputeInverse(TMatrix<T>& a, TMatrix<T> & ainv)
581{
582 if ( a.NbDimensions() != 2 )
583 throw(SzMismatchError("LapackServer::Inverse() NDim(a) != 2"));
584 if ( a.GetMemoryMapping() != BaseArray::FortranMemoryMapping )
585 throw(SzMismatchError("LapackServer::Inverse() a NOT in FortranMemoryMapping"));
586 if ( a.NRows() != a.NCols() )
587 throw(SzMismatchError("LapackServer::Inverse() a NOT square matrix (a.NRows!=a.NCols)"));
588 if (ainv.IsAllocated()) {
589 bool smo, ssz;
590 ssz = a.CompareSizes(ainv, smo);
591 if ( (ssz == false) || (smo == false) )
592 throw(SzMismatchError("LapackServer::Inverse() ainv<>a Size/MemOrg mismatch "));
593 }
594 else ainv.SetSize(a.NRows(), a.NCols(), BaseArray::FortranMemoryMapping, false);
595 ainv = IdentityMatrix();
596 return LinSolve(a, ainv);
597}
598
599////////////////////////////////////////////////////////////////////////////////////
600//! Interface to Lapack least squares solver driver s/d/c/zgelsd().
601/*! Solves the linear least squares problem defined by an m-by-n matrix
602 \b a and an m element vector \b b , using SVD factorization Divide and Conquer.
603 Inout arrays should have FortranMemory mapping (column packed).
604 \param rcond : definition of zero value (S(i) <= RCOND*S(0) are treated as zero).
605 If RCOND < 0, machine precision is used instead.
606 \param a : input matrix, overwritten on output
607 \param b : input vector b overwritten by solution on output (beware of size changing)
608 \param x : output matrix of solutions.
609 \param rank : output the rank of the matrix.
610 \return : return code from lapack driver _gelsd()
611 \warning : b is not resized.
612 */
613template <class T>
614int LapackServer<T>::LeastSquareSolveSVD_DC(TMatrix<T>& a,TMatrix<T>& b,TVector<r_8>& s,int_4& rank,r_8 rcond)
615{
616 if ( ( a.NbDimensions() != 2 ) )
617 throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) a != 2"));
618
619 if (!a.IsPacked() || !b.IsPacked())
620 throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) a Or b Not Packed"));
621
622 int_4 m = a.NRows();
623 int_4 n = a.NCols();
624
625 if(b.NRows() != m)
626 throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) bad matching dim between a and b"));
627
628 int_4 nrhs = b.NCols();
629 int_4 minmn = (m < n) ? m : n;
630 int_4 maxmn = (m > n) ? m : n;
631
632 int_4 lda = m;
633 int_4 ldb = maxmn;
634 int_4 info;
635
636 { // Use {} for automatic des-allocation of "bsave"
637 TMatrix<T> bsave(m,nrhs); bsave.SetMemoryMapping(BaseArray::FortranMemoryMapping);
638 bsave = b;
639 b.ReSize(maxmn,nrhs); b = (T) 0;
640 for(int i=0;i<m;i++) for(int j=0;j<nrhs;j++) b(i,j) = bsave(i,j);
641 } // Use {} for automatic des-allocation of "bsave"
642 s.ReSize(minmn);
643
644 int_4 smlsiz = 25; // Normallement ilaenv_en_C(9,...) renvoie toujours 25
645 if(typeid(T) == typeid(r_4) ) smlsiz = ilaenv_en_C(9,"SGELSD"," ",0,0,0,0);
646 else if(typeid(T) == typeid(r_8) ) smlsiz = ilaenv_en_C(9,"DGELSD"," ",0,0,0,0);
647 else if(typeid(T) == typeid(complex<r_4>) ) smlsiz = ilaenv_en_C(9,"CGELSD"," ",0,0,0,0);
648 else if(typeid(T) == typeid(complex<r_8>) ) smlsiz = ilaenv_en_C(9,"ZGELSD"," ",0,0,0,0);
649 if(smlsiz<0) smlsiz = 0;
650 r_8 dum = log((r_8)minmn/(r_8)(smlsiz+1.)) / log(2.);
651 int_4 nlvl = int_4(dum) + 1; if(nlvl<0) nlvl = 0;
652
653 T * work = NULL;
654 int_4 * iwork = NULL;
655 int_4 lwork=-1, lrwork;
656 T wkget[2];
657
658 if(typeid(T) == typeid(r_4) ) {
659 r_4* sloc = new r_4[minmn];
660 r_4 srcond = rcond;
661 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
662 sgelsd(&m,&n,&nrhs,(r_4*)a.Data(),&lda,
663 (r_4*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
664 (r_4*)wkget,&lwork,(int_4*)iwork,&info);
665 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
666 sgelsd(&m,&n,&nrhs,(r_4*)a.Data(),&lda,
667 (r_4*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
668 (r_4*)work,&lwork,(int_4*)iwork,&info);
669 for(int_4 i=0;i<minmn;i++) s(i) = sloc[i];
670 delete [] sloc;
671 } else if(typeid(T) == typeid(r_8) ) {
672 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
673 dgelsd(&m,&n,&nrhs,(r_8*)a.Data(),&lda,
674 (r_8*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
675 (r_8*)wkget,&lwork,(int_4*)iwork,&info);
676 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
677 dgelsd(&m,&n,&nrhs,(r_8*)a.Data(),&lda,
678 (r_8*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
679 (r_8*)work,&lwork,(int_4*)iwork,&info);
680 } else if(typeid(T) == typeid(complex<r_4>) ) {
681 // Cf meme remarque que ci-dessous (complex<r_8)
682 lrwork = 10*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + 3*smlsiz*nrhs + (smlsiz+1)*(smlsiz+1);
683 int_4 lrwork_d = 12*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + minmn*nrhs + (smlsiz+1)*(smlsiz+1);
684 if(lrwork_d > lrwork) lrwork = lrwork_d;
685 r_4* rwork = new r_4[lrwork +GARDMEM];
686 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
687 r_4* sloc = new r_4[minmn];
688 r_4 srcond = rcond;
689 cgelsd(&m,&n,&nrhs,(complex<r_4>*)a.Data(),&lda,
690 (complex<r_4>*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
691 (complex<r_4>*)wkget,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
692 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
693 cgelsd(&m,&n,&nrhs,(complex<r_4>*)a.Data(),&lda,
694 (complex<r_4>*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
695 (complex<r_4>*)work,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
696 for(int_4 i=0;i<minmn;i++) s(i) = sloc[i];
697 delete [] sloc; delete [] rwork;
698 } else if(typeid(T) == typeid(complex<r_8>) ) {
699 // CMV: Bizarrement, la formule donnee dans zgelsd() plante pour des N grands (500)
700 // On prend (par analogie) la formule pour "lwork" de dgelsd()
701 lrwork = 10*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + 3*smlsiz*nrhs + (smlsiz+1)*(smlsiz+1);
702 int_4 lrwork_d = 12*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + minmn*nrhs + (smlsiz+1)*(smlsiz+1);
703 if(lrwork_d > lrwork) lrwork = lrwork_d;
704 r_8* rwork = new r_8[lrwork +GARDMEM];
705 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
706 zgelsd(&m,&n,&nrhs,(complex<r_8>*)a.Data(),&lda,
707 (complex<r_8>*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
708 (complex<r_8>*)wkget,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
709 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
710 zgelsd(&m,&n,&nrhs,(complex<r_8>*)a.Data(),&lda,
711 (complex<r_8>*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
712 (complex<r_8>*)work,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
713 delete [] rwork;
714 } else {
715 if(work) delete [] work; work=NULL;
716 if(iwork) delete [] iwork; iwork=NULL;
717 string tn = typeid(T).name();
718 cerr << " LapackServer::LeastSquareSolveSVD_DC(a,b) - Unsupported DataType T = " << tn << endl;
719 throw TypeMismatchExc("LapackServer::LeastSquareSolveSVD_DC(a,b) - Unsupported DataType (T)");
720 }
721
722 if(work) delete [] work; if(iwork) delete [] iwork;
723 if(info!=0 && Throw_On_Error) {
724 char serr[128]; sprintf(serr,"LeastSquareSolveSVD_DC_Error info=%d",info);
725 throw MathExc(serr);
726 }
727 return(info);
728}
729
730
731////////////////////////////////////////////////////////////////////////////////////
732//! Interface to Lapack SVD driver s/d/c/zgesv().
733/*! Computes the vector of singular values of \b a. Input arrays
734 should have FortranMemoryMapping (column packed).
735 \param a : input m-by-n matrix
736 \param s : Vector of min(m,n) singular values (descending order)
737 \return : return code from lapack driver _gesvd()
738 */
739
740template <class T>
741int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s)
742{
743 return (SVDDriver(a, s, NULL, NULL) );
744}
745
746//! Interface to Lapack SVD driver s/d/c/zgesv().
747/*! Computes the vector of singular values of \b a, as well as
748 right and left singular vectors of \b a.
749 \f[
750 A = U \Sigma V^T , ( A = U \Sigma V^H \ complex)
751 \f]
752 \f[
753 A v_i = \sigma_i u_i \ and A^T u_i = \sigma_i v_i \ (A^H \ complex)
754 \f]
755 U and V are orthogonal (unitary) matrices.
756 \param a : input m-by-n matrix (in FortranMemoryMapping)
757 \param s : Vector of min(m,n) singular values (descending order)
758 \param u : m-by-m Matrix of left singular vectors
759 \param vt : Transpose of right singular vectors (n-by-n matrix).
760 \return : return code from lapack driver _gesvd()
761 */
762template <class T>
763int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s, TArray<T> & u, TArray<T> & vt)
764{
765 return (SVDDriver(a, s, &u, &vt) );
766}
767
768
769//! Interface to Lapack SVD driver s/d/c/zgesv().
770template <class T>
771int LapackServer<T>::SVDDriver(TArray<T>& a, TArray<T> & s, TArray<T>* up, TArray<T>* vtp)
772{
773 if ( ( a.NbDimensions() != 2 ) )
774 throw(SzMismatchError("LapackServer::SVDDriver(a, ...) a.NbDimensions() != 2"));
775
776 int_4 rowa = a.RowsKA();
777 int_4 cola = a.ColsKA();
778
779 if ( !a.IsPacked(rowa) )
780 throw(SzMismatchError("LapackServer::SVDDriver(a, ...) a Not Column Packed "));
781
782 int_4 m = a.Size(rowa);
783 int_4 n = a.Size(cola);
784 int_4 maxmn = (m > n) ? m : n;
785 int_4 minmn = (m < n) ? m : n;
786
787 char jobu, jobvt;
788 jobu = 'N';
789 jobvt = 'N';
790
791 sa_size_t sz[2];
792 if ( up != NULL) {
793 if ( dynamic_cast< TVector<T> * > (vtp) )
794 throw( TypeMismatchExc("LapackServer::SVDDriver() Wrong type (=TVector<T>) for u !") );
795 up->SetMemoryMapping(BaseArray::FortranMemoryMapping);
796 sz[0] = sz[1] = m;
797 up->ReSize(2, sz );
798 jobu = 'A';
799 }
800 else {
801 up = new TMatrix<T>(1,1);
802 jobu = 'N';
803 }
804 if ( vtp != NULL) {
805 if ( dynamic_cast< TVector<T> * > (vtp) )
806 throw( TypeMismatchExc("LapackServer::SVDDriver() Wrong type (=TVector<T>) for vt !") );
807 vtp->SetMemoryMapping(BaseArray::FortranMemoryMapping);
808 sz[0] = sz[1] = n;
809 vtp->ReSize(2, sz );
810 jobvt = 'A';
811 }
812 else {
813 vtp = new TMatrix<T>(1,1);
814 jobvt = 'N';
815 }
816
817 TVector<T> *vs = dynamic_cast< TVector<T> * > (&s);
818 if (vs) vs->ReSize(minmn);
819 else {
820 TMatrix<T> *ms = dynamic_cast< TMatrix<T> * > (&s);
821 if (ms) ms->ReSize(minmn,1);
822 else {
823 sz[0] = minmn; sz[1] = 1;
824 s.ReSize(1, sz);
825 }
826 }
827
828 int_4 lda = a.Step(a.ColsKA());
829 int_4 ldu = up->Step(up->ColsKA());
830 int_4 ldvt = vtp->Step(vtp->ColsKA());
831 int_4 info;
832
833 int_4 lwork = -1; // maxmn*5 *wspace_size_factor;
834 T * work = NULL; // = new T[lwork];
835 T wkget[2];
836
837 if (typeid(T) == typeid(r_4) ) {
838 sgesvd(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
839 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
840 (r_4 *)wkget, &lwork, &info);
841 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
842 sgesvd(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
843 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
844 (r_4 *)work, &lwork, &info);
845 } else if (typeid(T) == typeid(r_8) ) {
846 dgesvd(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
847 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
848 (r_8 *)wkget, &lwork, &info);
849 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
850 dgesvd(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
851 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
852 (r_8 *)work, &lwork, &info);
853 } else if (typeid(T) == typeid(complex<r_4>) ) {
854 r_4 * rwork = new r_4[5*minmn +GARDMEM];
855 r_4 * sloc = new r_4[minmn];
856 cgesvd(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
857 (r_4 *)sloc, (complex<r_4> *) up->Data(), &ldu,
858 (complex<r_4> *)vtp->Data(), &ldvt,
859 (complex<r_4> *)wkget, &lwork, (r_4 *)rwork, &info);
860 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
861 cgesvd(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
862 (r_4 *)sloc, (complex<r_4> *) up->Data(), &ldu,
863 (complex<r_4> *)vtp->Data(), &ldvt,
864 (complex<r_4> *)work, &lwork, (r_4 *)rwork, &info);
865 for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
866 delete [] rwork; delete [] sloc;
867 } else if (typeid(T) == typeid(complex<r_8>) ) {
868 r_8 * rwork = new r_8[5*minmn +GARDMEM];
869 r_8 * sloc = new r_8[minmn];
870 zgesvd(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
871 (r_8 *)sloc, (complex<r_8> *) up->Data(), &ldu,
872 (complex<r_8> *)vtp->Data(), &ldvt,
873 (complex<r_8> *)wkget, &lwork, (r_8 *)rwork, &info);
874 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
875 zgesvd(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
876 (r_8 *)sloc, (complex<r_8> *) up->Data(), &ldu,
877 (complex<r_8> *)vtp->Data(), &ldvt,
878 (complex<r_8> *)work, &lwork, (r_8 *)rwork, &info);
879 for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
880 delete [] rwork; delete [] sloc;
881 } else {
882 if(work) delete [] work; work=NULL;
883 if (jobu == 'N') delete up;
884 if (jobvt == 'N') delete vtp;
885 string tn = typeid(T).name();
886 cerr << " LapackServer::SVDDriver(...) - Unsupported DataType T = " << tn << endl;
887 throw TypeMismatchExc("LapackServer::SVDDriver(a,b) - Unsupported DataType (T)");
888 }
889
890 if(work) delete [] work;
891 if (jobu == 'N') delete up;
892 if (jobvt == 'N') delete vtp;
893 if(info!=0 && Throw_On_Error) {
894 char serr[128]; sprintf(serr,"SVDDriver_Error info=%d",info);
895 throw MathExc(serr);
896 }
897 return(info);
898}
899
900
901//! Interface to Lapack SVD driver s/d/c/zgesdd().
902/*! Same as SVD but with Divide and Conquer method */
903template <class T>
904int LapackServer<T>::SVD_DC(TMatrix<T>& a, TVector<r_8>& s, TMatrix<T>& u, TMatrix<T>& vt)
905{
906
907 if ( !a.IsPacked() )
908 throw(SzMismatchError("LapackServer::SVD_DC(a, ...) a Not Packed "));
909
910 int_4 m = a.NRows();
911 int_4 n = a.NCols();
912 int_4 maxmn = (m > n) ? m : n;
913 int_4 minmn = (m < n) ? m : n;
914 int_4 supermax = 4*minmn*minmn+4*minmn; if(maxmn>supermax) supermax=maxmn;
915
916 char jobz = 'A';
917
918 s.ReSize(minmn);
919 u.ReSize(m,m);
920 vt.ReSize(n,n);
921
922 int_4 lda = m;
923 int_4 ldu = m;
924 int_4 ldvt = n;
925 int_4 info;
926 int_4 lwork=-1;
927 T * work = NULL;
928 int_4 * iwork = NULL;
929 T wkget[2];
930
931 if(typeid(T) == typeid(r_4) ) {
932 r_4* sloc = new r_4[minmn];
933 iwork = new int_4[8*minmn +GARDMEM];
934 sgesdd(&jobz,&m,&n,(r_4*)a.Data(),&lda,
935 (r_4*)sloc,(r_4*)u.Data(),&ldu,(r_4*)vt.Data(),&ldvt,
936 (r_4*)wkget,&lwork,(int_4*)iwork,&info);
937 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
938 sgesdd(&jobz,&m,&n,(r_4*)a.Data(),&lda,
939 (r_4*)sloc,(r_4*)u.Data(),&ldu,(r_4*)vt.Data(),&ldvt,
940 (r_4*)work,&lwork,(int_4*)iwork,&info);
941 for(int_4 i=0;i<minmn;i++) s(i) = (r_8) sloc[i];
942 delete [] sloc;
943 } else if(typeid(T) == typeid(r_8) ) {
944 iwork = new int_4[8*minmn +GARDMEM];
945 dgesdd(&jobz,&m,&n,(r_8*)a.Data(),&lda,
946 (r_8*)s.Data(),(r_8*)u.Data(),&ldu,(r_8*)vt.Data(),&ldvt,
947 (r_8*)wkget,&lwork,(int_4*)iwork,&info);
948 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
949 dgesdd(&jobz,&m,&n,(r_8*)a.Data(),&lda,
950 (r_8*)s.Data(),(r_8*)u.Data(),&ldu,(r_8*)vt.Data(),&ldvt,
951 (r_8*)work,&lwork,(int_4*)iwork,&info);
952 } else if(typeid(T) == typeid(complex<r_4>) ) {
953 r_4* sloc = new r_4[minmn];
954 r_4* rwork = new r_4[5*minmn*minmn+5*minmn +GARDMEM];
955 iwork = new int_4[8*minmn +GARDMEM];
956 cgesdd(&jobz,&m,&n,(complex<r_4>*)a.Data(),&lda,
957 (r_4*)sloc,(complex<r_4>*)u.Data(),&ldu,(complex<r_4>*)vt.Data(),&ldvt,
958 (complex<r_4>*)wkget,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
959 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
960 cgesdd(&jobz,&m,&n,(complex<r_4>*)a.Data(),&lda,
961 (r_4*)sloc,(complex<r_4>*)u.Data(),&ldu,(complex<r_4>*)vt.Data(),&ldvt,
962 (complex<r_4>*)work,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
963 for(int_4 i=0;i<minmn;i++) s(i) = (r_8) sloc[i];
964 delete [] sloc; delete [] rwork;
965 } else if(typeid(T) == typeid(complex<r_8>) ) {
966 r_8* rwork = new r_8[5*minmn*minmn+5*minmn +GARDMEM];
967 iwork = new int_4[8*minmn +GARDMEM];
968 zgesdd(&jobz,&m,&n,(complex<r_8>*)a.Data(),&lda,
969 (r_8*)s.Data(),(complex<r_8>*)u.Data(),&ldu,(complex<r_8>*)vt.Data(),&ldvt,
970 (complex<r_8>*)wkget,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
971 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
972 zgesdd(&jobz,&m,&n,(complex<r_8>*)a.Data(),&lda,
973 (r_8*)s.Data(),(complex<r_8>*)u.Data(),&ldu,(complex<r_8>*)vt.Data(),&ldvt,
974 (complex<r_8>*)work,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
975 delete [] rwork;
976 } else {
977 if(work) delete [] work; work=NULL;
978 if(iwork) delete [] iwork; iwork=NULL;
979 string tn = typeid(T).name();
980 cerr << " LapackServer::SVD_DC(...) - Unsupported DataType T = " << tn << endl;
981 throw TypeMismatchExc("LapackServer::SVD_DC - Unsupported DataType (T)");
982 }
983
984 if(work) delete [] work; if(iwork) delete [] iwork;
985 if(info!=0 && Throw_On_Error) {
986 char serr[128]; sprintf(serr,"SVD_DC_Error info=%d",info);
987 throw MathExc(serr);
988 }
989 return(info);
990}
991
992
993////////////////////////////////////////////////////////////////////////////////////
994/*! Computes the eigen values and eigen vectors of a symetric (or hermitian) matrix \b a.
995 Input arrays should have FortranMemoryMapping (column packed).
996 \param a : input symetric (or hermitian) n-by-n matrix
997 \param b : Vector of eigenvalues (descending order)
998 \param eigenvector : if true compute eigenvectors, if not only eigenvalues
999 \param a : on return array of eigenvectors (same order than eval, one vector = one column)
1000 \return : return code from lapack driver
1001 */
1002
1003template <class T>
1004int LapackServer<T>::LapackEigenSym(TArray<T>& a, TVector<r_8>& b, bool eigenvector)
1005{
1006 if ( a.NbDimensions() != 2 )
1007 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a NbDimensions() != 2"));
1008 int_4 rowa = a.RowsKA();
1009 int_4 cola = a.ColsKA();
1010 if ( a.Size(rowa) != a.Size(cola))
1011 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a Not a square Array"));
1012 if (!a.IsPacked(rowa))
1013 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a Not Column Packed"));
1014
1015 char uplo='U';
1016 char jobz='N'; if(eigenvector) jobz='V';
1017
1018 int_4 n = a.Size(rowa);
1019 int_4 lda = a.Step(cola);
1020 int_4 info = 0;
1021 int_4 lwork = -1;
1022 T * work = NULL;
1023 T wkget[2];
1024
1025 b.ReSize(n); b = 0.;
1026
1027 if (typeid(T) == typeid(r_4) ) {
1028 r_4* w = new r_4[n];
1029 ssyev(&jobz,&uplo,&n,(r_4 *)a.Data(),&lda,(r_4 *)w,(r_4 *)wkget,&lwork,&info);
1030 lwork = type2i4(&wkget[0],4); /* 3*n-1;*/ work = new T[lwork +GARDMEM];
1031 ssyev(&jobz,&uplo,&n,(r_4 *)a.Data(),&lda,(r_4 *)w,(r_4 *)work,&lwork,&info);
1032 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1033 delete [] w;
1034 } else if (typeid(T) == typeid(r_8) ) {
1035 r_8* w = new r_8[n];
1036 dsyev(&jobz,&uplo,&n,(r_8 *)a.Data(),&lda,(r_8 *)w,(r_8 *)wkget,&lwork,&info);
1037 lwork = type2i4(&wkget[0],8); /* 3*n-1;*/ work = new T[lwork +GARDMEM];
1038 dsyev(&jobz,&uplo,&n,(r_8 *)a.Data(),&lda,(r_8 *)w,(r_8 *)work,&lwork,&info);
1039 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1040 delete [] w;
1041 } else if (typeid(T) == typeid(complex<r_4>) ) {
1042 r_4* rwork = new r_4[3*n-2 +GARDMEM]; r_4* w = new r_4[n];
1043 cheev(&jobz,&uplo,&n,(complex<r_4> *)a.Data(),&lda,(r_4 *)w
1044 ,(complex<r_4> *)wkget,&lwork,(r_4 *)rwork,&info);
1045 lwork = type2i4(&wkget[0],4); /* 2*n-1;*/ work = new T[lwork +GARDMEM];
1046 cheev(&jobz,&uplo,&n,(complex<r_4> *)a.Data(),&lda,(r_4 *)w
1047 ,(complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
1048 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1049 delete [] rwork; delete [] w;
1050 } else if (typeid(T) == typeid(complex<r_8>) ) {
1051 r_8* rwork = new r_8[3*n-2 +GARDMEM]; r_8* w = new r_8[n];
1052 zheev(&jobz,&uplo,&n,(complex<r_8> *)a.Data(),&lda,(r_8 *)w
1053 ,(complex<r_8> *)wkget,&lwork,(r_8 *)rwork,&info);
1054 lwork = type2i4(&wkget[0],8); /* 2*n-1;*/ work = new T[lwork +GARDMEM];
1055 zheev(&jobz,&uplo,&n,(complex<r_8> *)a.Data(),&lda,(r_8 *)w
1056 ,(complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
1057 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1058 delete [] rwork; delete [] w;
1059 } else {
1060 if(work) delete [] work; work=NULL;
1061 string tn = typeid(T).name();
1062 cerr << " LapackServer::LapackEigenSym(a,b) - Unsupported DataType T = " << tn << endl;
1063 throw TypeMismatchExc("LapackServer::LapackEigenSym(a,b) - Unsupported DataType (T)");
1064 }
1065
1066 if(work) delete [] work;
1067 if(info!=0 && Throw_On_Error) {
1068 char serr[128]; sprintf(serr,"LapackEigenSym_Error info=%d",info);
1069 throw MathExc(serr);
1070 }
1071 return(info);
1072}
1073
1074////////////////////////////////////////////////////////////////////////////////////
1075/*! Computes the eigen values and eigen vectors of a general squared matrix \b a.
1076 Input arrays should have FortranMemoryMapping (column packed).
1077 \param a : input general n-by-n matrix
1078 \param eval : Vector of eigenvalues (complex double precision)
1079 \param evec : Matrix of eigenvector (same order than eval, one vector = one column)
1080 \param eigenvector : if true compute (right) eigenvectors, if not only eigenvalues
1081 \param a : on return array of eigenvectors
1082 \return : return code from lapack driver
1083 \verbatim
1084 eval : contains the computed eigenvalues.
1085 --- For real matrices "a" :
1086 Complex conjugate pairs of eigenvalues appear consecutively
1087 with the eigenvalue having the positive imaginary part first.
1088 evec : the right eigenvectors v(j) are stored one after another
1089 in the columns of evec, in the same order as their eigenvalues.
1090 --- For real matrices "a" :
1091 If the j-th eigenvalue is real, then v(j) = evec(:,j),
1092 the j-th column of evec.
1093 If the j-th and (j+1)-st eigenvalues form a complex
1094 conjugate pair, then v(j) = evec(:,j) + i*evec(:,j+1) and
1095 v(j+1) = evec(:,j) - i*evec(:,j+1).
1096 \endverbatim
1097*/
1098
1099template <class T>
1100int LapackServer<T>::LapackEigen(TArray<T>& a, TVector< complex<r_8> >& eval, TMatrix<T>& evec, bool eigenvector)
1101{
1102 if ( a.NbDimensions() != 2 )
1103 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a NbDimensions() != 2"));
1104 int_4 rowa = a.RowsKA();
1105 int_4 cola = a.ColsKA();
1106 if ( a.Size(rowa) != a.Size(cola))
1107 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a Not a square Array"));
1108 if (!a.IsPacked(rowa))
1109 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a Not Column Packed"));
1110
1111 char jobvl = 'N';
1112 char jobvr = 'N'; if(eigenvector) jobvr='V';
1113
1114 int_4 n = a.Size(rowa);
1115 int_4 lda = a.Step(cola);
1116 int_4 info = 0;
1117
1118 eval.ReSize(n); eval = complex<r_8>(0.,0.);
1119 if(eigenvector) {evec.ReSize(n,n); evec = (T) 0.;}
1120 int_4 ldvr = n, ldvl = 1;
1121
1122 int_4 lwork = -1;
1123 T * work = NULL;
1124 T wkget[2];
1125
1126 if (typeid(T) == typeid(r_4) ) {
1127 r_4* wr = new r_4[n]; r_4* wi = new r_4[n]; r_4* vl = NULL;
1128 sgeev(&jobvl,&jobvr,&n,(r_4 *)a.Data(),&lda,(r_4 *)wr,(r_4 *)wi,
1129 (r_4 *)vl,&ldvl,(r_4 *)evec.Data(),&ldvr,
1130 (r_4 *)wkget,&lwork,&info);
1131 lwork = type2i4(&wkget[0],4); /* 4*n;*/ work = new T[lwork +GARDMEM];
1132 sgeev(&jobvl,&jobvr,&n,(r_4 *)a.Data(),&lda,(r_4 *)wr,(r_4 *)wi,
1133 (r_4 *)vl,&ldvl,(r_4 *)evec.Data(),&ldvr,
1134 (r_4 *)work,&lwork,&info);
1135 if(info==0) for(int i=0;i<n;i++) eval(i) = complex<r_8>(wr[i],wi[i]);
1136 delete [] wr; delete [] wi;
1137 } else if (typeid(T) == typeid(r_8) ) {
1138 r_8* wr = new r_8[n]; r_8* wi = new r_8[n]; r_8* vl = NULL;
1139 dgeev(&jobvl,&jobvr,&n,(r_8 *)a.Data(),&lda,(r_8 *)wr,(r_8 *)wi,
1140 (r_8 *)vl,&ldvl,(r_8 *)evec.Data(),&ldvr,
1141 (r_8 *)wkget,&lwork,&info);
1142 lwork = type2i4(&wkget[0],8); /* 4*n;*/ work = new T[lwork +GARDMEM];
1143 dgeev(&jobvl,&jobvr,&n,(r_8 *)a.Data(),&lda,(r_8 *)wr,(r_8 *)wi,
1144 (r_8 *)vl,&ldvl,(r_8 *)evec.Data(),&ldvr,
1145 (r_8 *)work,&lwork,&info);
1146 if(info==0) for(int i=0;i<n;i++) eval(i) = complex<r_8>(wr[i],wi[i]);
1147 delete [] wr; delete [] wi;
1148 } else if (typeid(T) == typeid(complex<r_4>) ) {
1149 r_4* rwork = new r_4[2*n +GARDMEM]; r_4* vl = NULL; TVector< complex<r_4> > w(n);
1150 cgeev(&jobvl,&jobvr,&n,(complex<r_4> *)a.Data(),&lda,(complex<r_4> *)w.Data(),
1151 (complex<r_4> *)vl,&ldvl,(complex<r_4> *)evec.Data(),&ldvr,
1152 (complex<r_4> *)wkget,&lwork,(r_4 *)rwork,&info);
1153 lwork = type2i4(&wkget[0],4); /* 2*n;*/ work = new T[lwork +GARDMEM];
1154 cgeev(&jobvl,&jobvr,&n,(complex<r_4> *)a.Data(),&lda,(complex<r_4> *)w.Data(),
1155 (complex<r_4> *)vl,&ldvl,(complex<r_4> *)evec.Data(),&ldvr,
1156 (complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
1157 if(info==0) for(int i=0;i<n;i++) eval(i) = w(i);
1158 delete [] rwork;
1159 } else if (typeid(T) == typeid(complex<r_8>) ) {
1160 r_8* rwork = new r_8[2*n +GARDMEM]; r_8* vl = NULL;
1161 zgeev(&jobvl,&jobvr,&n,(complex<r_8> *)a.Data(),&lda,(complex<r_8> *)eval.Data(),
1162 (complex<r_8> *)vl,&ldvl,(complex<r_8> *)evec.Data(),&ldvr,
1163 (complex<r_8> *)wkget,&lwork,(r_8 *)rwork,&info);
1164 lwork = type2i4(&wkget[0],8); /* 2*n;*/ work = new T[lwork +GARDMEM];
1165 zgeev(&jobvl,&jobvr,&n,(complex<r_8> *)a.Data(),&lda,(complex<r_8> *)eval.Data(),
1166 (complex<r_8> *)vl,&ldvl,(complex<r_8> *)evec.Data(),&ldvr,
1167 (complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
1168 delete [] rwork;
1169 } else {
1170 if(work) delete [] work; work=NULL;
1171 string tn = typeid(T).name();
1172 cerr << " LapackServer::LapackEigen(a,b) - Unsupported DataType T = " << tn << endl;
1173 throw TypeMismatchExc("LapackServer::LapackEigen(a,b) - Unsupported DataType (T)");
1174 }
1175
1176 if(work) delete [] work;
1177 if(info!=0 && Throw_On_Error) {
1178 char serr[128]; sprintf(serr,"LapackEigen_Error info=%d",info);
1179 throw MathExc(serr);
1180 }
1181 return(info);
1182}
1183
1184
1185
1186
1187///////////////////////////////////////////////////////////////
1188#ifdef __CXX_PRAGMA_TEMPLATES__
1189#pragma define_template LapackServer<r_4>
1190#pragma define_template LapackServer<r_8>
1191#pragma define_template LapackServer< complex<r_4> >
1192#pragma define_template LapackServer< complex<r_8> >
1193#endif
1194
1195#if defined(ANSI_TEMPLATES) || defined(GNU_TEMPLATES)
1196namespace SOPHYA {
1197template class LapackServer<r_4>;
1198template class LapackServer<r_8>;
1199template class LapackServer< complex<r_4> >;
1200template class LapackServer< complex<r_8> >;
1201}
1202#endif
1203
1204#if defined(OS_LINUX)
1205// Pour le link avec f2c sous Linux
1206extern "C" {
1207 void MAIN__();
1208}
1209
1210void MAIN__()
1211{
1212 cerr << "MAIN__() function for linking with libf2c.a " << endl;
1213 cerr << " This function should never be called !!! " << endl;
1214 throw PError("MAIN__() should not be called - see intflapack.cc");
1215}
1216#endif
Note: See TracBrowser for help on using the repository browser.