source: Sophya/trunk/SophyaExt/LinAlg/intflapack.cc

Last change on this file was 3740, checked in by cmv, 16 years ago

bug avec SetMemoryMapping, cmv+rz 08/02/2010

File size: 48.6 KB
Line 
1#include <iostream>
2#include <string.h>
3#include <math.h>
4#include "sopnamsp.h"
5#include "intflapack.h"
6#include "sspvflags.h"
7
8#include "tvector.h"
9#include "tmatrix.h"
10#include <typeinfo>
11
12#define GARDMEM 5
13
14/*************** Pour memoire (Christophe) ***************
15Les dispositions memoires (FORTRAN) pour les vecteurs et matrices LAPACK:
16
171./ --- REAL X(N):
18 if an array X of dimension (N) holds a vector x,
19 then X(i) holds "x_i" for i=1,...,N
20
212./ --- REAL A(LDA,N):
22 if a two-dimensional array A of dimension (LDA,N) holds an m-by-n matrix A,
23 then A(i,j) holds "a_ij" for i=1,...,m et j=1,...,n (LDA must be at least m).
24 Note that array arguments are usually declared in the software as assumed-size
25 arrays (last dimension *), for example: REAL A(LDA,*)
26 --- Rangement en memoire:
27 | 11 12 13 14 |
28 Ex: Real A(4,4): A = | 21 22 23 24 |
29 | 31 32 33 34 |
30 | 41 42 43 44 |
31 memoire: {11 21 31 41} {12 22 32 42} {13 23 33 43} {14 24 34 44}
32 First indice (line) "i" varies then the second (column):
33 (put all the first column, then put all the second column,
34 ..., then put all the last column)
35***********************************************************/
36
37/*!
38 \defgroup LinAlg LinAlg module
39 This module contains classes and functions for complex linear
40 algebra on arrays. This module is intended mainly to have
41 classes implementing C++ interfaces between Sophya objects
42 and external linear algebra libraries, such as LAPACK.
43*/
44
45/*!
46 \class SOPHYA::LapackServer
47 \ingroup LinAlg
48 This class implements an interface to LAPACK library driver routines.
49 The LAPACK (Linear Algebra PACKage) is a collection high performance
50 routines to solve common problems in numerical linear algebra.
51 its is available from http://www.netlib.org.
52
53 The present version of LapackServer (Feb 2005) provides
54 interfaces for the linear system solver, singular value
55 decomposition (SVD), Least square solver and
56 eigen value / eigen vector decomposition.
57 Only arrays with BaseArray::FortranMemoryMapping
58 can be handled by LapackServer. LapackServer can be instanciated
59 for simple and double precision real or complex array types.
60 \warning The input array is overwritten in most cases.
61 The example below shows solving a linear system A*X = B
62
63 \code
64 #include "intflapack.h"
65 // ...
66 // Use FortranMemoryMapping as default
67 BaseArray::SetDefaultMemoryMapping(BaseArray::FortranMemoryMapping);
68 // Create an fill the arrays A and B
69 int n = 20;
70 Matrix A(n, n);
71 A = RandomSequence();
72 Vector X(n),B(n);
73 X = RandomSequence();
74 B = A*X;
75 // Solve the linear system A*X = B
76 LapackServer<r_8> lps;
77 lps.LinSolve(A,B);
78 // We get the result in B, which should be equal to X ...
79 // Compute the difference B-X ;
80 Vector diff = B-X;
81 \endcode
82
83*/
84
85/*
86 Decembre 2005 : Suite portage AIX xlC
87 On declare des noms en majuscule pour les routines fortran -
88 avec ou sans underscore _ , suivant les systemes
89*/
90#ifdef AIX
91
92#define ilaenv ilaenv
93
94#define sgesv sgesv
95#define dgesv dgesv
96#define cgesv cgesv
97#define zgesv zgesv
98
99#define ssysv ssysv
100#define dsysv dsysv
101#define csysv csysv
102#define zsysv zsysv
103
104#define sgels sgels
105#define dgels dgels
106#define cgels cgels
107#define zgels zgels
108
109#define sgelsd sgelsd
110#define dgelsd dgelsd
111#define cgelsd cgelsd
112#define zgelsd zgelsd
113
114#define sgesvd sgesvd
115#define dgesvd dgesvd
116#define cgesvd cgesvd
117#define zgesvd zgesvd
118
119#define sgesdd sgesdd
120#define dgesdd dgesdd
121#define cgesdd cgesdd
122#define zgesdd zgesdd
123
124#define ssyev ssyev
125#define dsyev dsyev
126#define cheev cheev
127#define zheev zheev
128
129#define sgeev sgeev
130#define dgeev dgeev
131#define cgeev cgeev
132#define zgeev zgeev
133
134#else
135#define ilaenv ilaenv_
136
137#define sgesv sgesv_
138#define dgesv dgesv_
139#define cgesv cgesv_
140#define zgesv zgesv_
141
142#define ssysv ssysv_
143#define dsysv dsysv_
144#define csysv csysv_
145#define zsysv zsysv_
146
147#define sgels sgels_
148#define dgels dgels_
149#define cgels cgels_
150#define zgels zgels_
151
152#define sgelsd sgelsd_
153#define dgelsd dgelsd_
154#define cgelsd cgelsd_
155#define zgelsd zgelsd_
156
157#define sgesvd sgesvd_
158#define dgesvd dgesvd_
159#define cgesvd cgesvd_
160#define zgesvd zgesvd_
161
162#define sgesdd sgesdd_
163#define dgesdd dgesdd_
164#define cgesdd cgesdd_
165#define zgesdd zgesdd_
166
167#define ssyev ssyev_
168#define dsyev dsyev_
169#define cheev cheev_
170#define zheev zheev_
171
172#define sgeev sgeev_
173#define dgeev dgeev_
174#define cgeev cgeev_
175#define zgeev zgeev_
176
177#endif
178////////////////////////////////////////////////////////////////////////////////////
179extern "C" {
180// Le calculateur de workingspace
181 int_4 ilaenv(int_4 *ispec,const char *name,const char *opts,int_4 *n1,int_4 *n2,int_4 *n3,int_4 *n4,
182 int_4 nc1,int_4 nc2);
183
184// Drivers pour resolution de systemes lineaires
185 void sgesv(int_4* n, int_4* nrhs, r_4* a, int_4* lda,
186 int_4* ipiv, r_4* b, int_4* ldb, int_4* info);
187 void dgesv(int_4* n, int_4* nrhs, r_8* a, int_4* lda,
188 int_4* ipiv, r_8* b, int_4* ldb, int_4* info);
189 void cgesv(int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
190 int_4* ipiv, complex<r_4>* b, int_4* ldb, int_4* info);
191 void zgesv(int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
192 int_4* ipiv, complex<r_8>* b, int_4* ldb, int_4* info);
193
194// Drivers pour resolution de systemes lineaires symetriques
195 void ssysv(char* uplo, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
196 int_4* ipiv, r_4* b, int_4* ldb,
197 r_4* work, int_4* lwork, int_4* info);
198 void dsysv(char* uplo, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
199 int_4* ipiv, r_8* b, int_4* ldb,
200 r_8* work, int_4* lwork, int_4* info);
201 void csysv(char* uplo, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
202 int_4* ipiv, complex<r_4>* b, int_4* ldb,
203 complex<r_4>* work, int_4* lwork, int_4* info);
204 void zsysv(char* uplo, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
205 int_4* ipiv, complex<r_8>* b, int_4* ldb,
206 complex<r_8>* work, int_4* lwork, int_4* info);
207
208// Driver pour resolution de systemes au sens de Xi2
209 void sgels(char * trans, int_4* m, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
210 r_4* b, int_4* ldb, r_4* work, int_4* lwork, int_4* info);
211 void dgels(char * trans, int_4* m, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
212 r_8* b, int_4* ldb, r_8* work, int_4* lwork, int_4* info);
213 void cgels(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
214 complex<r_4>* b, int_4* ldb, complex<r_4>* work, int_4* lwork, int_4* info);
215 void zgels(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
216 complex<r_8>* b, int_4* ldb, complex<r_8>* work, int_4* lwork, int_4* info);
217
218// Driver pour resolution de systemes au sens de Xi2 par SVD Divide & Conquer
219 void sgelsd(int_4* m,int_4* n,int_4* nrhs,r_4* a,int_4* lda,
220 r_4* b,int_4* ldb,r_4* s,r_4* rcond,int_4* rank,
221 r_4* work,int_4* lwork,int_4* iwork,int_4* info);
222 void dgelsd(int_4* m,int_4* n,int_4* nrhs,r_8* a,int_4* lda,
223 r_8* b,int_4* ldb,r_8* s,r_8* rcond,int_4* rank,
224 r_8* work,int_4* lwork,int_4* iwork,int_4* info);
225 void cgelsd(int_4* m,int_4* n,int_4* nrhs,complex<r_4>* a,int_4* lda,
226 complex<r_4>* b,int_4* ldb,r_4* s,r_4* rcond,int_4* rank,
227 complex<r_4>* work,int_4* lwork,r_4* rwork,int_4* iwork,int_4* info);
228 void zgelsd(int_4* m,int_4* n,int_4* nrhs,complex<r_8>* a,int_4* lda,
229 complex<r_8>* b,int_4* ldb,r_8* s,r_8* rcond,int_4* rank,
230 complex<r_8>* work,int_4* lwork,r_8* rwork,int_4* iwork,int_4* info);
231
232// Driver pour decomposition SVD
233 void sgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, r_4* a, int_4* lda,
234 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
235 r_4* work, int_4* lwork, int_4* info);
236 void dgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, r_8* a, int_4* lda,
237 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
238 r_8* work, int_4* lwork, int_4* info);
239 void cgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
240 r_4* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
241 complex<r_4>* work, int_4* lwork, r_4* rwork, int_4* info);
242 void zgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
243 r_8* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
244 complex<r_8>* work, int_4* lwork, r_8* rwork, int_4* info);
245
246// Driver pour decomposition SVD Divide and Conquer
247 void sgesdd(char* jobz, int_4* m, int_4* n, r_4* a, int_4* lda,
248 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
249 r_4* work, int_4* lwork, int_4* iwork, int_4* info);
250 void dgesdd(char* jobz, int_4* m, int_4* n, r_8* a, int_4* lda,
251 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
252 r_8* work, int_4* lwork, int_4* iwork, int_4* info);
253 void cgesdd(char* jobz, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
254 r_4* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
255 complex<r_4>* work, int_4* lwork, r_4* rwork, int_4* iwork, int_4* info);
256 void zgesdd(char* jobz, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
257 r_8* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
258 complex<r_8>* work, int_4* lwork, r_8* rwork, int_4* iwork, int_4* info);
259
260// Driver pour eigen decomposition for symetric/hermitian matrices
261 void ssyev(char* jobz, char* uplo, int_4* n, r_4* a, int_4* lda, r_4* w,
262 r_4* work, int_4 *lwork, int_4* info);
263 void dsyev(char* jobz, char* uplo, int_4* n, r_8* a, int_4* lda, r_8* w,
264 r_8* work, int_4 *lwork, int_4* info);
265 void cheev(char* jobz, char* uplo, int_4* n, complex<r_4>* a, int_4* lda, r_4* w,
266 complex<r_4>* work, int_4 *lwork, r_4* rwork, int_4* info);
267 void zheev(char* jobz, char* uplo, int_4* n, complex<r_8>* a, int_4* lda, r_8* w,
268 complex<r_8>* work, int_4 *lwork, r_8* rwork, int_4* info);
269
270// Driver pour eigen decomposition for general squared matrices
271 void sgeev(char* jobl, char* jobvr, int_4* n, r_4* a, int_4* lda, r_4* wr, r_4* wi,
272 r_4* vl, int_4* ldvl, r_4* vr, int_4* ldvr,
273 r_4* work, int_4 *lwork, int_4* info);
274 void dgeev(char* jobl, char* jobvr, int_4* n, r_8* a, int_4* lda, r_8* wr, r_8* wi,
275 r_8* vl, int_4* ldvl, r_8* vr, int_4* ldvr,
276 r_8* work, int_4 *lwork, int_4* info);
277 void cgeev(char* jobl, char* jobvr, int_4* n, complex<r_4>* a, int_4* lda, complex<r_4>* w,
278 complex<r_4>* vl, int_4* ldvl, complex<r_4>* vr, int_4* ldvr,
279 complex<r_4>* work, int_4 *lwork, r_4* rwork, int_4* info);
280 void zgeev(char* jobl, char* jobvr, int_4* n, complex<r_8>* a, int_4* lda, complex<r_8>* w,
281 complex<r_8>* vl, int_4* ldvl, complex<r_8>* vr, int_4* ldvr,
282 complex<r_8>* work, int_4 *lwork, r_8* rwork, int_4* info);
283
284}
285
286// -------------- Classe LapackServer<T> --------------
287
288////////////////////////////////////////////////////////////////////////////////////
289template <class T>
290LapackServer<T>::LapackServer(bool throw_on_error)
291 : Throw_On_Error(throw_on_error)
292{
293 SetWorkSpaceSizeFactor();
294}
295
296template <class T>
297LapackServer<T>::~LapackServer()
298{
299}
300
301////////////////////////////////////////////////////////////////////////////////////
302template <class T>
303int_4 LapackServer<T>::ilaenv_en_C(int_4 ispec,const char *name,const char *opts,int_4 n1,int_4 n2,int_4 n3,int_4 n4)
304{
305 int_4 nc1 = strlen(name), nc2 = strlen(opts), rc=0;
306 rc = ilaenv(&ispec,name,opts,&n1,&n2,&n3,&n4,nc1,nc2);
307 //cout<<"ilaenv_en_C("<<ispec<<","<<name<<"("<<nc1<<"),"<<opts<<"("<<nc2<<"),"
308 // <<n1<<","<<n2<<","<<n3<<","<<n4<<") = "<<rc<<endl;
309 return rc;
310}
311
312template <class T>
313int_4 LapackServer<T>::type2i4(void *val,int nbytes)
314// Retourne un entier contenant la valeur contenue dans val
315// - nbytes = nombre de bytes dans le contenu de val
316// ex: r_4 x = 3.4; type2i4(&x,4) -> 3
317// ex: r_8 x = 3.4; type2i4(&x,8) -> 3
318// ex: complex<r_4> x(3.4,7.8); type2i4(&x,4) -> 3
319// ex: complex<r_8> x(3.4,7.8); type2i4(&x,8) -> 3
320{
321 r_4* x4; r_8* x8; int_4 lw=0;
322 if(nbytes==4) {x4 = (r_4*)val; lw = (int_4)(*x4);}
323 else {x8 = (r_8*)val; lw = (int_4)(*x8);}
324 return lw;
325}
326
327////////////////////////////////////////////////////////////////////////////////////
328//! Interface to Lapack linear system solver driver s/d/c/zgesv().
329/*! Solve the linear system a * x = b using LU factorization.
330 Input arrays should have FortranMemory mapping (column packed).
331 \param a : input matrix, overwritten on output
332 \param b : input-output, input vector b, contains x on exit
333 \return : return code from lapack driver _gesv()
334 */
335template <class T>
336int LapackServer<T>::LinSolve(TArray<T>& a, TArray<T> & b)
337{
338 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
339 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
340
341 int_4 rowa = a.RowsKA();
342 int_4 cola = a.ColsKA();
343 int_4 rowb = b.RowsKA();
344 int_4 colb = b.ColsKA();
345 if ( a.Size(rowa) != a.Size(cola))
346 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Not a square Array"));
347 if ( a.Size(rowa) != b.Size(rowb))
348 throw(SzMismatchError("LapackServer::LinSolve(a,b) RowSize(a <> b) "));
349
350 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
351 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Column Packed"));
352
353 int_4 n = a.Size(rowa);
354 int_4 nrhs = b.Size(colb);
355 int_4 lda = a.Step(cola);
356 int_4 ldb = b.Step(colb);
357 int_4 info;
358 int_4* ipiv = new int_4[n];
359
360 if (typeid(T) == typeid(r_4) )
361 sgesv(&n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb, &info);
362 else if (typeid(T) == typeid(r_8) )
363 dgesv(&n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb, &info);
364 else if (typeid(T) == typeid(complex<r_4>) )
365 cgesv(&n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
366 (complex<r_4> *)b.Data(), &ldb, &info);
367 else if (typeid(T) == typeid(complex<r_8>) )
368 zgesv(&n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
369 (complex<r_8> *)b.Data(), &ldb, &info);
370 else {
371 delete[] ipiv;
372 string tn = typeid(T).name();
373 cerr << " LapackServer::LinSolve(a,b) - Unsupported DataType T = " << tn << endl;
374 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
375 }
376 delete[] ipiv;
377 if(info!=0 && Throw_On_Error) {
378 char serr[128]; sprintf(serr,"LinSolve_Error info=%d",info);
379 throw MathExc(serr);
380 }
381 return(info);
382}
383
384////////////////////////////////////////////////////////////////////////////////////
385//! Interface to Lapack linear system solver driver s/d/c/zsysv().
386/*! Solve the linear system a * x = b with a symetric matrix using LU factorization.
387 Input arrays should have FortranMemory mapping (column packed).
388 \param a : input matrix symetric , overwritten on output
389 \param b : input-output, input vector b, contains x on exit
390 \return : return code from lapack driver _sysv()
391 */
392template <class T>
393int LapackServer<T>::LinSolveSym(TArray<T>& a, TArray<T> & b)
394// --- REMARQUES DE CMV ---
395// 1./ contrairement a ce qui est dit dans la doc, il s'agit
396// de matrices SYMETRIQUES complexes et non HERMITIENNES !!!
397// 2./ pourquoi les routines de LinSolve pour des matrices symetriques
398// sont plus de deux fois plus lentes que les LinSolve generales sur OSF
399// et sensiblement plus lentes sous Linux ???
400{
401 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
402 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Or b NbDimensions() != 2"));
403 int_4 rowa = a.RowsKA();
404 int_4 cola = a.ColsKA();
405 int_4 rowb = b.RowsKA();
406 int_4 colb = b.ColsKA();
407 if ( a.Size(rowa) != a.Size(cola))
408 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Not a square Array"));
409 if ( a.Size(rowa) != b.Size(rowb))
410 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) RowSize(a <> b) "));
411
412 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
413 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Or b Not Column Packed"));
414
415 int_4 n = a.Size(rowa);
416 int_4 nrhs = b.Size(colb);
417 int_4 lda = a.Step(cola);
418 int_4 ldb = b.Step(colb);
419 int_4 info = 0;
420 int_4* ipiv = new int_4[n];
421 int_4 lwork = -1;
422 T * work = NULL;
423 T wkget[2];
424
425 char uplo = 'U'; // char uplo = 'L';
426 char struplo[5]; struplo[0] = uplo; struplo[1] = '\0';
427
428 if (typeid(T) == typeid(r_4) ) {
429 ssysv(&uplo, &n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb,
430 (r_4 *)wkget, &lwork, &info);
431 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
432 ssysv(&uplo, &n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb,
433 (r_4 *)work, &lwork, &info);
434 } else if (typeid(T) == typeid(r_8) ) {
435 dsysv(&uplo, &n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb,
436 (r_8 *)wkget, &lwork, &info);
437 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
438 dsysv(&uplo, &n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb,
439 (r_8 *)work, &lwork, &info);
440 } else if (typeid(T) == typeid(complex<r_4>) ) {
441 csysv(&uplo, &n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
442 (complex<r_4> *)b.Data(), &ldb,
443 (complex<r_4> *)wkget, &lwork, &info);
444 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
445 csysv(&uplo, &n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
446 (complex<r_4> *)b.Data(), &ldb,
447 (complex<r_4> *)work, &lwork, &info);
448 } else if (typeid(T) == typeid(complex<r_8>) ) {
449 zsysv(&uplo, &n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
450 (complex<r_8> *)b.Data(), &ldb,
451 (complex<r_8> *)wkget, &lwork, &info);
452 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
453 zsysv(&uplo, &n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
454 (complex<r_8> *)b.Data(), &ldb,
455 (complex<r_8> *)work, &lwork, &info);
456 } else {
457 if(work) delete[] work;
458 delete[] ipiv;
459 string tn = typeid(T).name();
460 cerr << " LapackServer::LinSolveSym(a,b) - Unsupported DataType T = " << tn << endl;
461 throw TypeMismatchExc("LapackServer::LinSolveSym(a,b) - Unsupported DataType (T)");
462 }
463 if(work) delete[] work;
464 delete[] ipiv;
465 if(info!=0 && Throw_On_Error) {
466 char serr[128]; sprintf(serr,"LinSolveSym_Error info=%d",info);
467 throw MathExc(serr);
468 }
469 return(info);
470}
471
472////////////////////////////////////////////////////////////////////////////////////
473//! Interface to Lapack least squares solver driver s/d/c/zgels().
474/*! Solves the linear least squares problem defined by an m-by-n matrix
475 \b a and an m element vector \b b , using QR or LQ factorization .
476 A solution \b x to the overdetermined system of linear equations
477 b = a * x is computed, minimizing the norm of b-a*x.
478 Underdetermined systems (m<n) are not yet handled.
479 Inout arrays should have FortranMemory mapping (column packed).
480 \param a : input matrix, overwritten on output
481 \param b : input-output, input vector b, contains x on exit.
482 \return : return code from lapack driver _gels()
483 \warning : b is not resized.
484 */
485/*
486 $CHECK$ - A faire - cas m<n
487 If the linear system is underdetermined, the minimum norm
488 solution is computed.
489*/
490
491template <class T>
492int LapackServer<T>::LeastSquareSolve(TArray<T>& a, TArray<T> & b)
493{
494 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
495 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) a Or b NbDimensions() != 2"));
496
497 int_4 rowa = a.RowsKA();
498 int_4 cola = a.ColsKA();
499 int_4 rowb = b.RowsKA();
500 int_4 colb = b.ColsKA();
501
502
503 if ( a.Size(rowa) != b.Size(rowb))
504 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) RowSize(a <> b) "));
505
506 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
507 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) a Or b Not Column Packed"));
508
509 if ( a.Size(rowa) < a.Size(cola)) { // $CHECK$ - m<n a changer
510 cout << " LapackServer<T>::LeastSquareSolve() - m<n - Not yet implemented for "
511 << " underdetermined systems ! " << endl;
512 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) NRows<NCols - "));
513 }
514 int_4 m = a.Size(rowa);
515 int_4 n = a.Size(cola);
516 int_4 nrhs = b.Size(colb);
517
518 int_4 lda = a.Step(cola);
519 int_4 ldb = b.Step(colb);
520 int_4 info;
521
522 //unused: int_4 minmn = (m < n) ? m : n;
523 int_4 maxmn = (m > n) ? m : n;
524 int_4 maxmnrhs = (nrhs > maxmn) ? nrhs : maxmn;
525 if (maxmnrhs < 1) maxmnrhs = 1;
526
527 int_4 lwork = -1; //minmn+maxmnrhs*5;
528 T * work = NULL;
529 T wkget[2];
530
531 char trans = 'N';
532
533 if (typeid(T) == typeid(r_4) ) {
534 sgels(&trans, &m, &n, &nrhs, (r_4 *)a.Data(), &lda,
535 (r_4 *)b.Data(), &ldb, (r_4 *)wkget, &lwork, &info);
536 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
537 sgels(&trans, &m, &n, &nrhs, (r_4 *)a.Data(), &lda,
538 (r_4 *)b.Data(), &ldb, (r_4 *)work, &lwork, &info);
539 } else if (typeid(T) == typeid(r_8) ) {
540 dgels(&trans, &m, &n, &nrhs, (r_8 *)a.Data(), &lda,
541 (r_8 *)b.Data(), &ldb, (r_8 *)wkget, &lwork, &info);
542 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
543 dgels(&trans, &m, &n, &nrhs, (r_8 *)a.Data(), &lda,
544 (r_8 *)b.Data(), &ldb, (r_8 *)work, &lwork, &info);
545 } else if (typeid(T) == typeid(complex<r_4>) ) {
546 cgels(&trans, &m, &n, &nrhs, (complex<r_4> *)a.Data(), &lda,
547 (complex<r_4> *)b.Data(), &ldb, (complex<r_4> *)wkget, &lwork, &info);
548 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
549 cgels(&trans, &m, &n, &nrhs, (complex<r_4> *)a.Data(), &lda,
550 (complex<r_4> *)b.Data(), &ldb, (complex<r_4> *)work, &lwork, &info);
551 } else if (typeid(T) == typeid(complex<r_8>) ) {
552 zgels(&trans, &m, &n, &nrhs, (complex<r_8> *)a.Data(), &lda,
553 (complex<r_8> *)b.Data(), &ldb, (complex<r_8> *)wkget, &lwork, &info);
554 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
555 zgels(&trans, &m, &n, &nrhs, (complex<r_8> *)a.Data(), &lda,
556 (complex<r_8> *)b.Data(), &ldb, (complex<r_8> *)work, &lwork, &info);
557 } else {
558 if(work) delete [] work; work=NULL;
559 string tn = typeid(T).name();
560 cerr << " LapackServer::LeastSquareSolve(a,b) - Unsupported DataType T = " << tn << endl;
561 throw TypeMismatchExc("LapackServer::LeastSquareSolve(a,b) - Unsupported DataType (T)");
562 }
563 if(work) delete [] work;
564 if(info!=0 && Throw_On_Error) {
565 char serr[128]; sprintf(serr,"LeastSquareSolve_Error info=%d",info);
566 throw MathExc(serr);
567 }
568 return(info);
569}
570
571////////////////////////////////////////////////////////////////////////////////////
572//! Square matrix inversion using Lapack linear system solver
573/*! Compute the inverse of a square matrix using linear system solver routine
574 Input arrays should have FortranMemory mapping (column packed).
575 \param a : input matrix, overwritten on output
576 \param ainv : output matrix, contains inverse(a) on exit.
577 ainv is allocated if it has size 0
578 If not allocated, ainv is automatically
579 \return : return code from LapackServer::LinSolve()
580 \sa LapackServer::LinSolve()
581 */
582template <class T>
583int LapackServer<T>::ComputeInverse(TMatrix<T>& a, TMatrix<T> & ainv)
584{
585 if ( a.NbDimensions() != 2 )
586 throw(SzMismatchError("LapackServer::Inverse() NDim(a) != 2"));
587 if ( a.GetMemoryMapping() != BaseArray::FortranMemoryMapping )
588 throw(SzMismatchError("LapackServer::Inverse() a NOT in FortranMemoryMapping"));
589 if ( a.NRows() != a.NCols() )
590 throw(SzMismatchError("LapackServer::Inverse() a NOT square matrix (a.NRows!=a.NCols)"));
591 if (ainv.IsAllocated()) {
592 bool smo, ssz;
593 ssz = a.CompareSizes(ainv, smo);
594 if ( (ssz == false) || (smo == false) )
595 throw(SzMismatchError("LapackServer::Inverse() ainv<>a Size/MemOrg mismatch "));
596 }
597 else ainv.SetSize(a.NRows(), a.NCols(), BaseArray::FortranMemoryMapping, false);
598 ainv = IdentityMatrix();
599 return LinSolve(a, ainv);
600}
601
602////////////////////////////////////////////////////////////////////////////////////
603//! Interface to Lapack least squares solver driver s/d/c/zgelsd().
604/*! Solves the linear least squares problem defined by an m-by-n matrix
605 \b a and an m element vector \b b , using SVD factorization Divide and Conquer.
606 Inout arrays should have FortranMemory mapping (column packed).
607 \param rcond : definition of zero value (S(i) <= RCOND*S(0) are treated as zero).
608 If RCOND < 0, machine precision is used instead.
609 \param a : input matrix, overwritten on output
610 \param b : input vector b overwritten by solution on output (beware of size changing)
611 \param x : output matrix of solutions.
612 \param rank : output the rank of the matrix.
613 \return : return code from lapack driver _gelsd()
614 \warning : b is not resized.
615 */
616template <class T>
617int LapackServer<T>::LeastSquareSolveSVD_DC(TMatrix<T>& a,TMatrix<T>& b,TVector<r_8>& s,int_4& rank,r_8 rcond)
618{
619#ifdef LAPACK_V2_EXTSOP
620 throw NotAvailableOperation("LapackServer::LeastSquareSolveSVD_DC(a,b) NOT implemented in LapackV2") ;
621#else
622 if ( ( a.NbDimensions() != 2 ) )
623 throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) a != 2"));
624
625 if (!a.IsPacked() || !b.IsPacked())
626 throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) a Or b Not Packed"));
627
628 int_4 m = a.NRows();
629 int_4 n = a.NCols();
630
631 if(b.NRows() != m)
632 throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) bad matching dim between a and b"));
633
634 int_4 nrhs = b.NCols();
635 int_4 minmn = (m < n) ? m : n;
636 int_4 maxmn = (m > n) ? m : n;
637
638 int_4 lda = m;
639 int_4 ldb = maxmn;
640 int_4 info;
641
642 { // Use {} for automatic des-allocation of "bsave"
643 TMatrix<T> bsave = b;
644 b.ReSize(maxmn,nrhs); b = (T) 0;
645 for(int i=0;i<m;i++) for(int j=0;j<nrhs;j++) b(i,j) = bsave(i,j);
646 } // Use {} for automatic des-allocation of "bsave"
647 s.ReSize(minmn);
648
649 int_4 smlsiz = 25; // Normallement ilaenv_en_C(9,...) renvoie toujours 25
650 if(typeid(T) == typeid(r_4) ) smlsiz = ilaenv_en_C(9,"SGELSD"," ",0,0,0,0);
651 else if(typeid(T) == typeid(r_8) ) smlsiz = ilaenv_en_C(9,"DGELSD"," ",0,0,0,0);
652 else if(typeid(T) == typeid(complex<r_4>) ) smlsiz = ilaenv_en_C(9,"CGELSD"," ",0,0,0,0);
653 else if(typeid(T) == typeid(complex<r_8>) ) smlsiz = ilaenv_en_C(9,"ZGELSD"," ",0,0,0,0);
654 if(smlsiz<0) smlsiz = 0;
655 r_8 dum = log((r_8)minmn/(r_8)(smlsiz+1.)) / log(2.);
656 int_4 nlvl = int_4(dum) + 1; if(nlvl<0) nlvl = 0;
657
658 T * work = NULL;
659 int_4 * iwork = NULL;
660 int_4 lwork=-1, lrwork;
661 T wkget[2];
662
663 if(typeid(T) == typeid(r_4) ) {
664 r_4* sloc = new r_4[minmn];
665 r_4 srcond = rcond;
666 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
667 sgelsd(&m,&n,&nrhs,(r_4*)a.Data(),&lda,
668 (r_4*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
669 (r_4*)wkget,&lwork,(int_4*)iwork,&info);
670 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
671 sgelsd(&m,&n,&nrhs,(r_4*)a.Data(),&lda,
672 (r_4*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
673 (r_4*)work,&lwork,(int_4*)iwork,&info);
674 for(int_4 i=0;i<minmn;i++) s(i) = sloc[i];
675 delete [] sloc;
676 } else if(typeid(T) == typeid(r_8) ) {
677 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
678 dgelsd(&m,&n,&nrhs,(r_8*)a.Data(),&lda,
679 (r_8*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
680 (r_8*)wkget,&lwork,(int_4*)iwork,&info);
681 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
682 dgelsd(&m,&n,&nrhs,(r_8*)a.Data(),&lda,
683 (r_8*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
684 (r_8*)work,&lwork,(int_4*)iwork,&info);
685 } else if(typeid(T) == typeid(complex<r_4>) ) {
686 // Cf meme remarque que ci-dessous (complex<r_8)
687 lrwork = 10*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + 3*smlsiz*nrhs + (smlsiz+1)*(smlsiz+1);
688 int_4 lrwork_d = 12*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + minmn*nrhs + (smlsiz+1)*(smlsiz+1);
689 if(lrwork_d > lrwork) lrwork = lrwork_d;
690 r_4* rwork = new r_4[lrwork +GARDMEM];
691 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
692 r_4* sloc = new r_4[minmn];
693 r_4 srcond = rcond;
694 cgelsd(&m,&n,&nrhs,(complex<r_4>*)a.Data(),&lda,
695 (complex<r_4>*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
696 (complex<r_4>*)wkget,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
697 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
698 cgelsd(&m,&n,&nrhs,(complex<r_4>*)a.Data(),&lda,
699 (complex<r_4>*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
700 (complex<r_4>*)work,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
701 for(int_4 i=0;i<minmn;i++) s(i) = sloc[i];
702 delete [] sloc; delete [] rwork;
703 } else if(typeid(T) == typeid(complex<r_8>) ) {
704 // CMV: Bizarrement, la formule donnee dans zgelsd() plante pour des N grands (500)
705 // On prend (par analogie) la formule pour "lwork" de dgelsd()
706 lrwork = 10*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + 3*smlsiz*nrhs + (smlsiz+1)*(smlsiz+1);
707 int_4 lrwork_d = 12*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + minmn*nrhs + (smlsiz+1)*(smlsiz+1);
708 if(lrwork_d > lrwork) lrwork = lrwork_d;
709 r_8* rwork = new r_8[lrwork +GARDMEM];
710 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
711 zgelsd(&m,&n,&nrhs,(complex<r_8>*)a.Data(),&lda,
712 (complex<r_8>*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
713 (complex<r_8>*)wkget,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
714 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
715 zgelsd(&m,&n,&nrhs,(complex<r_8>*)a.Data(),&lda,
716 (complex<r_8>*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
717 (complex<r_8>*)work,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
718 delete [] rwork;
719 } else {
720 if(work) delete [] work; work=NULL;
721 if(iwork) delete [] iwork; iwork=NULL;
722 string tn = typeid(T).name();
723 cerr << " LapackServer::LeastSquareSolveSVD_DC(a,b) - Unsupported DataType T = " << tn << endl;
724 throw TypeMismatchExc("LapackServer::LeastSquareSolveSVD_DC(a,b) - Unsupported DataType (T)");
725 }
726
727 if(work) delete [] work; if(iwork) delete [] iwork;
728 if(info!=0 && Throw_On_Error) {
729 char serr[128]; sprintf(serr,"LeastSquareSolveSVD_DC_Error info=%d",info);
730 throw MathExc(serr);
731 }
732 return(info);
733#endif
734}
735
736
737////////////////////////////////////////////////////////////////////////////////////
738//! Interface to Lapack SVD driver s/d/c/zgesv().
739/*! Computes the vector of singular values of \b a. Input arrays
740 should have FortranMemoryMapping (column packed).
741 \param a : input m-by-n matrix
742 \param s : Vector of min(m,n) singular values (descending order)
743 \return : return code from lapack driver _gesvd()
744 */
745
746template <class T>
747int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s)
748{
749 return (SVDDriver(a, s, NULL, NULL) );
750}
751
752//! Interface to Lapack SVD driver s/d/c/zgesv().
753/*! Computes the vector of singular values of \b a, as well as
754 right and left singular vectors of \b a.
755 \f[
756 A = U \Sigma V^T , ( A = U \Sigma V^H \ complex)
757 \f]
758 \f[
759 A v_i = \sigma_i u_i \ and A^T u_i = \sigma_i v_i \ (A^H \ complex)
760 \f]
761 U and V are orthogonal (unitary) matrices.
762 \param a : input m-by-n matrix (in FortranMemoryMapping)
763 \param s : Vector of min(m,n) singular values (descending order)
764 \param u : m-by-m Matrix of left singular vectors
765 \param vt : Transpose of right singular vectors (n-by-n matrix).
766 \return : return code from lapack driver _gesvd()
767 */
768template <class T>
769int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s, TArray<T> & u, TArray<T> & vt)
770{
771 return (SVDDriver(a, s, &u, &vt) );
772}
773
774
775//! Interface to Lapack SVD driver s/d/c/zgesv().
776template <class T>
777int LapackServer<T>::SVDDriver(TArray<T>& a, TArray<T> & s, TArray<T>* up, TArray<T>* vtp)
778{
779 if ( ( a.NbDimensions() != 2 ) )
780 throw(SzMismatchError("LapackServer::SVDDriver(a, ...) a.NbDimensions() != 2"));
781
782 int_4 rowa = a.RowsKA();
783 int_4 cola = a.ColsKA();
784
785 if ( !a.IsPacked(rowa) )
786 throw(SzMismatchError("LapackServer::SVDDriver(a, ...) a Not Column Packed "));
787
788 int_4 m = a.Size(rowa);
789 int_4 n = a.Size(cola);
790 //unused: int_4 maxmn = (m > n) ? m : n;
791 int_4 minmn = (m < n) ? m : n;
792
793 char jobu, jobvt;
794 jobu = 'N';
795 jobvt = 'N';
796
797 sa_size_t sz[2];
798 if ( up != NULL) {
799 if ( dynamic_cast< TVector<T> * > (vtp) )
800 throw( TypeMismatchExc("LapackServer::SVDDriver() Wrong type (=TVector<T>) for u !") );
801 up->SetMemoryMapping(BaseArray::FortranMemoryMapping);
802 sz[0] = sz[1] = m;
803 up->ReSize(2, sz );
804 jobu = 'A';
805 }
806 else {
807 up = new TMatrix<T>(1,1);
808 jobu = 'N';
809 }
810 if ( vtp != NULL) {
811 if ( dynamic_cast< TVector<T> * > (vtp) )
812 throw( TypeMismatchExc("LapackServer::SVDDriver() Wrong type (=TVector<T>) for vt !") );
813 vtp->SetMemoryMapping(BaseArray::FortranMemoryMapping);
814 sz[0] = sz[1] = n;
815 vtp->ReSize(2, sz );
816 jobvt = 'A';
817 }
818 else {
819 vtp = new TMatrix<T>(1,1);
820 jobvt = 'N';
821 }
822
823 TVector<T> *vs = dynamic_cast< TVector<T> * > (&s);
824 if (vs) vs->ReSize(minmn);
825 else {
826 TMatrix<T> *ms = dynamic_cast< TMatrix<T> * > (&s);
827 if (ms) ms->ReSize(minmn,1);
828 else {
829 sz[0] = minmn; sz[1] = 1;
830 s.ReSize(1, sz);
831 }
832 }
833
834 int_4 lda = a.Step(a.ColsKA());
835 int_4 ldu = up->Step(up->ColsKA());
836 int_4 ldvt = vtp->Step(vtp->ColsKA());
837 int_4 info;
838
839 int_4 lwork = -1; // maxmn*5 *wspace_size_factor;
840 T * work = NULL; // = new T[lwork];
841 T wkget[2];
842
843 if (typeid(T) == typeid(r_4) ) {
844 sgesvd(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
845 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
846 (r_4 *)wkget, &lwork, &info);
847 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
848 sgesvd(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
849 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
850 (r_4 *)work, &lwork, &info);
851 } else if (typeid(T) == typeid(r_8) ) {
852 dgesvd(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
853 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
854 (r_8 *)wkget, &lwork, &info);
855 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
856 dgesvd(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
857 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
858 (r_8 *)work, &lwork, &info);
859 } else if (typeid(T) == typeid(complex<r_4>) ) {
860 r_4 * rwork = new r_4[5*minmn +GARDMEM];
861 r_4 * sloc = new r_4[minmn];
862 cgesvd(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
863 (r_4 *)sloc, (complex<r_4> *) up->Data(), &ldu,
864 (complex<r_4> *)vtp->Data(), &ldvt,
865 (complex<r_4> *)wkget, &lwork, (r_4 *)rwork, &info);
866 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
867 cgesvd(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
868 (r_4 *)sloc, (complex<r_4> *) up->Data(), &ldu,
869 (complex<r_4> *)vtp->Data(), &ldvt,
870 (complex<r_4> *)work, &lwork, (r_4 *)rwork, &info);
871 for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
872 delete [] rwork; delete [] sloc;
873 } else if (typeid(T) == typeid(complex<r_8>) ) {
874 r_8 * rwork = new r_8[5*minmn +GARDMEM];
875 r_8 * sloc = new r_8[minmn];
876 zgesvd(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
877 (r_8 *)sloc, (complex<r_8> *) up->Data(), &ldu,
878 (complex<r_8> *)vtp->Data(), &ldvt,
879 (complex<r_8> *)wkget, &lwork, (r_8 *)rwork, &info);
880 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
881 zgesvd(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
882 (r_8 *)sloc, (complex<r_8> *) up->Data(), &ldu,
883 (complex<r_8> *)vtp->Data(), &ldvt,
884 (complex<r_8> *)work, &lwork, (r_8 *)rwork, &info);
885 for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
886 delete [] rwork; delete [] sloc;
887 } else {
888 if(work) delete [] work; work=NULL;
889 if (jobu == 'N') delete up;
890 if (jobvt == 'N') delete vtp;
891 string tn = typeid(T).name();
892 cerr << " LapackServer::SVDDriver(...) - Unsupported DataType T = " << tn << endl;
893 throw TypeMismatchExc("LapackServer::SVDDriver(a,b) - Unsupported DataType (T)");
894 }
895
896 if(work) delete [] work;
897 if (jobu == 'N') delete up;
898 if (jobvt == 'N') delete vtp;
899 if(info!=0 && Throw_On_Error) {
900 char serr[128]; sprintf(serr,"SVDDriver_Error info=%d",info);
901 throw MathExc(serr);
902 }
903 return(info);
904}
905
906
907//! Interface to Lapack SVD driver s/d/c/zgesdd().
908/*! Same as SVD but with Divide and Conquer method */
909template <class T>
910int LapackServer<T>::SVD_DC(TMatrix<T>& a, TVector<r_8>& s, TMatrix<T>& u, TMatrix<T>& vt)
911{
912#ifdef LAPACK_V2_EXTSOP
913 throw NotAvailableOperation("LapackServer::SVD_DC(a,b) NOT implemented in LapackV2") ;
914#else
915 if ( !a.IsPacked() )
916 throw(SzMismatchError("LapackServer::SVD_DC(a, ...) a Not Packed "));
917
918 int_4 m = a.NRows();
919 int_4 n = a.NCols();
920 int_4 maxmn = (m > n) ? m : n;
921 int_4 minmn = (m < n) ? m : n;
922 int_4 supermax = 4*minmn*minmn+4*minmn; if(maxmn>supermax) supermax=maxmn;
923
924 char jobz = 'A';
925
926 s.ReSize(minmn);
927 u.ReSize(m,m);
928 vt.ReSize(n,n);
929
930 int_4 lda = m;
931 int_4 ldu = m;
932 int_4 ldvt = n;
933 int_4 info;
934 int_4 lwork=-1;
935 T * work = NULL;
936 int_4 * iwork = NULL;
937 T wkget[2];
938
939 if(typeid(T) == typeid(r_4) ) {
940 r_4* sloc = new r_4[minmn];
941 iwork = new int_4[8*minmn +GARDMEM];
942 sgesdd(&jobz,&m,&n,(r_4*)a.Data(),&lda,
943 (r_4*)sloc,(r_4*)u.Data(),&ldu,(r_4*)vt.Data(),&ldvt,
944 (r_4*)wkget,&lwork,(int_4*)iwork,&info);
945 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
946 sgesdd(&jobz,&m,&n,(r_4*)a.Data(),&lda,
947 (r_4*)sloc,(r_4*)u.Data(),&ldu,(r_4*)vt.Data(),&ldvt,
948 (r_4*)work,&lwork,(int_4*)iwork,&info);
949 for(int_4 i=0;i<minmn;i++) s(i) = (r_8) sloc[i];
950 delete [] sloc;
951 } else if(typeid(T) == typeid(r_8) ) {
952 iwork = new int_4[8*minmn +GARDMEM];
953 dgesdd(&jobz,&m,&n,(r_8*)a.Data(),&lda,
954 (r_8*)s.Data(),(r_8*)u.Data(),&ldu,(r_8*)vt.Data(),&ldvt,
955 (r_8*)wkget,&lwork,(int_4*)iwork,&info);
956 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
957 dgesdd(&jobz,&m,&n,(r_8*)a.Data(),&lda,
958 (r_8*)s.Data(),(r_8*)u.Data(),&ldu,(r_8*)vt.Data(),&ldvt,
959 (r_8*)work,&lwork,(int_4*)iwork,&info);
960 } else if(typeid(T) == typeid(complex<r_4>) ) {
961 r_4* sloc = new r_4[minmn];
962 r_4* rwork = new r_4[5*minmn*minmn+5*minmn +GARDMEM];
963 iwork = new int_4[8*minmn +GARDMEM];
964 cgesdd(&jobz,&m,&n,(complex<r_4>*)a.Data(),&lda,
965 (r_4*)sloc,(complex<r_4>*)u.Data(),&ldu,(complex<r_4>*)vt.Data(),&ldvt,
966 (complex<r_4>*)wkget,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
967 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
968 cgesdd(&jobz,&m,&n,(complex<r_4>*)a.Data(),&lda,
969 (r_4*)sloc,(complex<r_4>*)u.Data(),&ldu,(complex<r_4>*)vt.Data(),&ldvt,
970 (complex<r_4>*)work,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
971 for(int_4 i=0;i<minmn;i++) s(i) = (r_8) sloc[i];
972 delete [] sloc; delete [] rwork;
973 } else if(typeid(T) == typeid(complex<r_8>) ) {
974 r_8* rwork = new r_8[5*minmn*minmn+5*minmn +GARDMEM];
975 iwork = new int_4[8*minmn +GARDMEM];
976 zgesdd(&jobz,&m,&n,(complex<r_8>*)a.Data(),&lda,
977 (r_8*)s.Data(),(complex<r_8>*)u.Data(),&ldu,(complex<r_8>*)vt.Data(),&ldvt,
978 (complex<r_8>*)wkget,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
979 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
980 zgesdd(&jobz,&m,&n,(complex<r_8>*)a.Data(),&lda,
981 (r_8*)s.Data(),(complex<r_8>*)u.Data(),&ldu,(complex<r_8>*)vt.Data(),&ldvt,
982 (complex<r_8>*)work,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
983 delete [] rwork;
984 } else {
985 if(work) delete [] work; work=NULL;
986 if(iwork) delete [] iwork; iwork=NULL;
987 string tn = typeid(T).name();
988 cerr << " LapackServer::SVD_DC(...) - Unsupported DataType T = " << tn << endl;
989 throw TypeMismatchExc("LapackServer::SVD_DC - Unsupported DataType (T)");
990 }
991
992 if(work) delete [] work; if(iwork) delete [] iwork;
993 if(info!=0 && Throw_On_Error) {
994 char serr[128]; sprintf(serr,"SVD_DC_Error info=%d",info);
995 throw MathExc(serr);
996 }
997 return(info);
998#endif
999}
1000
1001
1002////////////////////////////////////////////////////////////////////////////////////
1003/*! Computes the eigen values and eigen vectors of a symetric (or hermitian) matrix \b a.
1004 Input arrays should have FortranMemoryMapping (column packed).
1005 \param a : input symetric (or hermitian) n-by-n matrix
1006 \param b : Vector of eigenvalues (descending order)
1007 \param eigenvector : if true compute eigenvectors, if not only eigenvalues
1008 \param a : on return array of eigenvectors (same order than eval, one vector = one column)
1009 \return : return code from lapack driver
1010 */
1011
1012template <class T>
1013int LapackServer<T>::LapackEigenSym(TArray<T>& a, TVector<r_8>& b, bool eigenvector)
1014{
1015 if ( a.NbDimensions() != 2 )
1016 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a NbDimensions() != 2"));
1017 int_4 rowa = a.RowsKA();
1018 int_4 cola = a.ColsKA();
1019 if ( a.Size(rowa) != a.Size(cola))
1020 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a Not a square Array"));
1021 if (!a.IsPacked(rowa))
1022 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a Not Column Packed"));
1023
1024 char uplo='U';
1025 char jobz='N'; if(eigenvector) jobz='V';
1026
1027 int_4 n = a.Size(rowa);
1028 int_4 lda = a.Step(cola);
1029 int_4 info = 0;
1030 int_4 lwork = -1;
1031 T * work = NULL;
1032 T wkget[2];
1033
1034 b.ReSize(n); b = 0.;
1035
1036 if (typeid(T) == typeid(r_4) ) {
1037 r_4* w = new r_4[n];
1038 ssyev(&jobz,&uplo,&n,(r_4 *)a.Data(),&lda,(r_4 *)w,(r_4 *)wkget,&lwork,&info);
1039 lwork = type2i4(&wkget[0],4); /* 3*n-1;*/ work = new T[lwork +GARDMEM];
1040 ssyev(&jobz,&uplo,&n,(r_4 *)a.Data(),&lda,(r_4 *)w,(r_4 *)work,&lwork,&info);
1041 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1042 delete [] w;
1043 } else if (typeid(T) == typeid(r_8) ) {
1044 r_8* w = new r_8[n];
1045 dsyev(&jobz,&uplo,&n,(r_8 *)a.Data(),&lda,(r_8 *)w,(r_8 *)wkget,&lwork,&info);
1046 lwork = type2i4(&wkget[0],8); /* 3*n-1;*/ work = new T[lwork +GARDMEM];
1047 dsyev(&jobz,&uplo,&n,(r_8 *)a.Data(),&lda,(r_8 *)w,(r_8 *)work,&lwork,&info);
1048 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1049 delete [] w;
1050 } else if (typeid(T) == typeid(complex<r_4>) ) {
1051 r_4* rwork = new r_4[3*n-2 +GARDMEM]; r_4* w = new r_4[n];
1052 cheev(&jobz,&uplo,&n,(complex<r_4> *)a.Data(),&lda,(r_4 *)w
1053 ,(complex<r_4> *)wkget,&lwork,(r_4 *)rwork,&info);
1054 lwork = type2i4(&wkget[0],4); /* 2*n-1;*/ work = new T[lwork +GARDMEM];
1055 cheev(&jobz,&uplo,&n,(complex<r_4> *)a.Data(),&lda,(r_4 *)w
1056 ,(complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
1057 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1058 delete [] rwork; delete [] w;
1059 } else if (typeid(T) == typeid(complex<r_8>) ) {
1060 r_8* rwork = new r_8[3*n-2 +GARDMEM]; r_8* w = new r_8[n];
1061 zheev(&jobz,&uplo,&n,(complex<r_8> *)a.Data(),&lda,(r_8 *)w
1062 ,(complex<r_8> *)wkget,&lwork,(r_8 *)rwork,&info);
1063 lwork = type2i4(&wkget[0],8); /* 2*n-1;*/ work = new T[lwork +GARDMEM];
1064 zheev(&jobz,&uplo,&n,(complex<r_8> *)a.Data(),&lda,(r_8 *)w
1065 ,(complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
1066 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1067 delete [] rwork; delete [] w;
1068 } else {
1069 if(work) delete [] work; work=NULL;
1070 string tn = typeid(T).name();
1071 cerr << " LapackServer::LapackEigenSym(a,b) - Unsupported DataType T = " << tn << endl;
1072 throw TypeMismatchExc("LapackServer::LapackEigenSym(a,b) - Unsupported DataType (T)");
1073 }
1074
1075 if(work) delete [] work;
1076 if(info!=0 && Throw_On_Error) {
1077 char serr[128]; sprintf(serr,"LapackEigenSym_Error info=%d",info);
1078 throw MathExc(serr);
1079 }
1080 return(info);
1081}
1082
1083////////////////////////////////////////////////////////////////////////////////////
1084/*! Computes the eigen values and eigen vectors of a general squared matrix \b a.
1085 Input arrays should have FortranMemoryMapping (column packed).
1086 \param a : input general n-by-n matrix
1087 \param eval : Vector of eigenvalues (complex double precision)
1088 \param evec : Matrix of eigenvector (same order than eval, one vector = one column)
1089 \param eigenvector : if true compute (right) eigenvectors, if not only eigenvalues
1090 \param a : on return array of eigenvectors
1091 \return : return code from lapack driver
1092 \verbatim
1093 eval : contains the computed eigenvalues.
1094 --- For real matrices "a" :
1095 Complex conjugate pairs of eigenvalues appear consecutively
1096 with the eigenvalue having the positive imaginary part first.
1097 evec : the right eigenvectors v(j) are stored one after another
1098 in the columns of evec, in the same order as their eigenvalues.
1099 --- For real matrices "a" :
1100 If the j-th eigenvalue is real, then v(j) = evec(:,j),
1101 the j-th column of evec.
1102 If the j-th and (j+1)-st eigenvalues form a complex
1103 conjugate pair, then v(j) = evec(:,j) + i*evec(:,j+1) and
1104 v(j+1) = evec(:,j) - i*evec(:,j+1).
1105 \endverbatim
1106*/
1107
1108template <class T>
1109int LapackServer<T>::LapackEigen(TArray<T>& a, TVector< complex<r_8> >& eval, TMatrix<T>& evec, bool eigenvector)
1110{
1111 if ( a.NbDimensions() != 2 )
1112 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a NbDimensions() != 2"));
1113 int_4 rowa = a.RowsKA();
1114 int_4 cola = a.ColsKA();
1115 if ( a.Size(rowa) != a.Size(cola))
1116 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a Not a square Array"));
1117 if (!a.IsPacked(rowa))
1118 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a Not Column Packed"));
1119
1120 char jobvl = 'N';
1121 char jobvr = 'N'; if(eigenvector) jobvr='V';
1122
1123 int_4 n = a.Size(rowa);
1124 int_4 lda = a.Step(cola);
1125 int_4 info = 0;
1126
1127 eval.ReSize(n); eval = complex<r_8>(0.,0.);
1128 if(eigenvector) {evec.ReSize(n,n); evec = (T) 0.;}
1129 int_4 ldvr = n, ldvl = 1;
1130
1131 int_4 lwork = -1;
1132 T * work = NULL;
1133 T wkget[2];
1134
1135 if (typeid(T) == typeid(r_4) ) {
1136 r_4* wr = new r_4[n]; r_4* wi = new r_4[n]; r_4* vl = NULL;
1137 sgeev(&jobvl,&jobvr,&n,(r_4 *)a.Data(),&lda,(r_4 *)wr,(r_4 *)wi,
1138 (r_4 *)vl,&ldvl,(r_4 *)evec.Data(),&ldvr,
1139 (r_4 *)wkget,&lwork,&info);
1140 lwork = type2i4(&wkget[0],4); /* 4*n;*/ work = new T[lwork +GARDMEM];
1141 sgeev(&jobvl,&jobvr,&n,(r_4 *)a.Data(),&lda,(r_4 *)wr,(r_4 *)wi,
1142 (r_4 *)vl,&ldvl,(r_4 *)evec.Data(),&ldvr,
1143 (r_4 *)work,&lwork,&info);
1144 if(info==0) for(int i=0;i<n;i++) eval(i) = complex<r_8>(wr[i],wi[i]);
1145 delete [] wr; delete [] wi;
1146 } else if (typeid(T) == typeid(r_8) ) {
1147 r_8* wr = new r_8[n]; r_8* wi = new r_8[n]; r_8* vl = NULL;
1148 dgeev(&jobvl,&jobvr,&n,(r_8 *)a.Data(),&lda,(r_8 *)wr,(r_8 *)wi,
1149 (r_8 *)vl,&ldvl,(r_8 *)evec.Data(),&ldvr,
1150 (r_8 *)wkget,&lwork,&info);
1151 lwork = type2i4(&wkget[0],8); /* 4*n;*/ work = new T[lwork +GARDMEM];
1152 dgeev(&jobvl,&jobvr,&n,(r_8 *)a.Data(),&lda,(r_8 *)wr,(r_8 *)wi,
1153 (r_8 *)vl,&ldvl,(r_8 *)evec.Data(),&ldvr,
1154 (r_8 *)work,&lwork,&info);
1155 if(info==0) for(int i=0;i<n;i++) eval(i) = complex<r_8>(wr[i],wi[i]);
1156 delete [] wr; delete [] wi;
1157 } else if (typeid(T) == typeid(complex<r_4>) ) {
1158 r_4* rwork = new r_4[2*n +GARDMEM]; r_4* vl = NULL; TVector< complex<r_4> > w(n);
1159 cgeev(&jobvl,&jobvr,&n,(complex<r_4> *)a.Data(),&lda,(complex<r_4> *)w.Data(),
1160 (complex<r_4> *)vl,&ldvl,(complex<r_4> *)evec.Data(),&ldvr,
1161 (complex<r_4> *)wkget,&lwork,(r_4 *)rwork,&info);
1162 lwork = type2i4(&wkget[0],4); /* 2*n;*/ work = new T[lwork +GARDMEM];
1163 cgeev(&jobvl,&jobvr,&n,(complex<r_4> *)a.Data(),&lda,(complex<r_4> *)w.Data(),
1164 (complex<r_4> *)vl,&ldvl,(complex<r_4> *)evec.Data(),&ldvr,
1165 (complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
1166 if(info==0) for(int i=0;i<n;i++) eval(i) = w(i);
1167 delete [] rwork;
1168 } else if (typeid(T) == typeid(complex<r_8>) ) {
1169 r_8* rwork = new r_8[2*n +GARDMEM]; r_8* vl = NULL;
1170 zgeev(&jobvl,&jobvr,&n,(complex<r_8> *)a.Data(),&lda,(complex<r_8> *)eval.Data(),
1171 (complex<r_8> *)vl,&ldvl,(complex<r_8> *)evec.Data(),&ldvr,
1172 (complex<r_8> *)wkget,&lwork,(r_8 *)rwork,&info);
1173 lwork = type2i4(&wkget[0],8); /* 2*n;*/ work = new T[lwork +GARDMEM];
1174 zgeev(&jobvl,&jobvr,&n,(complex<r_8> *)a.Data(),&lda,(complex<r_8> *)eval.Data(),
1175 (complex<r_8> *)vl,&ldvl,(complex<r_8> *)evec.Data(),&ldvr,
1176 (complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
1177 delete [] rwork;
1178 } else {
1179 if(work) delete [] work; work=NULL;
1180 string tn = typeid(T).name();
1181 cerr << " LapackServer::LapackEigen(a,b) - Unsupported DataType T = " << tn << endl;
1182 throw TypeMismatchExc("LapackServer::LapackEigen(a,b) - Unsupported DataType (T)");
1183 }
1184
1185 if(work) delete [] work;
1186 if(info!=0 && Throw_On_Error) {
1187 char serr[128]; sprintf(serr,"LapackEigen_Error info=%d",info);
1188 throw MathExc(serr);
1189 }
1190 return(info);
1191}
1192
1193
1194
1195
1196///////////////////////////////////////////////////////////////
1197#ifdef __CXX_PRAGMA_TEMPLATES__
1198#pragma define_template LapackServer<r_4>
1199#pragma define_template LapackServer<r_8>
1200#pragma define_template LapackServer< complex<r_4> >
1201#pragma define_template LapackServer< complex<r_8> >
1202#endif
1203
1204#if defined(ANSI_TEMPLATES) || defined(GNU_TEMPLATES)
1205namespace SOPHYA {
1206template class LapackServer<r_4>;
1207template class LapackServer<r_8>;
1208template class LapackServer< complex<r_4> >;
1209template class LapackServer< complex<r_8> >;
1210}
1211#endif
1212
1213#if defined(Linux)
1214// Pour le link avec f2c sous Linux
1215extern "C" {
1216 void MAIN__();
1217}
1218
1219void MAIN__()
1220{
1221 cerr << "MAIN__() function for linking with libf2c.a " << endl;
1222 cerr << " This function should never be called !!! " << endl;
1223 throw PError("MAIN__() should not be called - see intflapack.cc");
1224}
1225#endif
Note: See TracBrowser for help on using the repository browser.