source: Sophya/trunk/SophyaExt/LinAlg/intflapack.cc@ 1494

Last change on this file since 1494 was 1494, checked in by ansari, 24 years ago

Ajout LeastSquareSolve - Reza 15/5/2001

File size: 13.8 KB
RevLine 
[814]1#include <iostream.h>
[775]2#include "intflapack.h"
[1342]3#include "tvector.h"
4#include "tmatrix.h"
[814]5#include <typeinfo>
[775]6
[1424]7/*!
8 \defgroup LinAlg LinAlg module
9 This module contains classes and functions for complex linear
10 algebra on arrays. This module is intended mainly to have
11 classes implementing C++ interfaces between Sophya objects
12 and external linear algebra libraries, such as LAPACK.
13*/
14
15/*!
16 \class SOPHYA::LapackServer
17 \ingroup LinAlg
18 This class implements an interface to LAPACK library driver routines.
19 The LAPACK (Linear Algebra PACKage) is a collection high performance
20 routines to solve common problems in numerical linear algebra.
21 its is available from http://www.netlib.org.
22
23 The present version of our LapackServer (Feb 2001) provides only
24 interfaces for the linear system solver and singular value
25 decomposition (SVD). Only arrays with BaseArray::FortranMemoryMapping
26 can be handled by LapackServer. LapackServer can be instanciated
27 for simple and double precision real or complex array types.
28
29 The example below shows solving a linear system A*X = B
30
31 \code
32 #include "intflapack.h"
33 // ...
34 // Use FortranMemoryMapping as default
35 BaseArray::SetDefaultMemoryMapping(BaseArray::FortranMemoryMapping);
36 // Create an fill the arrays A and B
37 int n = 20;
38 Matrix A(n, n);
39 A = RandomSequence();
40 Vector X(n),B(n);
41 X = RandomSequence();
42 B = A*X;
43 // Solve the linear system A*X = B
44 LapackServer<r_8> lps;
45 lps.LinSolve(A,B);
46 // We get the result in B, which should be equal to X ...
47 // Compute the difference B-X ;
48 Vector diff = B-X;
49 \endcode
50
51*/
52
[775]53extern "C" {
[1342]54// Drivers pour resolution de systemes lineaires
55 void sgesv_(int_4* n, int_4* nrhs, r_4* a, int_4* lda,
56 int_4* ipiv, r_4* b, int_4* ldb, int_4* info);
57 void dgesv_(int_4* n, int_4* nrhs, r_8* a, int_4* lda,
58 int_4* ipiv, r_8* b, int_4* ldb, int_4* info);
59 void cgesv_(int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
60 int_4* ipiv, complex<r_4>* b, int_4* ldb, int_4* info);
61 void zgesv_(int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
62 int_4* ipiv, complex<r_8>* b, int_4* ldb, int_4* info);
63
[1494]64 // Driver pour resolution de systemes au sens de Xi2
65 void sgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, r_4* a, int_4* lda,
66 r_4* b, int_4* ldb, r_4* work, int_4* lwork, int_4* info);
67 void dgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, r_8* a, int_4* lda,
68 r_8* b, int_4* ldb, r_8* work, int_4* lwork, int_4* info);
69 void cgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
70 complex<r_4>* b, int_4* ldb, complex<r_4>* work, int_4* lwork, int_4* info);
71 void zgels_(char * trans, int_4* m, int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
72 complex<r_8>* b, int_4* ldb, complex<r_8>* work, int_4* lwork, int_4* info);
73
[1342]74// Driver pour decomposition SVD
75 void sgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_4* a, int_4* lda,
76 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
77 r_4* work, int_4* lwork, int_4* info);
78 void dgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_8* a, int_4* lda,
79 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
80 r_8* work, int_4* lwork, int_4* info);
81 void cgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
82 complex<r_4>* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
83 complex<r_4>* work, int_4* lwork, int_4* info);
84 void zgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
85 complex<r_8>* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
86 complex<r_8>* work, int_4* lwork, int_4* info);
87
[775]88}
89
[1342]90
91// -------------- Classe LapackServer<T> --------------
92
[814]93template <class T>
[1344]94LapackServer<T>::LapackServer()
[1342]95{
96 SetWorkSpaceSizeFactor();
97}
98
99template <class T>
[1344]100LapackServer<T>::~LapackServer()
[1342]101{
102}
103
[1424]104//! Interface to Lapack linear system solver driver s/d/c/zgesvd().
105/*! Solve the linear system a * x = b. Input arrays
106 should have FortranMemory mapping (column packed).
107 \param a : input matrix, overwritten on output
108 \param b : input-output, input vector b, contains x on exit
109 \return : return code from lapack driver _gesv()
110 */
[1342]111template <class T>
[1042]112int LapackServer<T>::LinSolve(TArray<T>& a, TArray<T> & b)
[814]113{
114 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
115 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
116
[1342]117 int_4 rowa = a.RowsKA();
118 int_4 cola = a.ColsKA();
119 int_4 rowb = b.RowsKA();
120 int_4 colb = b.ColsKA();
[814]121 if ( a.Size(rowa) != a.Size(cola))
122 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Not a square Array"));
[1042]123 if ( a.Size(rowa) != b.Size(rowb))
[814]124 throw(SzMismatchError("LapackServer::LinSolve(a,b) RowSize(a <> b) "));
125
126 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
[1342]127 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Column Packed"));
[814]128
129 int_4 n = a.Size(rowa);
130 int_4 nrhs = b.Size(colb);
131 int_4 lda = a.Step(cola);
132 int_4 ldb = b.Step(colb);
133 int_4 info;
134 int_4* ipiv = new int_4[n];
135
136 if (typeid(T) == typeid(r_4) )
137 sgesv_(&n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb, &info);
138 else if (typeid(T) == typeid(r_8) )
139 dgesv_(&n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb, &info);
140 else if (typeid(T) == typeid(complex<r_4>) )
141 cgesv_(&n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
142 (complex<r_4> *)b.Data(), &ldb, &info);
143 else if (typeid(T) == typeid(complex<r_8>) )
144 zgesv_(&n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
145 (complex<r_8> *)b.Data(), &ldb, &info);
146 else {
147 delete[] ipiv;
148 string tn = typeid(T).name();
149 cerr << " LapackServer::LinSolve(a,b) - Unsupported DataType T = " << tn << endl;
150 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
151 }
152 delete[] ipiv;
[1042]153 return(info);
[814]154}
155
[1494]156template <class T>
157int LapackServer<T>::LeastSquareSolve(TArray<T>& a, TArray<T> & b)
158{
159 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
160 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
161
162 int_4 rowa = a.RowsKA();
163 int_4 cola = a.ColsKA();
164 int_4 rowb = b.RowsKA();
165 int_4 colb = b.ColsKA();
166
167
168 if ( a.Size(rowa) != b.Size(rowb))
169 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) RowSize(a <> b) "));
170
171 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
172 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Column Packed"));
173
174 if ( a.Size(rowa) < a.Size(cola))
175 throw(SzMismatchError("LapackServer::LeastSquareSolve(a,b) NRows<NCols "));
176
177 int_4 m = a.Size(rowa);
178 int_4 n = a.Size(cola);
179 int_4 nrhs = b.Size(colb);
180
181 int_4 lda = a.Step(cola);
182 int_4 ldb = b.Step(colb);
183 int_4 info;
184
185 int_4 minmn = (m < n) ? m : n;
186 int_4 maxmn = (m > n) ? m : n;
187 int_4 maxmnrhs = (nrhs > maxmn) ? nrhs : maxmn;
188 if (maxmnrhs < 1) maxmnrhs = 1;
189
190 int_4 lwork = minmn+maxmnrhs*5;
191 T * work = new T[lwork];
192
193 char trans = 'N';
194
195 if (typeid(T) == typeid(r_4) )
196 sgels_(&trans, &m, &n, &nrhs, (r_4 *)a.Data(), &lda,
197 (r_4 *)b.Data(), &ldb, (r_4 *)work, &lwork, &info);
198 else if (typeid(T) == typeid(r_8) )
199 dgels_(&trans, &m, &n, &nrhs, (r_8 *)a.Data(), &lda,
200 (r_8 *)b.Data(), &ldb, (r_8 *)work, &lwork, &info);
201 else if (typeid(T) == typeid(complex<r_4>) )
202 cgels_(&trans, &m, &n, &nrhs, (complex<r_4> *)a.Data(), &lda,
203 (complex<r_4> *)b.Data(), &ldb, (complex<r_4> *)work, &lwork, &info);
204 else if (typeid(T) == typeid(complex<r_8>) )
205 zgels_(&trans, &m, &n, &nrhs, (complex<r_8> *)a.Data(), &lda,
206 (complex<r_8> *)b.Data(), &ldb, (complex<r_8> *)work, &lwork, &info);
207 else {
208 delete[] work;
209 string tn = typeid(T).name();
210 cerr << " LapackServer::LeastSquareSolve(a,b) - Unsupported DataType T = " << tn << endl;
211 throw TypeMismatchExc("LapackServer::LeastSquareSolve(a,b) - Unsupported DataType (T)");
212 }
213 delete[] work;
214 return(info);
215}
216
217
[1424]218//! Interface to Lapack SVD driver s/d/c/zgesv().
219/*! Computes the vector of singular values of \b a. Input arrays
220 should have FortranMemoryMapping (column packed).
221 \param a : input m-by-n matrix
222 \param s : Vector of min(m,n) singular values (descending order)
223 \return : return code from lapack driver _gesvd()
224 */
225
[1342]226template <class T>
227int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s)
228{
229 return (SVDDriver(a, s, NULL, NULL) );
230}
231
[1424]232//! Interface to Lapack SVD driver s/d/c/zgesv().
233/*! Computes the vector of singular values of \b a, as well as
234 right and left singular vectors of \b a.
235 \f[
236 A = U \Sigma V^T , ( A = U \Sigma V^H \ complex)
237 \f]
238 \f[
239 A v_i = \sigma_i u_i \ and A^T u_i = \sigma_i v_i \ (A^H \ complex)
240 \f]
241 U and V are orthogonal (unitary) matrices.
242 \param a : input m-by-n matrix (in FotranMemoryMapping)
243 \param s : Vector of min(m,n) singular values (descending order)
244 \param u : Matrix of left singular vectors
245 \param vt : Transpose of right singular vectors.
246 \return : return code from lapack driver _gesvd()
247 */
[1342]248template <class T>
249int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s, TArray<T> & u, TArray<T> & vt)
250{
251 return (SVDDriver(a, s, &u, &vt) );
252}
253
[1424]254
255//! Interface to Lapack SVD driver s/d/c/zgesv().
[1342]256template <class T>
257int LapackServer<T>::SVDDriver(TArray<T>& a, TArray<T> & s, TArray<T>* up, TArray<T>* vtp)
258{
259 if ( ( a.NbDimensions() != 2 ) )
260 throw(SzMismatchError("LapackServer::SVD(a, ...) a.NbDimensions() != 2"));
261
262 int_4 rowa = a.RowsKA();
263 int_4 cola = a.ColsKA();
264
265 if ( !a.IsPacked(rowa) )
266 throw(SzMismatchError("LapackServer::SVD(a, ...) a Not Column Packed "));
267
268 int_4 m = a.Size(rowa);
269 int_4 n = a.Size(cola);
270 int_4 maxmn = (m > n) ? m : n;
271 int_4 minmn = (m < n) ? m : n;
272
273 char jobu, jobvt;
274 jobu = 'N';
275 jobvt = 'N';
276
277 sa_size_t sz[2];
278 if ( up != NULL) {
279 if ( dynamic_cast< TVector<T> * > (vtp) )
280 throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for u !") );
281 up->SetMemoryMapping(BaseArray::FortranMemoryMapping);
282 sz[0] = sz[1] = m;
283 up->ReSize(2, sz );
284 jobu = 'A';
285 }
286 else {
287 up = new TMatrix<T>(1,1);
288 jobu = 'N';
289 }
290 if ( vtp != NULL) {
291 if ( dynamic_cast< TVector<T> * > (vtp) )
292 throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for vt !") );
293 vtp->SetMemoryMapping(BaseArray::FortranMemoryMapping);
294 sz[0] = sz[1] = n;
295 vtp->ReSize(2, sz );
296 jobvt = 'A';
297 }
298 else {
299 vtp = new TMatrix<T>(1,1);
300 jobvt = 'N';
301 }
302
303 TVector<T> *vs = dynamic_cast< TVector<T> * > (&s);
304 if (vs) vs->ReSize(minmn);
305 else {
306 TMatrix<T> *ms = dynamic_cast< TMatrix<T> * > (&s);
307 if (ms) ms->ReSize(minmn,1);
308 else {
309 sz[0] = minmn; sz[1] = 1;
310 s.ReSize(1, sz);
311 }
312 }
313
314 int_4 lda = a.Step(a.ColsKA());
315 int_4 ldu = up->Step(up->ColsKA());
316 int_4 ldvt = vtp->Step(vtp->ColsKA());
317
318 int_4 lwork = maxmn*5*wspace_size_factor;
319 T * work = new T[lwork];
320 int_4 info;
321
322 if (typeid(T) == typeid(r_4) )
323 sgesvd_(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
324 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
325 (r_4 *)work, &lwork, &info);
326 else if (typeid(T) == typeid(r_8) )
327 dgesvd_(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
328 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
329 (r_8 *)work, &lwork, &info);
330 else if (typeid(T) == typeid(complex<r_4>) )
331 cgesvd_(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
332 (complex<r_4> *)s.Data(), (complex<r_4> *) up->Data(), &ldu,
333 (complex<r_4> *)vtp->Data(), &ldvt,
334 (complex<r_4> *)work, &lwork, &info);
335 else if (typeid(T) == typeid(complex<r_8>) )
336 zgesvd_(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
337 (complex<r_8> *)s.Data(), (complex<r_8> *) up->Data(), &ldu,
338 (complex<r_8> *)vtp->Data(), &ldvt,
339 (complex<r_8> *)work, &lwork, &info);
340 else {
341 if (jobu == 'N') delete up;
342 if (jobvt == 'N') delete vtp;
343 string tn = typeid(T).name();
344 cerr << " LapackServer::SVDDriver(...) - Unsupported DataType T = " << tn << endl;
345 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
346 }
347
348 if (jobu == 'N') delete up;
349 if (jobvt == 'N') delete vtp;
350 return(info);
351}
352
[775]353void rztest_lapack(TArray<r_4>& aa, TArray<r_4>& bb)
354{
355 if ( aa.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
356 if ( aa.SizeX() != aa.SizeY()) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a square Matrix"));
357 if ( bb.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
[788]358 if ( bb.SizeX() != aa.SizeX() ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A <> B "));
[775]359 if ( !bb.IsPacked() || !bb.IsPacked() )
360 throw(SzMismatchError("rztest_lapack(TMatrix<r_4> Not packed A or B "));
361
[788]362 int_4 n = aa.SizeX();
363 int_4 nrhs = bb.SizeY();
[775]364 int_4 lda = n;
[788]365 int_4 ldb = bb.SizeX();
[775]366 int_4 info;
367 int_4* ipiv = new int_4[n];
368 sgesv_(&n, &nrhs, aa.Data(), &lda, ipiv, bb.Data(), &ldb, &info);
[814]369 delete[] ipiv;
[775]370 cout << "rztest_lapack/Info= " << info << endl;
371 cout << aa << "\n" << bb << endl;
372 return;
373}
[814]374
375///////////////////////////////////////////////////////////////
376#ifdef __CXX_PRAGMA_TEMPLATES__
377#pragma define_template LapackServer<r_4>
378#pragma define_template LapackServer<r_8>
379#pragma define_template LapackServer< complex<r_4> >
380#pragma define_template LapackServer< complex<r_8> >
381#endif
382
383#if defined(ANSI_TEMPLATES) || defined(GNU_TEMPLATES)
384template class LapackServer<r_4>;
385template class LapackServer<r_8>;
386template class LapackServer< complex<r_4> >;
387template class LapackServer< complex<r_8> >;
388#endif
389
390#if defined(OS_LINUX)
391// Pour le link avec f2c sous Linux
392extern "C" {
393 void MAIN__();
394}
395
396void MAIN__()
397{
398 cerr << "MAIN__() function for linking with libf2c.a " << endl;
399 cerr << " This function should never be called !!! " << endl;
400 throw PError("MAIN__() should not be called - see intflapack.cc");
401}
402#endif
Note: See TracBrowser for help on using the repository browser.