source: Sophya/trunk/SophyaExt/LinAlg/intflapack.cc@ 1424

Last change on this file since 1424 was 1424, checked in by ansari, 25 years ago

MAJ documentation - Reza 23/2/2001

File size: 11.1 KB
Line 
1#include <iostream.h>
2#include "intflapack.h"
3#include "tvector.h"
4#include "tmatrix.h"
5#include <typeinfo>
6
7/*!
8 \defgroup LinAlg LinAlg module
9 This module contains classes and functions for complex linear
10 algebra on arrays. This module is intended mainly to have
11 classes implementing C++ interfaces between Sophya objects
12 and external linear algebra libraries, such as LAPACK.
13*/
14
15/*!
16 \class SOPHYA::LapackServer
17 \ingroup LinAlg
18 This class implements an interface to LAPACK library driver routines.
19 The LAPACK (Linear Algebra PACKage) is a collection high performance
20 routines to solve common problems in numerical linear algebra.
21 its is available from http://www.netlib.org.
22
23 The present version of our LapackServer (Feb 2001) provides only
24 interfaces for the linear system solver and singular value
25 decomposition (SVD). Only arrays with BaseArray::FortranMemoryMapping
26 can be handled by LapackServer. LapackServer can be instanciated
27 for simple and double precision real or complex array types.
28
29 The example below shows solving a linear system A*X = B
30
31 \code
32 #include "intflapack.h"
33 // ...
34 // Use FortranMemoryMapping as default
35 BaseArray::SetDefaultMemoryMapping(BaseArray::FortranMemoryMapping);
36 // Create an fill the arrays A and B
37 int n = 20;
38 Matrix A(n, n);
39 A = RandomSequence();
40 Vector X(n),B(n);
41 X = RandomSequence();
42 B = A*X;
43 // Solve the linear system A*X = B
44 LapackServer<r_8> lps;
45 lps.LinSolve(A,B);
46 // We get the result in B, which should be equal to X ...
47 // Compute the difference B-X ;
48 Vector diff = B-X;
49 \endcode
50
51*/
52
53extern "C" {
54// Drivers pour resolution de systemes lineaires
55 void sgesv_(int_4* n, int_4* nrhs, r_4* a, int_4* lda,
56 int_4* ipiv, r_4* b, int_4* ldb, int_4* info);
57 void dgesv_(int_4* n, int_4* nrhs, r_8* a, int_4* lda,
58 int_4* ipiv, r_8* b, int_4* ldb, int_4* info);
59 void cgesv_(int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
60 int_4* ipiv, complex<r_4>* b, int_4* ldb, int_4* info);
61 void zgesv_(int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
62 int_4* ipiv, complex<r_8>* b, int_4* ldb, int_4* info);
63
64// Driver pour decomposition SVD
65 void sgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_4* a, int_4* lda,
66 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
67 r_4* work, int_4* lwork, int_4* info);
68 void dgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_8* a, int_4* lda,
69 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
70 r_8* work, int_4* lwork, int_4* info);
71 void cgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
72 complex<r_4>* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
73 complex<r_4>* work, int_4* lwork, int_4* info);
74 void zgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
75 complex<r_8>* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
76 complex<r_8>* work, int_4* lwork, int_4* info);
77
78}
79
80
81// -------------- Classe LapackServer<T> --------------
82
83template <class T>
84LapackServer<T>::LapackServer()
85{
86 SetWorkSpaceSizeFactor();
87}
88
89template <class T>
90LapackServer<T>::~LapackServer()
91{
92}
93
94//! Interface to Lapack linear system solver driver s/d/c/zgesvd().
95/*! Solve the linear system a * x = b. Input arrays
96 should have FortranMemory mapping (column packed).
97 \param a : input matrix, overwritten on output
98 \param b : input-output, input vector b, contains x on exit
99 \return : return code from lapack driver _gesv()
100 */
101template <class T>
102int LapackServer<T>::LinSolve(TArray<T>& a, TArray<T> & b)
103{
104 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
105 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
106
107 int_4 rowa = a.RowsKA();
108 int_4 cola = a.ColsKA();
109 int_4 rowb = b.RowsKA();
110 int_4 colb = b.ColsKA();
111 if ( a.Size(rowa) != a.Size(cola))
112 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Not a square Array"));
113 if ( a.Size(rowa) != b.Size(rowb))
114 throw(SzMismatchError("LapackServer::LinSolve(a,b) RowSize(a <> b) "));
115
116 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
117 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Column Packed"));
118
119 int_4 n = a.Size(rowa);
120 int_4 nrhs = b.Size(colb);
121 int_4 lda = a.Step(cola);
122 int_4 ldb = b.Step(colb);
123 int_4 info;
124 int_4* ipiv = new int_4[n];
125
126 if (typeid(T) == typeid(r_4) )
127 sgesv_(&n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb, &info);
128 else if (typeid(T) == typeid(r_8) )
129 dgesv_(&n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb, &info);
130 else if (typeid(T) == typeid(complex<r_4>) )
131 cgesv_(&n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
132 (complex<r_4> *)b.Data(), &ldb, &info);
133 else if (typeid(T) == typeid(complex<r_8>) )
134 zgesv_(&n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
135 (complex<r_8> *)b.Data(), &ldb, &info);
136 else {
137 delete[] ipiv;
138 string tn = typeid(T).name();
139 cerr << " LapackServer::LinSolve(a,b) - Unsupported DataType T = " << tn << endl;
140 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
141 }
142 delete[] ipiv;
143 return(info);
144}
145
146//! Interface to Lapack SVD driver s/d/c/zgesv().
147/*! Computes the vector of singular values of \b a. Input arrays
148 should have FortranMemoryMapping (column packed).
149 \param a : input m-by-n matrix
150 \param s : Vector of min(m,n) singular values (descending order)
151 \return : return code from lapack driver _gesvd()
152 */
153
154template <class T>
155int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s)
156{
157 return (SVDDriver(a, s, NULL, NULL) );
158}
159
160//! Interface to Lapack SVD driver s/d/c/zgesv().
161/*! Computes the vector of singular values of \b a, as well as
162 right and left singular vectors of \b a.
163 \f[
164 A = U \Sigma V^T , ( A = U \Sigma V^H \ complex)
165 \f]
166 \f[
167 A v_i = \sigma_i u_i \ and A^T u_i = \sigma_i v_i \ (A^H \ complex)
168 \f]
169 U and V are orthogonal (unitary) matrices.
170 \param a : input m-by-n matrix (in FotranMemoryMapping)
171 \param s : Vector of min(m,n) singular values (descending order)
172 \param u : Matrix of left singular vectors
173 \param vt : Transpose of right singular vectors.
174 \return : return code from lapack driver _gesvd()
175 */
176template <class T>
177int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s, TArray<T> & u, TArray<T> & vt)
178{
179 return (SVDDriver(a, s, &u, &vt) );
180}
181
182
183//! Interface to Lapack SVD driver s/d/c/zgesv().
184template <class T>
185int LapackServer<T>::SVDDriver(TArray<T>& a, TArray<T> & s, TArray<T>* up, TArray<T>* vtp)
186{
187 if ( ( a.NbDimensions() != 2 ) )
188 throw(SzMismatchError("LapackServer::SVD(a, ...) a.NbDimensions() != 2"));
189
190 int_4 rowa = a.RowsKA();
191 int_4 cola = a.ColsKA();
192
193 if ( !a.IsPacked(rowa) )
194 throw(SzMismatchError("LapackServer::SVD(a, ...) a Not Column Packed "));
195
196 int_4 m = a.Size(rowa);
197 int_4 n = a.Size(cola);
198 int_4 maxmn = (m > n) ? m : n;
199 int_4 minmn = (m < n) ? m : n;
200
201 char jobu, jobvt;
202 jobu = 'N';
203 jobvt = 'N';
204
205 sa_size_t sz[2];
206 if ( up != NULL) {
207 if ( dynamic_cast< TVector<T> * > (vtp) )
208 throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for u !") );
209 up->SetMemoryMapping(BaseArray::FortranMemoryMapping);
210 sz[0] = sz[1] = m;
211 up->ReSize(2, sz );
212 jobu = 'A';
213 }
214 else {
215 up = new TMatrix<T>(1,1);
216 jobu = 'N';
217 }
218 if ( vtp != NULL) {
219 if ( dynamic_cast< TVector<T> * > (vtp) )
220 throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for vt !") );
221 vtp->SetMemoryMapping(BaseArray::FortranMemoryMapping);
222 sz[0] = sz[1] = n;
223 vtp->ReSize(2, sz );
224 jobvt = 'A';
225 }
226 else {
227 vtp = new TMatrix<T>(1,1);
228 jobvt = 'N';
229 }
230
231 TVector<T> *vs = dynamic_cast< TVector<T> * > (&s);
232 if (vs) vs->ReSize(minmn);
233 else {
234 TMatrix<T> *ms = dynamic_cast< TMatrix<T> * > (&s);
235 if (ms) ms->ReSize(minmn,1);
236 else {
237 sz[0] = minmn; sz[1] = 1;
238 s.ReSize(1, sz);
239 }
240 }
241
242 int_4 lda = a.Step(a.ColsKA());
243 int_4 ldu = up->Step(up->ColsKA());
244 int_4 ldvt = vtp->Step(vtp->ColsKA());
245
246 int_4 lwork = maxmn*5*wspace_size_factor;
247 T * work = new T[lwork];
248 int_4 info;
249
250 if (typeid(T) == typeid(r_4) )
251 sgesvd_(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
252 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
253 (r_4 *)work, &lwork, &info);
254 else if (typeid(T) == typeid(r_8) )
255 dgesvd_(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
256 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
257 (r_8 *)work, &lwork, &info);
258 else if (typeid(T) == typeid(complex<r_4>) )
259 cgesvd_(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
260 (complex<r_4> *)s.Data(), (complex<r_4> *) up->Data(), &ldu,
261 (complex<r_4> *)vtp->Data(), &ldvt,
262 (complex<r_4> *)work, &lwork, &info);
263 else if (typeid(T) == typeid(complex<r_8>) )
264 zgesvd_(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
265 (complex<r_8> *)s.Data(), (complex<r_8> *) up->Data(), &ldu,
266 (complex<r_8> *)vtp->Data(), &ldvt,
267 (complex<r_8> *)work, &lwork, &info);
268 else {
269 if (jobu == 'N') delete up;
270 if (jobvt == 'N') delete vtp;
271 string tn = typeid(T).name();
272 cerr << " LapackServer::SVDDriver(...) - Unsupported DataType T = " << tn << endl;
273 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
274 }
275
276 if (jobu == 'N') delete up;
277 if (jobvt == 'N') delete vtp;
278 return(info);
279}
280
281void rztest_lapack(TArray<r_4>& aa, TArray<r_4>& bb)
282{
283 if ( aa.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
284 if ( aa.SizeX() != aa.SizeY()) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a square Matrix"));
285 if ( bb.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
286 if ( bb.SizeX() != aa.SizeX() ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A <> B "));
287 if ( !bb.IsPacked() || !bb.IsPacked() )
288 throw(SzMismatchError("rztest_lapack(TMatrix<r_4> Not packed A or B "));
289
290 int_4 n = aa.SizeX();
291 int_4 nrhs = bb.SizeY();
292 int_4 lda = n;
293 int_4 ldb = bb.SizeX();
294 int_4 info;
295 int_4* ipiv = new int_4[n];
296 sgesv_(&n, &nrhs, aa.Data(), &lda, ipiv, bb.Data(), &ldb, &info);
297 delete[] ipiv;
298 cout << "rztest_lapack/Info= " << info << endl;
299 cout << aa << "\n" << bb << endl;
300 return;
301}
302
303///////////////////////////////////////////////////////////////
304#ifdef __CXX_PRAGMA_TEMPLATES__
305#pragma define_template LapackServer<r_4>
306#pragma define_template LapackServer<r_8>
307#pragma define_template LapackServer< complex<r_4> >
308#pragma define_template LapackServer< complex<r_8> >
309#endif
310
311#if defined(ANSI_TEMPLATES) || defined(GNU_TEMPLATES)
312template class LapackServer<r_4>;
313template class LapackServer<r_8>;
314template class LapackServer< complex<r_4> >;
315template class LapackServer< complex<r_8> >;
316#endif
317
318#if defined(OS_LINUX)
319// Pour le link avec f2c sous Linux
320extern "C" {
321 void MAIN__();
322}
323
324void MAIN__()
325{
326 cerr << "MAIN__() function for linking with libf2c.a " << endl;
327 cerr << " This function should never be called !!! " << endl;
328 throw PError("MAIN__() should not be called - see intflapack.cc");
329}
330#endif
Note: See TracBrowser for help on using the repository browser.