source: Sophya/trunk/SophyaExt/LinAlg/intflapack.cc@ 2322

Last change on this file since 2322 was 2322, checked in by cmv, 23 years ago
  • passage xxstream.h en xxstream
  • compile avec gcc_3.2, gcc_2.96 et cxx En 3.2 le seek from ::end semble marcher (voir Eval/COS/pbseekios.cc)

rz+cmv 11/2/2003

File size: 14.7 KB
Line 
1#include <iostream>
2#include "intflapack.h"
3#include "tvector.h"
4#include "tmatrix.h"
5#include <typeinfo>
6
7/*!
8 \defgroup LinAlg LinAlg module
9 This module contains classes and functions for complex linear
10 algebra on arrays. This module is intended mainly to have
11 classes implementing C++ interfaces between Sophya objects
12 and external linear algebra libraries, such as LAPACK.
13*/
14
15/*!
16 \class SOPHYA::LapackServer
17 \ingroup LinAlg
18 This class implements an interface to LAPACK library driver routines.
19 The LAPACK (Linear Algebra PACKage) is a collection high performance
20 routines to solve common problems in numerical linear algebra.
21 its is available from http://www.netlib.org.
22
23 The present version of our LapackServer (Feb 2001) provides only
24 interfaces for the linear system solver and singular value
25 decomposition (SVD). Only arrays with BaseArray::FortranMemoryMapping
26 can be handled by LapackServer. LapackServer can be instanciated
27 for simple and double precision real or complex array types.
28
29 The example below shows solving a linear system A*X = B
30
31 \code
32 #include "intflapack.h"
33 // ...
34 // Use FortranMemoryMapping as default
35 BaseArray::SetDefaultMemoryMapping(BaseArray::FortranMemoryMapping);
36 // Create an fill the arrays A and B
37 int n = 20;
38 Matrix A(n, n);
39 A = RandomSequence();
40 Vector X(n),B(n);
41 X = RandomSequence();
42 B = A*X;
43 // Solve the linear system A*X = B
44 LapackServer<r_8> lps;
45 lps.LinSolve(A,B);
46 // We get the result in B, which should be equal to X ...
47 // Compute the difference B-X ;
48 Vector diff = B-X;
49 \endcode
50
51*/
52
53extern "C" {
54// Drivers pour resolution de systemes lineaires
55 void sgesv_(int_4* n, int_4* nrhs, r_4* a, int_4* lda,
56 int_4* ipiv, r_4* b, int_4* ldb, int_4* info);
57 void dgesv_(int_4* n, int_4* nrhs, r_8* a, int_4* lda,
58 int_4* ipiv, r_8* b, int_4* ldb, int_4* info);
59 void cgesv_(int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
60 int_4* ipiv, complex<r_4>* b, int_4* ldb, int_4* info);
61 void zgesv_(int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
62 int_4* ipiv, complex<r_8>* b, int_4* ldb, int_4* info);
63
64 // Driver pour resolution de systemes au sens de Xi2
65 void sgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
66 r_4* b, int_4* ldb, r_4* work, int_4* lwork, int_4* info);
67 void dgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
68 r_8* b, int_4* ldb, r_8* work, int_4* lwork, int_4* info);
69 void cgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
70 complex<r_4>* b, int_4* ldb, complex<r_4>* work, int_4* lwork, int_4* info);
71 void zgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
72 complex<r_8>* b, int_4* ldb, complex<r_8>* work, int_4* lwork, int_4* info);
73
74// Driver pour decomposition SVD
75 void sgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_4* a, int_4* lda,
76 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
77 r_4* work, int_4* lwork, int_4* info);
78 void dgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_8* a, int_4* lda,
79 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
80 r_8* work, int_4* lwork, int_4* info);
81 void cgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
82 complex<r_4>* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
83 complex<r_4>* work, int_4* lwork, int_4* info);
84 void zgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
85 complex<r_8>* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
86 complex<r_8>* work, int_4* lwork, int_4* info);
87
88}
89
90
91// -------------- Classe LapackServer<T> --------------
92
93template <class T>
94LapackServer<T>::LapackServer()
95{
96 SetWorkSpaceSizeFactor();
97}
98
99template <class T>
100LapackServer<T>::~LapackServer()
101{
102}
103
104//! Interface to Lapack linear system solver driver s/d/c/zgesvd().
105/*! Solve the linear system a * x = b. Input arrays
106 should have FortranMemory mapping (column packed).
107 \param a : input matrix, overwritten on output
108 \param b : input-output, input vector b, contains x on exit
109 \return : return code from lapack driver _gesv()
110 */
111template <class T>
112int LapackServer<T>::LinSolve(TArray<T>& a, TArray<T> & b)
113{
114 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
115 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
116
117 int_4 rowa = a.RowsKA();
118 int_4 cola = a.ColsKA();
119 int_4 rowb = b.RowsKA();
120 int_4 colb = b.ColsKA();
121 if ( a.Size(rowa) != a.Size(cola))
122 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Not a square Array"));
123 if ( a.Size(rowa) != b.Size(rowb))
124 throw(SzMismatchError("LapackServer::LinSolve(a,b) RowSize(a <> b) "));
125
126 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
127 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Column Packed"));
128
129 int_4 n = a.Size(rowa);
130 int_4 nrhs = b.Size(colb);
131 int_4 lda = a.Step(cola);
132 int_4 ldb = b.Step(colb);
133 int_4 info;
134 int_4* ipiv = new int_4[n];
135
136 if (typeid(T) == typeid(r_4) )
137 sgesv_(&n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb, &info);
138 else if (typeid(T) == typeid(r_8) )
139 dgesv_(&n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb, &info);
140 else if (typeid(T) == typeid(complex<r_4>) )
141 cgesv_(&n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
142 (complex<r_4> *)b.Data(), &ldb, &info);
143 else if (typeid(T) == typeid(complex<r_8>) )
144 zgesv_(&n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
145 (complex<r_8> *)b.Data(), &ldb, &info);
146 else {
147 delete[] ipiv;
148 string tn = typeid(T).name();
149 cerr << " LapackServer::LinSolve(a,b) - Unsupported DataType T = " << tn << endl;
150 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
151 }
152 delete[] ipiv;
153 return(info);
154}
155
156//! Interface to Lapack least squares solver driver s/d/c/zgels().
157/*! Solves the linear least squares problem defined by an m-by-n matrix
158 \b a and an m element vector \b b .
159 A solution \b x to the overdetermined system of linear equations
160 b = a * x is computed, minimizing the norm of b-a*x.
161 Underdetermined systems (m<n) are not yet handled.
162 Inout arrays should have FortranMemory mapping (column packed).
163 \param a : input matrix, overwritten on output
164 \param b : input-output, input vector b, contains x on exit.
165 \return : return code from lapack driver _gels()
166 \warning : b is not resized.
167 */
168/*
169 $CHECK$ - A faire - cas m<n
170 If the linear system is underdetermined, the minimum norm
171 solution is computed.
172*/
173
174template <class T>
175int LapackServer<T>::LeastSquareSolve(TArray<T>& a, TArray<T> & b)
176{
177 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
178 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
179
180 int_4 rowa = a.RowsKA();
181 int_4 cola = a.ColsKA();
182 int_4 rowb = b.RowsKA();
183 int_4 colb = b.ColsKA();
184
185
186 if ( a.Size(rowa) != b.Size(rowb))
187 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) RowSize(a <> b) "));
188
189 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
190 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) a Or b Not Column Packed"));
191
192 if ( a.Size(rowa) < a.Size(cola)) { // $CHECK$ - m<n a changer
193 cout << " LapackServer<T>::LeastSquareSolve() - m<n - Not yet implemented for "
194 << " underdetermined systems ! " << endl;
195 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) NRows<NCols - "));
196 }
197 int_4 m = a.Size(rowa);
198 int_4 n = a.Size(cola);
199 int_4 nrhs = b.Size(colb);
200
201 int_4 lda = a.Step(cola);
202 int_4 ldb = b.Step(colb);
203 int_4 info;
204
205 int_4 minmn = (m < n) ? m : n;
206 int_4 maxmn = (m > n) ? m : n;
207 int_4 maxmnrhs = (nrhs > maxmn) ? nrhs : maxmn;
208 if (maxmnrhs < 1) maxmnrhs = 1;
209
210 int_4 lwork = minmn+maxmnrhs*5;
211 T * work = new T[lwork];
212
213 char trans = 'N';
214
215 if (typeid(T) == typeid(r_4) )
216 sgels_(&trans, &m, &n, &nrhs, (r_4 *)a.Data(), &lda,
217 (r_4 *)b.Data(), &ldb, (r_4 *)work, &lwork, &info);
218 else if (typeid(T) == typeid(r_8) )
219 dgels_(&trans, &m, &n, &nrhs, (r_8 *)a.Data(), &lda,
220 (r_8 *)b.Data(), &ldb, (r_8 *)work, &lwork, &info);
221 else if (typeid(T) == typeid(complex<r_4>) )
222 cgels_(&trans, &m, &n, &nrhs, (complex<r_4> *)a.Data(), &lda,
223 (complex<r_4> *)b.Data(), &ldb, (complex<r_4> *)work, &lwork, &info);
224 else if (typeid(T) == typeid(complex<r_8>) )
225 zgels_(&trans, &m, &n, &nrhs, (complex<r_8> *)a.Data(), &lda,
226 (complex<r_8> *)b.Data(), &ldb, (complex<r_8> *)work, &lwork, &info);
227 else {
228 delete[] work;
229 string tn = typeid(T).name();
230 cerr << " LapackServer::LeastSquareSolve(a,b) - Unsupported DataType T = " << tn << endl;
231 throw TypeMismatchExc("LapackServer::LeastSquareSolve(a,b) - Unsupported DataType (T)");
232 }
233 delete[] work;
234 return(info);
235}
236
237
238//! Interface to Lapack SVD driver s/d/c/zgesv().
239/*! Computes the vector of singular values of \b a. Input arrays
240 should have FortranMemoryMapping (column packed).
241 \param a : input m-by-n matrix
242 \param s : Vector of min(m,n) singular values (descending order)
243 \return : return code from lapack driver _gesvd()
244 */
245
246template <class T>
247int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s)
248{
249 return (SVDDriver(a, s, NULL, NULL) );
250}
251
252//! Interface to Lapack SVD driver s/d/c/zgesv().
253/*! Computes the vector of singular values of \b a, as well as
254 right and left singular vectors of \b a.
255 \f[
256 A = U \Sigma V^T , ( A = U \Sigma V^H \ complex)
257 \f]
258 \f[
259 A v_i = \sigma_i u_i \ and A^T u_i = \sigma_i v_i \ (A^H \ complex)
260 \f]
261 U and V are orthogonal (unitary) matrices.
262 \param a : input m-by-n matrix (in FotranMemoryMapping)
263 \param s : Vector of min(m,n) singular values (descending order)
264 \param u : Matrix of left singular vectors
265 \param vt : Transpose of right singular vectors.
266 \return : return code from lapack driver _gesvd()
267 */
268template <class T>
269int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s, TArray<T> & u, TArray<T> & vt)
270{
271 return (SVDDriver(a, s, &u, &vt) );
272}
273
274
275//! Interface to Lapack SVD driver s/d/c/zgesv().
276template <class T>
277int LapackServer<T>::SVDDriver(TArray<T>& a, TArray<T> & s, TArray<T>* up, TArray<T>* vtp)
278{
279 if ( ( a.NbDimensions() != 2 ) )
280 throw(SzMismatchError("LapackServer::SVD(a, ...) a.NbDimensions() != 2"));
281
282 int_4 rowa = a.RowsKA();
283 int_4 cola = a.ColsKA();
284
285 if ( !a.IsPacked(rowa) )
286 throw(SzMismatchError("LapackServer::SVD(a, ...) a Not Column Packed "));
287
288 int_4 m = a.Size(rowa);
289 int_4 n = a.Size(cola);
290 int_4 maxmn = (m > n) ? m : n;
291 int_4 minmn = (m < n) ? m : n;
292
293 char jobu, jobvt;
294 jobu = 'N';
295 jobvt = 'N';
296
297 sa_size_t sz[2];
298 if ( up != NULL) {
299 if ( dynamic_cast< TVector<T> * > (vtp) )
300 throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for u !") );
301 up->SetMemoryMapping(BaseArray::FortranMemoryMapping);
302 sz[0] = sz[1] = m;
303 up->ReSize(2, sz );
304 jobu = 'A';
305 }
306 else {
307 up = new TMatrix<T>(1,1);
308 jobu = 'N';
309 }
310 if ( vtp != NULL) {
311 if ( dynamic_cast< TVector<T> * > (vtp) )
312 throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for vt !") );
313 vtp->SetMemoryMapping(BaseArray::FortranMemoryMapping);
314 sz[0] = sz[1] = n;
315 vtp->ReSize(2, sz );
316 jobvt = 'A';
317 }
318 else {
319 vtp = new TMatrix<T>(1,1);
320 jobvt = 'N';
321 }
322
323 TVector<T> *vs = dynamic_cast< TVector<T> * > (&s);
324 if (vs) vs->ReSize(minmn);
325 else {
326 TMatrix<T> *ms = dynamic_cast< TMatrix<T> * > (&s);
327 if (ms) ms->ReSize(minmn,1);
328 else {
329 sz[0] = minmn; sz[1] = 1;
330 s.ReSize(1, sz);
331 }
332 }
333
334 int_4 lda = a.Step(a.ColsKA());
335 int_4 ldu = up->Step(up->ColsKA());
336 int_4 ldvt = vtp->Step(vtp->ColsKA());
337
338 int_4 lwork = maxmn*5*wspace_size_factor;
339 T * work = new T[lwork];
340 int_4 info;
341
342 if (typeid(T) == typeid(r_4) )
343 sgesvd_(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
344 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
345 (r_4 *)work, &lwork, &info);
346 else if (typeid(T) == typeid(r_8) )
347 dgesvd_(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
348 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
349 (r_8 *)work, &lwork, &info);
350 else if (typeid(T) == typeid(complex<r_4>) )
351 cgesvd_(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
352 (complex<r_4> *)s.Data(), (complex<r_4> *) up->Data(), &ldu,
353 (complex<r_4> *)vtp->Data(), &ldvt,
354 (complex<r_4> *)work, &lwork, &info);
355 else if (typeid(T) == typeid(complex<r_8>) )
356 zgesvd_(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
357 (complex<r_8> *)s.Data(), (complex<r_8> *) up->Data(), &ldu,
358 (complex<r_8> *)vtp->Data(), &ldvt,
359 (complex<r_8> *)work, &lwork, &info);
360 else {
361 if (jobu == 'N') delete up;
362 if (jobvt == 'N') delete vtp;
363 string tn = typeid(T).name();
364 cerr << " LapackServer::SVDDriver(...) - Unsupported DataType T = " << tn << endl;
365 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
366 }
367
368 if (jobu == 'N') delete up;
369 if (jobvt == 'N') delete vtp;
370 return(info);
371}
372
373void rztest_lapack(TArray<r_4>& aa, TArray<r_4>& bb)
374{
375 if ( aa.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
376 if ( aa.SizeX() != aa.SizeY()) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a square Matrix"));
377 if ( bb.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
378 if ( bb.SizeX() != aa.SizeX() ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A <> B "));
379 if ( !bb.IsPacked() || !bb.IsPacked() )
380 throw(SzMismatchError("rztest_lapack(TMatrix<r_4> Not packed A or B "));
381
382 int_4 n = aa.SizeX();
383 int_4 nrhs = bb.SizeY();
384 int_4 lda = n;
385 int_4 ldb = bb.SizeX();
386 int_4 info;
387 int_4* ipiv = new int_4[n];
388 sgesv_(&n, &nrhs, aa.Data(), &lda, ipiv, bb.Data(), &ldb, &info);
389 delete[] ipiv;
390 cout << "rztest_lapack/Info= " << info << endl;
391 cout << aa << "\n" << bb << endl;
392 return;
393}
394
395///////////////////////////////////////////////////////////////
396#ifdef __CXX_PRAGMA_TEMPLATES__
397#pragma define_template LapackServer<r_4>
398#pragma define_template LapackServer<r_8>
399#pragma define_template LapackServer< complex<r_4> >
400#pragma define_template LapackServer< complex<r_8> >
401#endif
402
403#if defined(ANSI_TEMPLATES) || defined(GNU_TEMPLATES)
404template class LapackServer<r_4>;
405template class LapackServer<r_8>;
406template class LapackServer< complex<r_4> >;
407template class LapackServer< complex<r_8> >;
408#endif
409
410#if defined(OS_LINUX)
411// Pour le link avec f2c sous Linux
412extern "C" {
413 void MAIN__();
414}
415
416void MAIN__()
417{
418 cerr << "MAIN__() function for linking with libf2c.a " << endl;
419 cerr << " This function should never be called !!! " << endl;
420 throw PError("MAIN__() should not be called - see intflapack.cc");
421}
422#endif
Note: See TracBrowser for help on using the repository browser.