source: Sophya/trunk/SophyaExt/LinAlg/intflapack.cc@ 3619

Last change on this file since 3619 was 3619, checked in by cmv, 16 years ago

add various #include<> for g++ 4.3 (jaunty 9.04), cmv 05/05/2009

File size: 48.7 KB
Line 
1#include <iostream>
2#include <string.h>
3#include <math.h>
4#include "sopnamsp.h"
5#include "intflapack.h"
6#include "sspvflags.h"
7
8#include "tvector.h"
9#include "tmatrix.h"
10#include <typeinfo>
11
12#define GARDMEM 5
13
14/*************** Pour memoire (Christophe) ***************
15Les dispositions memoires (FORTRAN) pour les vecteurs et matrices LAPACK:
16
171./ --- REAL X(N):
18 if an array X of dimension (N) holds a vector x,
19 then X(i) holds "x_i" for i=1,...,N
20
212./ --- REAL A(LDA,N):
22 if a two-dimensional array A of dimension (LDA,N) holds an m-by-n matrix A,
23 then A(i,j) holds "a_ij" for i=1,...,m et j=1,...,n (LDA must be at least m).
24 Note that array arguments are usually declared in the software as assumed-size
25 arrays (last dimension *), for example: REAL A(LDA,*)
26 --- Rangement en memoire:
27 | 11 12 13 14 |
28 Ex: Real A(4,4): A = | 21 22 23 24 |
29 | 31 32 33 34 |
30 | 41 42 43 44 |
31 memoire: {11 21 31 41} {12 22 32 42} {13 23 33 43} {14 24 34 44}
32 First indice (line) "i" varies then the second (column):
33 (put all the first column, then put all the second column,
34 ..., then put all the last column)
35***********************************************************/
36
37/*!
38 \defgroup LinAlg LinAlg module
39 This module contains classes and functions for complex linear
40 algebra on arrays. This module is intended mainly to have
41 classes implementing C++ interfaces between Sophya objects
42 and external linear algebra libraries, such as LAPACK.
43*/
44
45/*!
46 \class SOPHYA::LapackServer
47 \ingroup LinAlg
48 This class implements an interface to LAPACK library driver routines.
49 The LAPACK (Linear Algebra PACKage) is a collection high performance
50 routines to solve common problems in numerical linear algebra.
51 its is available from http://www.netlib.org.
52
53 The present version of LapackServer (Feb 2005) provides
54 interfaces for the linear system solver, singular value
55 decomposition (SVD), Least square solver and
56 eigen value / eigen vector decomposition.
57 Only arrays with BaseArray::FortranMemoryMapping
58 can be handled by LapackServer. LapackServer can be instanciated
59 for simple and double precision real or complex array types.
60 \warning The input array is overwritten in most cases.
61 The example below shows solving a linear system A*X = B
62
63 \code
64 #include "intflapack.h"
65 // ...
66 // Use FortranMemoryMapping as default
67 BaseArray::SetDefaultMemoryMapping(BaseArray::FortranMemoryMapping);
68 // Create an fill the arrays A and B
69 int n = 20;
70 Matrix A(n, n);
71 A = RandomSequence();
72 Vector X(n),B(n);
73 X = RandomSequence();
74 B = A*X;
75 // Solve the linear system A*X = B
76 LapackServer<r_8> lps;
77 lps.LinSolve(A,B);
78 // We get the result in B, which should be equal to X ...
79 // Compute the difference B-X ;
80 Vector diff = B-X;
81 \endcode
82
83*/
84
85/*
86 Decembre 2005 : Suite portage AIX xlC
87 On declare des noms en majuscule pour les routines fortran -
88 avec ou sans underscore _ , suivant les systemes
89*/
90#ifdef AIX
91
92#define ilaenv ilaenv
93
94#define sgesv sgesv
95#define dgesv dgesv
96#define cgesv cgesv
97#define zgesv zgesv
98
99#define ssysv ssysv
100#define dsysv dsysv
101#define csysv csysv
102#define zsysv zsysv
103
104#define sgels sgels
105#define dgels dgels
106#define cgels cgels
107#define zgels zgels
108
109#define sgelsd sgelsd
110#define dgelsd dgelsd
111#define cgelsd cgelsd
112#define zgelsd zgelsd
113
114#define sgesvd sgesvd
115#define dgesvd dgesvd
116#define cgesvd cgesvd
117#define zgesvd zgesvd
118
119#define sgesdd sgesdd
120#define dgesdd dgesdd
121#define cgesdd cgesdd
122#define zgesdd zgesdd
123
124#define ssyev ssyev
125#define dsyev dsyev
126#define cheev cheev
127#define zheev zheev
128
129#define sgeev sgeev
130#define dgeev dgeev
131#define cgeev cgeev
132#define zgeev zgeev
133
134#else
135#define ilaenv ilaenv_
136
137#define sgesv sgesv_
138#define dgesv dgesv_
139#define cgesv cgesv_
140#define zgesv zgesv_
141
142#define ssysv ssysv_
143#define dsysv dsysv_
144#define csysv csysv_
145#define zsysv zsysv_
146
147#define sgels sgels_
148#define dgels dgels_
149#define cgels cgels_
150#define zgels zgels_
151
152#define sgelsd sgelsd_
153#define dgelsd dgelsd_
154#define cgelsd cgelsd_
155#define zgelsd zgelsd_
156
157#define sgesvd sgesvd_
158#define dgesvd dgesvd_
159#define cgesvd cgesvd_
160#define zgesvd zgesvd_
161
162#define sgesdd sgesdd_
163#define dgesdd dgesdd_
164#define cgesdd cgesdd_
165#define zgesdd zgesdd_
166
167#define ssyev ssyev_
168#define dsyev dsyev_
169#define cheev cheev_
170#define zheev zheev_
171
172#define sgeev sgeev_
173#define dgeev dgeev_
174#define cgeev cgeev_
175#define zgeev zgeev_
176
177#endif
178////////////////////////////////////////////////////////////////////////////////////
179extern "C" {
180// Le calculateur de workingspace
181 int_4 ilaenv(int_4 *ispec,const char *name,const char *opts,int_4 *n1,int_4 *n2,int_4 *n3,int_4 *n4,
182 int_4 nc1,int_4 nc2);
183
184// Drivers pour resolution de systemes lineaires
185 void sgesv(int_4* n, int_4* nrhs, r_4* a, int_4* lda,
186 int_4* ipiv, r_4* b, int_4* ldb, int_4* info);
187 void dgesv(int_4* n, int_4* nrhs, r_8* a, int_4* lda,
188 int_4* ipiv, r_8* b, int_4* ldb, int_4* info);
189 void cgesv(int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
190 int_4* ipiv, complex<r_4>* b, int_4* ldb, int_4* info);
191 void zgesv(int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
192 int_4* ipiv, complex<r_8>* b, int_4* ldb, int_4* info);
193
194// Drivers pour resolution de systemes lineaires symetriques
195 void ssysv(char* uplo, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
196 int_4* ipiv, r_4* b, int_4* ldb,
197 r_4* work, int_4* lwork, int_4* info);
198 void dsysv(char* uplo, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
199 int_4* ipiv, r_8* b, int_4* ldb,
200 r_8* work, int_4* lwork, int_4* info);
201 void csysv(char* uplo, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
202 int_4* ipiv, complex<r_4>* b, int_4* ldb,
203 complex<r_4>* work, int_4* lwork, int_4* info);
204 void zsysv(char* uplo, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
205 int_4* ipiv, complex<r_8>* b, int_4* ldb,
206 complex<r_8>* work, int_4* lwork, int_4* info);
207
208// Driver pour resolution de systemes au sens de Xi2
209 void sgels(char * trans, int_4* m, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
210 r_4* b, int_4* ldb, r_4* work, int_4* lwork, int_4* info);
211 void dgels(char * trans, int_4* m, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
212 r_8* b, int_4* ldb, r_8* work, int_4* lwork, int_4* info);
213 void cgels(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
214 complex<r_4>* b, int_4* ldb, complex<r_4>* work, int_4* lwork, int_4* info);
215 void zgels(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
216 complex<r_8>* b, int_4* ldb, complex<r_8>* work, int_4* lwork, int_4* info);
217
218// Driver pour resolution de systemes au sens de Xi2 par SVD Divide & Conquer
219 void sgelsd(int_4* m,int_4* n,int_4* nrhs,r_4* a,int_4* lda,
220 r_4* b,int_4* ldb,r_4* s,r_4* rcond,int_4* rank,
221 r_4* work,int_4* lwork,int_4* iwork,int_4* info);
222 void dgelsd(int_4* m,int_4* n,int_4* nrhs,r_8* a,int_4* lda,
223 r_8* b,int_4* ldb,r_8* s,r_8* rcond,int_4* rank,
224 r_8* work,int_4* lwork,int_4* iwork,int_4* info);
225 void cgelsd(int_4* m,int_4* n,int_4* nrhs,complex<r_4>* a,int_4* lda,
226 complex<r_4>* b,int_4* ldb,r_4* s,r_4* rcond,int_4* rank,
227 complex<r_4>* work,int_4* lwork,r_4* rwork,int_4* iwork,int_4* info);
228 void zgelsd(int_4* m,int_4* n,int_4* nrhs,complex<r_8>* a,int_4* lda,
229 complex<r_8>* b,int_4* ldb,r_8* s,r_8* rcond,int_4* rank,
230 complex<r_8>* work,int_4* lwork,r_8* rwork,int_4* iwork,int_4* info);
231
232// Driver pour decomposition SVD
233 void sgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, r_4* a, int_4* lda,
234 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
235 r_4* work, int_4* lwork, int_4* info);
236 void dgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, r_8* a, int_4* lda,
237 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
238 r_8* work, int_4* lwork, int_4* info);
239 void cgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
240 r_4* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
241 complex<r_4>* work, int_4* lwork, r_4* rwork, int_4* info);
242 void zgesvd(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
243 r_8* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
244 complex<r_8>* work, int_4* lwork, r_8* rwork, int_4* info);
245
246// Driver pour decomposition SVD Divide and Conquer
247 void sgesdd(char* jobz, int_4* m, int_4* n, r_4* a, int_4* lda,
248 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
249 r_4* work, int_4* lwork, int_4* iwork, int_4* info);
250 void dgesdd(char* jobz, int_4* m, int_4* n, r_8* a, int_4* lda,
251 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
252 r_8* work, int_4* lwork, int_4* iwork, int_4* info);
253 void cgesdd(char* jobz, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
254 r_4* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
255 complex<r_4>* work, int_4* lwork, r_4* rwork, int_4* iwork, int_4* info);
256 void zgesdd(char* jobz, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
257 r_8* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
258 complex<r_8>* work, int_4* lwork, r_8* rwork, int_4* iwork, int_4* info);
259
260// Driver pour eigen decomposition for symetric/hermitian matrices
261 void ssyev(char* jobz, char* uplo, int_4* n, r_4* a, int_4* lda, r_4* w,
262 r_4* work, int_4 *lwork, int_4* info);
263 void dsyev(char* jobz, char* uplo, int_4* n, r_8* a, int_4* lda, r_8* w,
264 r_8* work, int_4 *lwork, int_4* info);
265 void cheev(char* jobz, char* uplo, int_4* n, complex<r_4>* a, int_4* lda, r_4* w,
266 complex<r_4>* work, int_4 *lwork, r_4* rwork, int_4* info);
267 void zheev(char* jobz, char* uplo, int_4* n, complex<r_8>* a, int_4* lda, r_8* w,
268 complex<r_8>* work, int_4 *lwork, r_8* rwork, int_4* info);
269
270// Driver pour eigen decomposition for general squared matrices
271 void sgeev(char* jobl, char* jobvr, int_4* n, r_4* a, int_4* lda, r_4* wr, r_4* wi,
272 r_4* vl, int_4* ldvl, r_4* vr, int_4* ldvr,
273 r_4* work, int_4 *lwork, int_4* info);
274 void dgeev(char* jobl, char* jobvr, int_4* n, r_8* a, int_4* lda, r_8* wr, r_8* wi,
275 r_8* vl, int_4* ldvl, r_8* vr, int_4* ldvr,
276 r_8* work, int_4 *lwork, int_4* info);
277 void cgeev(char* jobl, char* jobvr, int_4* n, complex<r_4>* a, int_4* lda, complex<r_4>* w,
278 complex<r_4>* vl, int_4* ldvl, complex<r_4>* vr, int_4* ldvr,
279 complex<r_4>* work, int_4 *lwork, r_4* rwork, int_4* info);
280 void zgeev(char* jobl, char* jobvr, int_4* n, complex<r_8>* a, int_4* lda, complex<r_8>* w,
281 complex<r_8>* vl, int_4* ldvl, complex<r_8>* vr, int_4* ldvr,
282 complex<r_8>* work, int_4 *lwork, r_8* rwork, int_4* info);
283
284}
285
286// -------------- Classe LapackServer<T> --------------
287
288////////////////////////////////////////////////////////////////////////////////////
289template <class T>
290LapackServer<T>::LapackServer(bool throw_on_error)
291 : Throw_On_Error(throw_on_error)
292{
293 SetWorkSpaceSizeFactor();
294}
295
296template <class T>
297LapackServer<T>::~LapackServer()
298{
299}
300
301////////////////////////////////////////////////////////////////////////////////////
302template <class T>
303int_4 LapackServer<T>::ilaenv_en_C(int_4 ispec,const char *name,const char *opts,int_4 n1,int_4 n2,int_4 n3,int_4 n4)
304{
305 int_4 nc1 = strlen(name), nc2 = strlen(opts), rc=0;
306 rc = ilaenv(&ispec,name,opts,&n1,&n2,&n3,&n4,nc1,nc2);
307 //cout<<"ilaenv_en_C("<<ispec<<","<<name<<"("<<nc1<<"),"<<opts<<"("<<nc2<<"),"
308 // <<n1<<","<<n2<<","<<n3<<","<<n4<<") = "<<rc<<endl;
309 return rc;
310}
311
312template <class T>
313int_4 LapackServer<T>::type2i4(void *val,int nbytes)
314// Retourne un entier contenant la valeur contenue dans val
315// - nbytes = nombre de bytes dans le contenu de val
316// ex: r_4 x = 3.4; type2i4(&x,4) -> 3
317// ex: r_8 x = 3.4; type2i4(&x,8) -> 3
318// ex: complex<r_4> x(3.4,7.8); type2i4(&x,4) -> 3
319// ex: complex<r_8> x(3.4,7.8); type2i4(&x,8) -> 3
320{
321 r_4* x4; r_8* x8; int_4 lw=0;
322 if(nbytes==4) {x4 = (r_4*)val; lw = (int_4)(*x4);}
323 else {x8 = (r_8*)val; lw = (int_4)(*x8);}
324 return lw;
325}
326
327////////////////////////////////////////////////////////////////////////////////////
328//! Interface to Lapack linear system solver driver s/d/c/zgesv().
329/*! Solve the linear system a * x = b using LU factorization.
330 Input arrays should have FortranMemory mapping (column packed).
331 \param a : input matrix, overwritten on output
332 \param b : input-output, input vector b, contains x on exit
333 \return : return code from lapack driver _gesv()
334 */
335template <class T>
336int LapackServer<T>::LinSolve(TArray<T>& a, TArray<T> & b)
337{
338 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
339 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
340
341 int_4 rowa = a.RowsKA();
342 int_4 cola = a.ColsKA();
343 int_4 rowb = b.RowsKA();
344 int_4 colb = b.ColsKA();
345 if ( a.Size(rowa) != a.Size(cola))
346 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Not a square Array"));
347 if ( a.Size(rowa) != b.Size(rowb))
348 throw(SzMismatchError("LapackServer::LinSolve(a,b) RowSize(a <> b) "));
349
350 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
351 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Column Packed"));
352
353 int_4 n = a.Size(rowa);
354 int_4 nrhs = b.Size(colb);
355 int_4 lda = a.Step(cola);
356 int_4 ldb = b.Step(colb);
357 int_4 info;
358 int_4* ipiv = new int_4[n];
359
360 if (typeid(T) == typeid(r_4) )
361 sgesv(&n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb, &info);
362 else if (typeid(T) == typeid(r_8) )
363 dgesv(&n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb, &info);
364 else if (typeid(T) == typeid(complex<r_4>) )
365 cgesv(&n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
366 (complex<r_4> *)b.Data(), &ldb, &info);
367 else if (typeid(T) == typeid(complex<r_8>) )
368 zgesv(&n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
369 (complex<r_8> *)b.Data(), &ldb, &info);
370 else {
371 delete[] ipiv;
372 string tn = typeid(T).name();
373 cerr << " LapackServer::LinSolve(a,b) - Unsupported DataType T = " << tn << endl;
374 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
375 }
376 delete[] ipiv;
377 if(info!=0 && Throw_On_Error) {
378 char serr[128]; sprintf(serr,"LinSolve_Error info=%d",info);
379 throw MathExc(serr);
380 }
381 return(info);
382}
383
384////////////////////////////////////////////////////////////////////////////////////
385//! Interface to Lapack linear system solver driver s/d/c/zsysv().
386/*! Solve the linear system a * x = b with a symetric matrix using LU factorization.
387 Input arrays should have FortranMemory mapping (column packed).
388 \param a : input matrix symetric , overwritten on output
389 \param b : input-output, input vector b, contains x on exit
390 \return : return code from lapack driver _sysv()
391 */
392template <class T>
393int LapackServer<T>::LinSolveSym(TArray<T>& a, TArray<T> & b)
394// --- REMARQUES DE CMV ---
395// 1./ contrairement a ce qui est dit dans la doc, il s'agit
396// de matrices SYMETRIQUES complexes et non HERMITIENNES !!!
397// 2./ pourquoi les routines de LinSolve pour des matrices symetriques
398// sont plus de deux fois plus lentes que les LinSolve generales sur OSF
399// et sensiblement plus lentes sous Linux ???
400{
401 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
402 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Or b NbDimensions() != 2"));
403 int_4 rowa = a.RowsKA();
404 int_4 cola = a.ColsKA();
405 int_4 rowb = b.RowsKA();
406 int_4 colb = b.ColsKA();
407 if ( a.Size(rowa) != a.Size(cola))
408 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Not a square Array"));
409 if ( a.Size(rowa) != b.Size(rowb))
410 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) RowSize(a <> b) "));
411
412 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
413 throw(SzMismatchError("LapackServer::LinSolveSym(a,b) a Or b Not Column Packed"));
414
415 int_4 n = a.Size(rowa);
416 int_4 nrhs = b.Size(colb);
417 int_4 lda = a.Step(cola);
418 int_4 ldb = b.Step(colb);
419 int_4 info = 0;
420 int_4* ipiv = new int_4[n];
421 int_4 lwork = -1;
422 T * work = NULL;
423 T wkget[2];
424
425 char uplo = 'U'; // char uplo = 'L';
426 char struplo[5]; struplo[0] = uplo; struplo[1] = '\0';
427
428 if (typeid(T) == typeid(r_4) ) {
429 ssysv(&uplo, &n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb,
430 (r_4 *)wkget, &lwork, &info);
431 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
432 ssysv(&uplo, &n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb,
433 (r_4 *)work, &lwork, &info);
434 } else if (typeid(T) == typeid(r_8) ) {
435 dsysv(&uplo, &n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb,
436 (r_8 *)wkget, &lwork, &info);
437 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
438 dsysv(&uplo, &n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb,
439 (r_8 *)work, &lwork, &info);
440 } else if (typeid(T) == typeid(complex<r_4>) ) {
441 csysv(&uplo, &n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
442 (complex<r_4> *)b.Data(), &ldb,
443 (complex<r_4> *)wkget, &lwork, &info);
444 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
445 csysv(&uplo, &n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
446 (complex<r_4> *)b.Data(), &ldb,
447 (complex<r_4> *)work, &lwork, &info);
448 } else if (typeid(T) == typeid(complex<r_8>) ) {
449 zsysv(&uplo, &n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
450 (complex<r_8> *)b.Data(), &ldb,
451 (complex<r_8> *)wkget, &lwork, &info);
452 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
453 zsysv(&uplo, &n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
454 (complex<r_8> *)b.Data(), &ldb,
455 (complex<r_8> *)work, &lwork, &info);
456 } else {
457 if(work) delete[] work;
458 delete[] ipiv;
459 string tn = typeid(T).name();
460 cerr << " LapackServer::LinSolveSym(a,b) - Unsupported DataType T = " << tn << endl;
461 throw TypeMismatchExc("LapackServer::LinSolveSym(a,b) - Unsupported DataType (T)");
462 }
463 if(work) delete[] work;
464 delete[] ipiv;
465 if(info!=0 && Throw_On_Error) {
466 char serr[128]; sprintf(serr,"LinSolveSym_Error info=%d",info);
467 throw MathExc(serr);
468 }
469 return(info);
470}
471
472////////////////////////////////////////////////////////////////////////////////////
473//! Interface to Lapack least squares solver driver s/d/c/zgels().
474/*! Solves the linear least squares problem defined by an m-by-n matrix
475 \b a and an m element vector \b b , using QR or LQ factorization .
476 A solution \b x to the overdetermined system of linear equations
477 b = a * x is computed, minimizing the norm of b-a*x.
478 Underdetermined systems (m<n) are not yet handled.
479 Inout arrays should have FortranMemory mapping (column packed).
480 \param a : input matrix, overwritten on output
481 \param b : input-output, input vector b, contains x on exit.
482 \return : return code from lapack driver _gels()
483 \warning : b is not resized.
484 */
485/*
486 $CHECK$ - A faire - cas m<n
487 If the linear system is underdetermined, the minimum norm
488 solution is computed.
489*/
490
491template <class T>
492int LapackServer<T>::LeastSquareSolve(TArray<T>& a, TArray<T> & b)
493{
494 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
495 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) a Or b NbDimensions() != 2"));
496
497 int_4 rowa = a.RowsKA();
498 int_4 cola = a.ColsKA();
499 int_4 rowb = b.RowsKA();
500 int_4 colb = b.ColsKA();
501
502
503 if ( a.Size(rowa) != b.Size(rowb))
504 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) RowSize(a <> b) "));
505
506 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
507 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) a Or b Not Column Packed"));
508
509 if ( a.Size(rowa) < a.Size(cola)) { // $CHECK$ - m<n a changer
510 cout << " LapackServer<T>::LeastSquareSolve() - m<n - Not yet implemented for "
511 << " underdetermined systems ! " << endl;
512 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) NRows<NCols - "));
513 }
514 int_4 m = a.Size(rowa);
515 int_4 n = a.Size(cola);
516 int_4 nrhs = b.Size(colb);
517
518 int_4 lda = a.Step(cola);
519 int_4 ldb = b.Step(colb);
520 int_4 info;
521
522 //unused: int_4 minmn = (m < n) ? m : n;
523 int_4 maxmn = (m > n) ? m : n;
524 int_4 maxmnrhs = (nrhs > maxmn) ? nrhs : maxmn;
525 if (maxmnrhs < 1) maxmnrhs = 1;
526
527 int_4 lwork = -1; //minmn+maxmnrhs*5;
528 T * work = NULL;
529 T wkget[2];
530
531 char trans = 'N';
532
533 if (typeid(T) == typeid(r_4) ) {
534 sgels(&trans, &m, &n, &nrhs, (r_4 *)a.Data(), &lda,
535 (r_4 *)b.Data(), &ldb, (r_4 *)wkget, &lwork, &info);
536 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
537 sgels(&trans, &m, &n, &nrhs, (r_4 *)a.Data(), &lda,
538 (r_4 *)b.Data(), &ldb, (r_4 *)work, &lwork, &info);
539 } else if (typeid(T) == typeid(r_8) ) {
540 dgels(&trans, &m, &n, &nrhs, (r_8 *)a.Data(), &lda,
541 (r_8 *)b.Data(), &ldb, (r_8 *)wkget, &lwork, &info);
542 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
543 dgels(&trans, &m, &n, &nrhs, (r_8 *)a.Data(), &lda,
544 (r_8 *)b.Data(), &ldb, (r_8 *)work, &lwork, &info);
545 } else if (typeid(T) == typeid(complex<r_4>) ) {
546 cgels(&trans, &m, &n, &nrhs, (complex<r_4> *)a.Data(), &lda,
547 (complex<r_4> *)b.Data(), &ldb, (complex<r_4> *)wkget, &lwork, &info);
548 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
549 cgels(&trans, &m, &n, &nrhs, (complex<r_4> *)a.Data(), &lda,
550 (complex<r_4> *)b.Data(), &ldb, (complex<r_4> *)work, &lwork, &info);
551 } else if (typeid(T) == typeid(complex<r_8>) ) {
552 zgels(&trans, &m, &n, &nrhs, (complex<r_8> *)a.Data(), &lda,
553 (complex<r_8> *)b.Data(), &ldb, (complex<r_8> *)wkget, &lwork, &info);
554 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
555 zgels(&trans, &m, &n, &nrhs, (complex<r_8> *)a.Data(), &lda,
556 (complex<r_8> *)b.Data(), &ldb, (complex<r_8> *)work, &lwork, &info);
557 } else {
558 if(work) delete [] work; work=NULL;
559 string tn = typeid(T).name();
560 cerr << " LapackServer::LeastSquareSolve(a,b) - Unsupported DataType T = " << tn << endl;
561 throw TypeMismatchExc("LapackServer::LeastSquareSolve(a,b) - Unsupported DataType (T)");
562 }
563 if(work) delete [] work;
564 if(info!=0 && Throw_On_Error) {
565 char serr[128]; sprintf(serr,"LeastSquareSolve_Error info=%d",info);
566 throw MathExc(serr);
567 }
568 return(info);
569}
570
571////////////////////////////////////////////////////////////////////////////////////
572//! Square matrix inversion using Lapack linear system solver
573/*! Compute the inverse of a square matrix using linear system solver routine
574 Input arrays should have FortranMemory mapping (column packed).
575 \param a : input matrix, overwritten on output
576 \param ainv : output matrix, contains inverse(a) on exit.
577 ainv is allocated if it has size 0
578 If not allocated, ainv is automatically
579 \return : return code from LapackServer::LinSolve()
580 \sa LapackServer::LinSolve()
581 */
582template <class T>
583int LapackServer<T>::ComputeInverse(TMatrix<T>& a, TMatrix<T> & ainv)
584{
585 if ( a.NbDimensions() != 2 )
586 throw(SzMismatchError("LapackServer::Inverse() NDim(a) != 2"));
587 if ( a.GetMemoryMapping() != BaseArray::FortranMemoryMapping )
588 throw(SzMismatchError("LapackServer::Inverse() a NOT in FortranMemoryMapping"));
589 if ( a.NRows() != a.NCols() )
590 throw(SzMismatchError("LapackServer::Inverse() a NOT square matrix (a.NRows!=a.NCols)"));
591 if (ainv.IsAllocated()) {
592 bool smo, ssz;
593 ssz = a.CompareSizes(ainv, smo);
594 if ( (ssz == false) || (smo == false) )
595 throw(SzMismatchError("LapackServer::Inverse() ainv<>a Size/MemOrg mismatch "));
596 }
597 else ainv.SetSize(a.NRows(), a.NCols(), BaseArray::FortranMemoryMapping, false);
598 ainv = IdentityMatrix();
599 return LinSolve(a, ainv);
600}
601
602////////////////////////////////////////////////////////////////////////////////////
603//! Interface to Lapack least squares solver driver s/d/c/zgelsd().
604/*! Solves the linear least squares problem defined by an m-by-n matrix
605 \b a and an m element vector \b b , using SVD factorization Divide and Conquer.
606 Inout arrays should have FortranMemory mapping (column packed).
607 \param rcond : definition of zero value (S(i) <= RCOND*S(0) are treated as zero).
608 If RCOND < 0, machine precision is used instead.
609 \param a : input matrix, overwritten on output
610 \param b : input vector b overwritten by solution on output (beware of size changing)
611 \param x : output matrix of solutions.
612 \param rank : output the rank of the matrix.
613 \return : return code from lapack driver _gelsd()
614 \warning : b is not resized.
615 */
616template <class T>
617int LapackServer<T>::LeastSquareSolveSVD_DC(TMatrix<T>& a,TMatrix<T>& b,TVector<r_8>& s,int_4& rank,r_8 rcond)
618{
619#ifdef LAPACK_V2_EXTSOP
620 throw NotAvailableOperation("LapackServer::LeastSquareSolveSVD_DC(a,b) NOT implemented in LapackV2") ;
621#else
622 if ( ( a.NbDimensions() != 2 ) )
623 throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) a != 2"));
624
625 if (!a.IsPacked() || !b.IsPacked())
626 throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) a Or b Not Packed"));
627
628 int_4 m = a.NRows();
629 int_4 n = a.NCols();
630
631 if(b.NRows() != m)
632 throw(SzMismatchError("LapackServer::LeastSquareSolveSVD_DC(a,b) bad matching dim between a and b"));
633
634 int_4 nrhs = b.NCols();
635 int_4 minmn = (m < n) ? m : n;
636 int_4 maxmn = (m > n) ? m : n;
637
638 int_4 lda = m;
639 int_4 ldb = maxmn;
640 int_4 info;
641
642 { // Use {} for automatic des-allocation of "bsave"
643 TMatrix<T> bsave(m,nrhs); bsave.SetMemoryMapping(BaseArray::FortranMemoryMapping);
644 bsave = b;
645 b.ReSize(maxmn,nrhs); b = (T) 0;
646 for(int i=0;i<m;i++) for(int j=0;j<nrhs;j++) b(i,j) = bsave(i,j);
647 } // Use {} for automatic des-allocation of "bsave"
648 s.ReSize(minmn);
649
650 int_4 smlsiz = 25; // Normallement ilaenv_en_C(9,...) renvoie toujours 25
651 if(typeid(T) == typeid(r_4) ) smlsiz = ilaenv_en_C(9,"SGELSD"," ",0,0,0,0);
652 else if(typeid(T) == typeid(r_8) ) smlsiz = ilaenv_en_C(9,"DGELSD"," ",0,0,0,0);
653 else if(typeid(T) == typeid(complex<r_4>) ) smlsiz = ilaenv_en_C(9,"CGELSD"," ",0,0,0,0);
654 else if(typeid(T) == typeid(complex<r_8>) ) smlsiz = ilaenv_en_C(9,"ZGELSD"," ",0,0,0,0);
655 if(smlsiz<0) smlsiz = 0;
656 r_8 dum = log((r_8)minmn/(r_8)(smlsiz+1.)) / log(2.);
657 int_4 nlvl = int_4(dum) + 1; if(nlvl<0) nlvl = 0;
658
659 T * work = NULL;
660 int_4 * iwork = NULL;
661 int_4 lwork=-1, lrwork;
662 T wkget[2];
663
664 if(typeid(T) == typeid(r_4) ) {
665 r_4* sloc = new r_4[minmn];
666 r_4 srcond = rcond;
667 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
668 sgelsd(&m,&n,&nrhs,(r_4*)a.Data(),&lda,
669 (r_4*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
670 (r_4*)wkget,&lwork,(int_4*)iwork,&info);
671 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
672 sgelsd(&m,&n,&nrhs,(r_4*)a.Data(),&lda,
673 (r_4*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
674 (r_4*)work,&lwork,(int_4*)iwork,&info);
675 for(int_4 i=0;i<minmn;i++) s(i) = sloc[i];
676 delete [] sloc;
677 } else if(typeid(T) == typeid(r_8) ) {
678 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
679 dgelsd(&m,&n,&nrhs,(r_8*)a.Data(),&lda,
680 (r_8*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
681 (r_8*)wkget,&lwork,(int_4*)iwork,&info);
682 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
683 dgelsd(&m,&n,&nrhs,(r_8*)a.Data(),&lda,
684 (r_8*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
685 (r_8*)work,&lwork,(int_4*)iwork,&info);
686 } else if(typeid(T) == typeid(complex<r_4>) ) {
687 // Cf meme remarque que ci-dessous (complex<r_8)
688 lrwork = 10*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + 3*smlsiz*nrhs + (smlsiz+1)*(smlsiz+1);
689 int_4 lrwork_d = 12*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + minmn*nrhs + (smlsiz+1)*(smlsiz+1);
690 if(lrwork_d > lrwork) lrwork = lrwork_d;
691 r_4* rwork = new r_4[lrwork +GARDMEM];
692 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
693 r_4* sloc = new r_4[minmn];
694 r_4 srcond = rcond;
695 cgelsd(&m,&n,&nrhs,(complex<r_4>*)a.Data(),&lda,
696 (complex<r_4>*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
697 (complex<r_4>*)wkget,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
698 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
699 cgelsd(&m,&n,&nrhs,(complex<r_4>*)a.Data(),&lda,
700 (complex<r_4>*)b.Data(),&ldb,(r_4*)sloc,&srcond,&rank,
701 (complex<r_4>*)work,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
702 for(int_4 i=0;i<minmn;i++) s(i) = sloc[i];
703 delete [] sloc; delete [] rwork;
704 } else if(typeid(T) == typeid(complex<r_8>) ) {
705 // CMV: Bizarrement, la formule donnee dans zgelsd() plante pour des N grands (500)
706 // On prend (par analogie) la formule pour "lwork" de dgelsd()
707 lrwork = 10*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + 3*smlsiz*nrhs + (smlsiz+1)*(smlsiz+1);
708 int_4 lrwork_d = 12*minmn + 2*minmn*smlsiz + 8*minmn*nlvl + minmn*nrhs + (smlsiz+1)*(smlsiz+1);
709 if(lrwork_d > lrwork) lrwork = lrwork_d;
710 r_8* rwork = new r_8[lrwork +GARDMEM];
711 iwork = new int_4[3*minmn*nlvl+11*minmn +GARDMEM];
712 zgelsd(&m,&n,&nrhs,(complex<r_8>*)a.Data(),&lda,
713 (complex<r_8>*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
714 (complex<r_8>*)wkget,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
715 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
716 zgelsd(&m,&n,&nrhs,(complex<r_8>*)a.Data(),&lda,
717 (complex<r_8>*)b.Data(),&ldb,(r_8*)s.Data(),&rcond,&rank,
718 (complex<r_8>*)work,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
719 delete [] rwork;
720 } else {
721 if(work) delete [] work; work=NULL;
722 if(iwork) delete [] iwork; iwork=NULL;
723 string tn = typeid(T).name();
724 cerr << " LapackServer::LeastSquareSolveSVD_DC(a,b) - Unsupported DataType T = " << tn << endl;
725 throw TypeMismatchExc("LapackServer::LeastSquareSolveSVD_DC(a,b) - Unsupported DataType (T)");
726 }
727
728 if(work) delete [] work; if(iwork) delete [] iwork;
729 if(info!=0 && Throw_On_Error) {
730 char serr[128]; sprintf(serr,"LeastSquareSolveSVD_DC_Error info=%d",info);
731 throw MathExc(serr);
732 }
733 return(info);
734#endif
735}
736
737
738////////////////////////////////////////////////////////////////////////////////////
739//! Interface to Lapack SVD driver s/d/c/zgesv().
740/*! Computes the vector of singular values of \b a. Input arrays
741 should have FortranMemoryMapping (column packed).
742 \param a : input m-by-n matrix
743 \param s : Vector of min(m,n) singular values (descending order)
744 \return : return code from lapack driver _gesvd()
745 */
746
747template <class T>
748int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s)
749{
750 return (SVDDriver(a, s, NULL, NULL) );
751}
752
753//! Interface to Lapack SVD driver s/d/c/zgesv().
754/*! Computes the vector of singular values of \b a, as well as
755 right and left singular vectors of \b a.
756 \f[
757 A = U \Sigma V^T , ( A = U \Sigma V^H \ complex)
758 \f]
759 \f[
760 A v_i = \sigma_i u_i \ and A^T u_i = \sigma_i v_i \ (A^H \ complex)
761 \f]
762 U and V are orthogonal (unitary) matrices.
763 \param a : input m-by-n matrix (in FortranMemoryMapping)
764 \param s : Vector of min(m,n) singular values (descending order)
765 \param u : m-by-m Matrix of left singular vectors
766 \param vt : Transpose of right singular vectors (n-by-n matrix).
767 \return : return code from lapack driver _gesvd()
768 */
769template <class T>
770int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s, TArray<T> & u, TArray<T> & vt)
771{
772 return (SVDDriver(a, s, &u, &vt) );
773}
774
775
776//! Interface to Lapack SVD driver s/d/c/zgesv().
777template <class T>
778int LapackServer<T>::SVDDriver(TArray<T>& a, TArray<T> & s, TArray<T>* up, TArray<T>* vtp)
779{
780 if ( ( a.NbDimensions() != 2 ) )
781 throw(SzMismatchError("LapackServer::SVDDriver(a, ...) a.NbDimensions() != 2"));
782
783 int_4 rowa = a.RowsKA();
784 int_4 cola = a.ColsKA();
785
786 if ( !a.IsPacked(rowa) )
787 throw(SzMismatchError("LapackServer::SVDDriver(a, ...) a Not Column Packed "));
788
789 int_4 m = a.Size(rowa);
790 int_4 n = a.Size(cola);
791 //unused: int_4 maxmn = (m > n) ? m : n;
792 int_4 minmn = (m < n) ? m : n;
793
794 char jobu, jobvt;
795 jobu = 'N';
796 jobvt = 'N';
797
798 sa_size_t sz[2];
799 if ( up != NULL) {
800 if ( dynamic_cast< TVector<T> * > (vtp) )
801 throw( TypeMismatchExc("LapackServer::SVDDriver() Wrong type (=TVector<T>) for u !") );
802 up->SetMemoryMapping(BaseArray::FortranMemoryMapping);
803 sz[0] = sz[1] = m;
804 up->ReSize(2, sz );
805 jobu = 'A';
806 }
807 else {
808 up = new TMatrix<T>(1,1);
809 jobu = 'N';
810 }
811 if ( vtp != NULL) {
812 if ( dynamic_cast< TVector<T> * > (vtp) )
813 throw( TypeMismatchExc("LapackServer::SVDDriver() Wrong type (=TVector<T>) for vt !") );
814 vtp->SetMemoryMapping(BaseArray::FortranMemoryMapping);
815 sz[0] = sz[1] = n;
816 vtp->ReSize(2, sz );
817 jobvt = 'A';
818 }
819 else {
820 vtp = new TMatrix<T>(1,1);
821 jobvt = 'N';
822 }
823
824 TVector<T> *vs = dynamic_cast< TVector<T> * > (&s);
825 if (vs) vs->ReSize(minmn);
826 else {
827 TMatrix<T> *ms = dynamic_cast< TMatrix<T> * > (&s);
828 if (ms) ms->ReSize(minmn,1);
829 else {
830 sz[0] = minmn; sz[1] = 1;
831 s.ReSize(1, sz);
832 }
833 }
834
835 int_4 lda = a.Step(a.ColsKA());
836 int_4 ldu = up->Step(up->ColsKA());
837 int_4 ldvt = vtp->Step(vtp->ColsKA());
838 int_4 info;
839
840 int_4 lwork = -1; // maxmn*5 *wspace_size_factor;
841 T * work = NULL; // = new T[lwork];
842 T wkget[2];
843
844 if (typeid(T) == typeid(r_4) ) {
845 sgesvd(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
846 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
847 (r_4 *)wkget, &lwork, &info);
848 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
849 sgesvd(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
850 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
851 (r_4 *)work, &lwork, &info);
852 } else if (typeid(T) == typeid(r_8) ) {
853 dgesvd(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
854 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
855 (r_8 *)wkget, &lwork, &info);
856 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
857 dgesvd(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
858 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
859 (r_8 *)work, &lwork, &info);
860 } else if (typeid(T) == typeid(complex<r_4>) ) {
861 r_4 * rwork = new r_4[5*minmn +GARDMEM];
862 r_4 * sloc = new r_4[minmn];
863 cgesvd(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
864 (r_4 *)sloc, (complex<r_4> *) up->Data(), &ldu,
865 (complex<r_4> *)vtp->Data(), &ldvt,
866 (complex<r_4> *)wkget, &lwork, (r_4 *)rwork, &info);
867 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
868 cgesvd(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
869 (r_4 *)sloc, (complex<r_4> *) up->Data(), &ldu,
870 (complex<r_4> *)vtp->Data(), &ldvt,
871 (complex<r_4> *)work, &lwork, (r_4 *)rwork, &info);
872 for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
873 delete [] rwork; delete [] sloc;
874 } else if (typeid(T) == typeid(complex<r_8>) ) {
875 r_8 * rwork = new r_8[5*minmn +GARDMEM];
876 r_8 * sloc = new r_8[minmn];
877 zgesvd(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
878 (r_8 *)sloc, (complex<r_8> *) up->Data(), &ldu,
879 (complex<r_8> *)vtp->Data(), &ldvt,
880 (complex<r_8> *)wkget, &lwork, (r_8 *)rwork, &info);
881 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
882 zgesvd(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
883 (r_8 *)sloc, (complex<r_8> *) up->Data(), &ldu,
884 (complex<r_8> *)vtp->Data(), &ldvt,
885 (complex<r_8> *)work, &lwork, (r_8 *)rwork, &info);
886 for(int_4 i=0;i<minmn;i++) s[i] = sloc[i];
887 delete [] rwork; delete [] sloc;
888 } else {
889 if(work) delete [] work; work=NULL;
890 if (jobu == 'N') delete up;
891 if (jobvt == 'N') delete vtp;
892 string tn = typeid(T).name();
893 cerr << " LapackServer::SVDDriver(...) - Unsupported DataType T = " << tn << endl;
894 throw TypeMismatchExc("LapackServer::SVDDriver(a,b) - Unsupported DataType (T)");
895 }
896
897 if(work) delete [] work;
898 if (jobu == 'N') delete up;
899 if (jobvt == 'N') delete vtp;
900 if(info!=0 && Throw_On_Error) {
901 char serr[128]; sprintf(serr,"SVDDriver_Error info=%d",info);
902 throw MathExc(serr);
903 }
904 return(info);
905}
906
907
908//! Interface to Lapack SVD driver s/d/c/zgesdd().
909/*! Same as SVD but with Divide and Conquer method */
910template <class T>
911int LapackServer<T>::SVD_DC(TMatrix<T>& a, TVector<r_8>& s, TMatrix<T>& u, TMatrix<T>& vt)
912{
913#ifdef LAPACK_V2_EXTSOP
914 throw NotAvailableOperation("LapackServer::SVD_DC(a,b) NOT implemented in LapackV2") ;
915#else
916 if ( !a.IsPacked() )
917 throw(SzMismatchError("LapackServer::SVD_DC(a, ...) a Not Packed "));
918
919 int_4 m = a.NRows();
920 int_4 n = a.NCols();
921 int_4 maxmn = (m > n) ? m : n;
922 int_4 minmn = (m < n) ? m : n;
923 int_4 supermax = 4*minmn*minmn+4*minmn; if(maxmn>supermax) supermax=maxmn;
924
925 char jobz = 'A';
926
927 s.ReSize(minmn);
928 u.ReSize(m,m);
929 vt.ReSize(n,n);
930
931 int_4 lda = m;
932 int_4 ldu = m;
933 int_4 ldvt = n;
934 int_4 info;
935 int_4 lwork=-1;
936 T * work = NULL;
937 int_4 * iwork = NULL;
938 T wkget[2];
939
940 if(typeid(T) == typeid(r_4) ) {
941 r_4* sloc = new r_4[minmn];
942 iwork = new int_4[8*minmn +GARDMEM];
943 sgesdd(&jobz,&m,&n,(r_4*)a.Data(),&lda,
944 (r_4*)sloc,(r_4*)u.Data(),&ldu,(r_4*)vt.Data(),&ldvt,
945 (r_4*)wkget,&lwork,(int_4*)iwork,&info);
946 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
947 sgesdd(&jobz,&m,&n,(r_4*)a.Data(),&lda,
948 (r_4*)sloc,(r_4*)u.Data(),&ldu,(r_4*)vt.Data(),&ldvt,
949 (r_4*)work,&lwork,(int_4*)iwork,&info);
950 for(int_4 i=0;i<minmn;i++) s(i) = (r_8) sloc[i];
951 delete [] sloc;
952 } else if(typeid(T) == typeid(r_8) ) {
953 iwork = new int_4[8*minmn +GARDMEM];
954 dgesdd(&jobz,&m,&n,(r_8*)a.Data(),&lda,
955 (r_8*)s.Data(),(r_8*)u.Data(),&ldu,(r_8*)vt.Data(),&ldvt,
956 (r_8*)wkget,&lwork,(int_4*)iwork,&info);
957 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
958 dgesdd(&jobz,&m,&n,(r_8*)a.Data(),&lda,
959 (r_8*)s.Data(),(r_8*)u.Data(),&ldu,(r_8*)vt.Data(),&ldvt,
960 (r_8*)work,&lwork,(int_4*)iwork,&info);
961 } else if(typeid(T) == typeid(complex<r_4>) ) {
962 r_4* sloc = new r_4[minmn];
963 r_4* rwork = new r_4[5*minmn*minmn+5*minmn +GARDMEM];
964 iwork = new int_4[8*minmn +GARDMEM];
965 cgesdd(&jobz,&m,&n,(complex<r_4>*)a.Data(),&lda,
966 (r_4*)sloc,(complex<r_4>*)u.Data(),&ldu,(complex<r_4>*)vt.Data(),&ldvt,
967 (complex<r_4>*)wkget,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
968 lwork = type2i4(&wkget[0],4); work = new T[lwork +GARDMEM];
969 cgesdd(&jobz,&m,&n,(complex<r_4>*)a.Data(),&lda,
970 (r_4*)sloc,(complex<r_4>*)u.Data(),&ldu,(complex<r_4>*)vt.Data(),&ldvt,
971 (complex<r_4>*)work,&lwork,(r_4*)rwork,(int_4*)iwork,&info);
972 for(int_4 i=0;i<minmn;i++) s(i) = (r_8) sloc[i];
973 delete [] sloc; delete [] rwork;
974 } else if(typeid(T) == typeid(complex<r_8>) ) {
975 r_8* rwork = new r_8[5*minmn*minmn+5*minmn +GARDMEM];
976 iwork = new int_4[8*minmn +GARDMEM];
977 zgesdd(&jobz,&m,&n,(complex<r_8>*)a.Data(),&lda,
978 (r_8*)s.Data(),(complex<r_8>*)u.Data(),&ldu,(complex<r_8>*)vt.Data(),&ldvt,
979 (complex<r_8>*)wkget,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
980 lwork = type2i4(&wkget[0],8); work = new T[lwork +GARDMEM];
981 zgesdd(&jobz,&m,&n,(complex<r_8>*)a.Data(),&lda,
982 (r_8*)s.Data(),(complex<r_8>*)u.Data(),&ldu,(complex<r_8>*)vt.Data(),&ldvt,
983 (complex<r_8>*)work,&lwork,(r_8*)rwork,(int_4*)iwork,&info);
984 delete [] rwork;
985 } else {
986 if(work) delete [] work; work=NULL;
987 if(iwork) delete [] iwork; iwork=NULL;
988 string tn = typeid(T).name();
989 cerr << " LapackServer::SVD_DC(...) - Unsupported DataType T = " << tn << endl;
990 throw TypeMismatchExc("LapackServer::SVD_DC - Unsupported DataType (T)");
991 }
992
993 if(work) delete [] work; if(iwork) delete [] iwork;
994 if(info!=0 && Throw_On_Error) {
995 char serr[128]; sprintf(serr,"SVD_DC_Error info=%d",info);
996 throw MathExc(serr);
997 }
998 return(info);
999#endif
1000}
1001
1002
1003////////////////////////////////////////////////////////////////////////////////////
1004/*! Computes the eigen values and eigen vectors of a symetric (or hermitian) matrix \b a.
1005 Input arrays should have FortranMemoryMapping (column packed).
1006 \param a : input symetric (or hermitian) n-by-n matrix
1007 \param b : Vector of eigenvalues (descending order)
1008 \param eigenvector : if true compute eigenvectors, if not only eigenvalues
1009 \param a : on return array of eigenvectors (same order than eval, one vector = one column)
1010 \return : return code from lapack driver
1011 */
1012
1013template <class T>
1014int LapackServer<T>::LapackEigenSym(TArray<T>& a, TVector<r_8>& b, bool eigenvector)
1015{
1016 if ( a.NbDimensions() != 2 )
1017 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a NbDimensions() != 2"));
1018 int_4 rowa = a.RowsKA();
1019 int_4 cola = a.ColsKA();
1020 if ( a.Size(rowa) != a.Size(cola))
1021 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a Not a square Array"));
1022 if (!a.IsPacked(rowa))
1023 throw(SzMismatchError("LapackServer::LapackEigenSym(a,b) a Not Column Packed"));
1024
1025 char uplo='U';
1026 char jobz='N'; if(eigenvector) jobz='V';
1027
1028 int_4 n = a.Size(rowa);
1029 int_4 lda = a.Step(cola);
1030 int_4 info = 0;
1031 int_4 lwork = -1;
1032 T * work = NULL;
1033 T wkget[2];
1034
1035 b.ReSize(n); b = 0.;
1036
1037 if (typeid(T) == typeid(r_4) ) {
1038 r_4* w = new r_4[n];
1039 ssyev(&jobz,&uplo,&n,(r_4 *)a.Data(),&lda,(r_4 *)w,(r_4 *)wkget,&lwork,&info);
1040 lwork = type2i4(&wkget[0],4); /* 3*n-1;*/ work = new T[lwork +GARDMEM];
1041 ssyev(&jobz,&uplo,&n,(r_4 *)a.Data(),&lda,(r_4 *)w,(r_4 *)work,&lwork,&info);
1042 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1043 delete [] w;
1044 } else if (typeid(T) == typeid(r_8) ) {
1045 r_8* w = new r_8[n];
1046 dsyev(&jobz,&uplo,&n,(r_8 *)a.Data(),&lda,(r_8 *)w,(r_8 *)wkget,&lwork,&info);
1047 lwork = type2i4(&wkget[0],8); /* 3*n-1;*/ work = new T[lwork +GARDMEM];
1048 dsyev(&jobz,&uplo,&n,(r_8 *)a.Data(),&lda,(r_8 *)w,(r_8 *)work,&lwork,&info);
1049 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1050 delete [] w;
1051 } else if (typeid(T) == typeid(complex<r_4>) ) {
1052 r_4* rwork = new r_4[3*n-2 +GARDMEM]; r_4* w = new r_4[n];
1053 cheev(&jobz,&uplo,&n,(complex<r_4> *)a.Data(),&lda,(r_4 *)w
1054 ,(complex<r_4> *)wkget,&lwork,(r_4 *)rwork,&info);
1055 lwork = type2i4(&wkget[0],4); /* 2*n-1;*/ work = new T[lwork +GARDMEM];
1056 cheev(&jobz,&uplo,&n,(complex<r_4> *)a.Data(),&lda,(r_4 *)w
1057 ,(complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
1058 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1059 delete [] rwork; delete [] w;
1060 } else if (typeid(T) == typeid(complex<r_8>) ) {
1061 r_8* rwork = new r_8[3*n-2 +GARDMEM]; r_8* w = new r_8[n];
1062 zheev(&jobz,&uplo,&n,(complex<r_8> *)a.Data(),&lda,(r_8 *)w
1063 ,(complex<r_8> *)wkget,&lwork,(r_8 *)rwork,&info);
1064 lwork = type2i4(&wkget[0],8); /* 2*n-1;*/ work = new T[lwork +GARDMEM];
1065 zheev(&jobz,&uplo,&n,(complex<r_8> *)a.Data(),&lda,(r_8 *)w
1066 ,(complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
1067 if(info==0) for(int i=0;i<n;i++) b(i) = w[i];
1068 delete [] rwork; delete [] w;
1069 } else {
1070 if(work) delete [] work; work=NULL;
1071 string tn = typeid(T).name();
1072 cerr << " LapackServer::LapackEigenSym(a,b) - Unsupported DataType T = " << tn << endl;
1073 throw TypeMismatchExc("LapackServer::LapackEigenSym(a,b) - Unsupported DataType (T)");
1074 }
1075
1076 if(work) delete [] work;
1077 if(info!=0 && Throw_On_Error) {
1078 char serr[128]; sprintf(serr,"LapackEigenSym_Error info=%d",info);
1079 throw MathExc(serr);
1080 }
1081 return(info);
1082}
1083
1084////////////////////////////////////////////////////////////////////////////////////
1085/*! Computes the eigen values and eigen vectors of a general squared matrix \b a.
1086 Input arrays should have FortranMemoryMapping (column packed).
1087 \param a : input general n-by-n matrix
1088 \param eval : Vector of eigenvalues (complex double precision)
1089 \param evec : Matrix of eigenvector (same order than eval, one vector = one column)
1090 \param eigenvector : if true compute (right) eigenvectors, if not only eigenvalues
1091 \param a : on return array of eigenvectors
1092 \return : return code from lapack driver
1093 \verbatim
1094 eval : contains the computed eigenvalues.
1095 --- For real matrices "a" :
1096 Complex conjugate pairs of eigenvalues appear consecutively
1097 with the eigenvalue having the positive imaginary part first.
1098 evec : the right eigenvectors v(j) are stored one after another
1099 in the columns of evec, in the same order as their eigenvalues.
1100 --- For real matrices "a" :
1101 If the j-th eigenvalue is real, then v(j) = evec(:,j),
1102 the j-th column of evec.
1103 If the j-th and (j+1)-st eigenvalues form a complex
1104 conjugate pair, then v(j) = evec(:,j) + i*evec(:,j+1) and
1105 v(j+1) = evec(:,j) - i*evec(:,j+1).
1106 \endverbatim
1107*/
1108
1109template <class T>
1110int LapackServer<T>::LapackEigen(TArray<T>& a, TVector< complex<r_8> >& eval, TMatrix<T>& evec, bool eigenvector)
1111{
1112 if ( a.NbDimensions() != 2 )
1113 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a NbDimensions() != 2"));
1114 int_4 rowa = a.RowsKA();
1115 int_4 cola = a.ColsKA();
1116 if ( a.Size(rowa) != a.Size(cola))
1117 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a Not a square Array"));
1118 if (!a.IsPacked(rowa))
1119 throw(SzMismatchError("LapackServer::LapackEigen(a,b) a Not Column Packed"));
1120
1121 char jobvl = 'N';
1122 char jobvr = 'N'; if(eigenvector) jobvr='V';
1123
1124 int_4 n = a.Size(rowa);
1125 int_4 lda = a.Step(cola);
1126 int_4 info = 0;
1127
1128 eval.ReSize(n); eval = complex<r_8>(0.,0.);
1129 if(eigenvector) {evec.ReSize(n,n); evec = (T) 0.;}
1130 int_4 ldvr = n, ldvl = 1;
1131
1132 int_4 lwork = -1;
1133 T * work = NULL;
1134 T wkget[2];
1135
1136 if (typeid(T) == typeid(r_4) ) {
1137 r_4* wr = new r_4[n]; r_4* wi = new r_4[n]; r_4* vl = NULL;
1138 sgeev(&jobvl,&jobvr,&n,(r_4 *)a.Data(),&lda,(r_4 *)wr,(r_4 *)wi,
1139 (r_4 *)vl,&ldvl,(r_4 *)evec.Data(),&ldvr,
1140 (r_4 *)wkget,&lwork,&info);
1141 lwork = type2i4(&wkget[0],4); /* 4*n;*/ work = new T[lwork +GARDMEM];
1142 sgeev(&jobvl,&jobvr,&n,(r_4 *)a.Data(),&lda,(r_4 *)wr,(r_4 *)wi,
1143 (r_4 *)vl,&ldvl,(r_4 *)evec.Data(),&ldvr,
1144 (r_4 *)work,&lwork,&info);
1145 if(info==0) for(int i=0;i<n;i++) eval(i) = complex<r_8>(wr[i],wi[i]);
1146 delete [] wr; delete [] wi;
1147 } else if (typeid(T) == typeid(r_8) ) {
1148 r_8* wr = new r_8[n]; r_8* wi = new r_8[n]; r_8* vl = NULL;
1149 dgeev(&jobvl,&jobvr,&n,(r_8 *)a.Data(),&lda,(r_8 *)wr,(r_8 *)wi,
1150 (r_8 *)vl,&ldvl,(r_8 *)evec.Data(),&ldvr,
1151 (r_8 *)wkget,&lwork,&info);
1152 lwork = type2i4(&wkget[0],8); /* 4*n;*/ work = new T[lwork +GARDMEM];
1153 dgeev(&jobvl,&jobvr,&n,(r_8 *)a.Data(),&lda,(r_8 *)wr,(r_8 *)wi,
1154 (r_8 *)vl,&ldvl,(r_8 *)evec.Data(),&ldvr,
1155 (r_8 *)work,&lwork,&info);
1156 if(info==0) for(int i=0;i<n;i++) eval(i) = complex<r_8>(wr[i],wi[i]);
1157 delete [] wr; delete [] wi;
1158 } else if (typeid(T) == typeid(complex<r_4>) ) {
1159 r_4* rwork = new r_4[2*n +GARDMEM]; r_4* vl = NULL; TVector< complex<r_4> > w(n);
1160 cgeev(&jobvl,&jobvr,&n,(complex<r_4> *)a.Data(),&lda,(complex<r_4> *)w.Data(),
1161 (complex<r_4> *)vl,&ldvl,(complex<r_4> *)evec.Data(),&ldvr,
1162 (complex<r_4> *)wkget,&lwork,(r_4 *)rwork,&info);
1163 lwork = type2i4(&wkget[0],4); /* 2*n;*/ work = new T[lwork +GARDMEM];
1164 cgeev(&jobvl,&jobvr,&n,(complex<r_4> *)a.Data(),&lda,(complex<r_4> *)w.Data(),
1165 (complex<r_4> *)vl,&ldvl,(complex<r_4> *)evec.Data(),&ldvr,
1166 (complex<r_4> *)work,&lwork,(r_4 *)rwork,&info);
1167 if(info==0) for(int i=0;i<n;i++) eval(i) = w(i);
1168 delete [] rwork;
1169 } else if (typeid(T) == typeid(complex<r_8>) ) {
1170 r_8* rwork = new r_8[2*n +GARDMEM]; r_8* vl = NULL;
1171 zgeev(&jobvl,&jobvr,&n,(complex<r_8> *)a.Data(),&lda,(complex<r_8> *)eval.Data(),
1172 (complex<r_8> *)vl,&ldvl,(complex<r_8> *)evec.Data(),&ldvr,
1173 (complex<r_8> *)wkget,&lwork,(r_8 *)rwork,&info);
1174 lwork = type2i4(&wkget[0],8); /* 2*n;*/ work = new T[lwork +GARDMEM];
1175 zgeev(&jobvl,&jobvr,&n,(complex<r_8> *)a.Data(),&lda,(complex<r_8> *)eval.Data(),
1176 (complex<r_8> *)vl,&ldvl,(complex<r_8> *)evec.Data(),&ldvr,
1177 (complex<r_8> *)work,&lwork,(r_8 *)rwork,&info);
1178 delete [] rwork;
1179 } else {
1180 if(work) delete [] work; work=NULL;
1181 string tn = typeid(T).name();
1182 cerr << " LapackServer::LapackEigen(a,b) - Unsupported DataType T = " << tn << endl;
1183 throw TypeMismatchExc("LapackServer::LapackEigen(a,b) - Unsupported DataType (T)");
1184 }
1185
1186 if(work) delete [] work;
1187 if(info!=0 && Throw_On_Error) {
1188 char serr[128]; sprintf(serr,"LapackEigen_Error info=%d",info);
1189 throw MathExc(serr);
1190 }
1191 return(info);
1192}
1193
1194
1195
1196
1197///////////////////////////////////////////////////////////////
1198#ifdef __CXX_PRAGMA_TEMPLATES__
1199#pragma define_template LapackServer<r_4>
1200#pragma define_template LapackServer<r_8>
1201#pragma define_template LapackServer< complex<r_4> >
1202#pragma define_template LapackServer< complex<r_8> >
1203#endif
1204
1205#if defined(ANSI_TEMPLATES) || defined(GNU_TEMPLATES)
1206namespace SOPHYA {
1207template class LapackServer<r_4>;
1208template class LapackServer<r_8>;
1209template class LapackServer< complex<r_4> >;
1210template class LapackServer< complex<r_8> >;
1211}
1212#endif
1213
1214#if defined(Linux)
1215// Pour le link avec f2c sous Linux
1216extern "C" {
1217 void MAIN__();
1218}
1219
1220void MAIN__()
1221{
1222 cerr << "MAIN__() function for linking with libf2c.a " << endl;
1223 cerr << " This function should never be called !!! " << endl;
1224 throw PError("MAIN__() should not be called - see intflapack.cc");
1225}
1226#endif
Note: See TracBrowser for help on using the repository browser.