source: Sophya/trunk/SophyaExt/LinAlg/intflapack.cc@ 1359

Last change on this file since 1359 was 1344, checked in by cmv, 25 years ago

Erreur de syntaxe rz 24/11/00

File size: 8.3 KB
Line 
1#include <iostream.h>
2#include "intflapack.h"
3#include "tvector.h"
4#include "tmatrix.h"
5#include <typeinfo>
6
7extern "C" {
8// Drivers pour resolution de systemes lineaires
9 void sgesv_(int_4* n, int_4* nrhs, r_4* a, int_4* lda,
10 int_4* ipiv, r_4* b, int_4* ldb, int_4* info);
11 void dgesv_(int_4* n, int_4* nrhs, r_8* a, int_4* lda,
12 int_4* ipiv, r_8* b, int_4* ldb, int_4* info);
13 void cgesv_(int_4* n, int_4* nrhs, complex<r_4>* a, int_4* lda,
14 int_4* ipiv, complex<r_4>* b, int_4* ldb, int_4* info);
15 void zgesv_(int_4* n, int_4* nrhs, complex<r_8>* a, int_4* lda,
16 int_4* ipiv, complex<r_8>* b, int_4* ldb, int_4* info);
17
18// Driver pour decomposition SVD
19 void sgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_4* a, int_4* lda,
20 r_4* s, r_4* u, int_4* ldu, r_4* vt, int_4* ldvt,
21 r_4* work, int_4* lwork, int_4* info);
22 void dgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, r_8* a, int_4* lda,
23 r_8* s, r_8* u, int_4* ldu, r_8* vt, int_4* ldvt,
24 r_8* work, int_4* lwork, int_4* info);
25 void cgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_4>* a, int_4* lda,
26 complex<r_4>* s, complex<r_4>* u, int_4* ldu, complex<r_4>* vt, int_4* ldvt,
27 complex<r_4>* work, int_4* lwork, int_4* info);
28 void zgesvd_(char* jobu, char* jobvt, int_4* m, int_4* n, complex<r_8>* a, int_4* lda,
29 complex<r_8>* s, complex<r_8>* u, int_4* ldu, complex<r_8>* vt, int_4* ldvt,
30 complex<r_8>* work, int_4* lwork, int_4* info);
31
32}
33
34
35// -------------- Classe LapackServer<T> --------------
36
37template <class T>
38LapackServer<T>::LapackServer()
39{
40 SetWorkSpaceSizeFactor();
41}
42
43template <class T>
44LapackServer<T>::~LapackServer()
45{
46}
47
48template <class T>
49int LapackServer<T>::LinSolve(TArray<T>& a, TArray<T> & b)
50{
51 if ( ( a.NbDimensions() != 2 ) || ( b.NbDimensions() != 2 ) )
52 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b NbDimensions() != 2"));
53
54 int_4 rowa = a.RowsKA();
55 int_4 cola = a.ColsKA();
56 int_4 rowb = b.RowsKA();
57 int_4 colb = b.ColsKA();
58 if ( a.Size(rowa) != a.Size(cola))
59 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Not a square Array"));
60 if ( a.Size(rowa) != b.Size(rowb))
61 throw(SzMismatchError("LapackServer::LinSolve(a,b) RowSize(a <> b) "));
62
63 if (!a.IsPacked(rowa) || !b.IsPacked(rowb))
64 throw(SzMismatchError("LapackServer::LinSolve(a,b) a Or b Not Column Packed"));
65
66 int_4 n = a.Size(rowa);
67 int_4 nrhs = b.Size(colb);
68 int_4 lda = a.Step(cola);
69 int_4 ldb = b.Step(colb);
70 int_4 info;
71 int_4* ipiv = new int_4[n];
72
73 if (typeid(T) == typeid(r_4) )
74 sgesv_(&n, &nrhs, (r_4 *)a.Data(), &lda, ipiv, (r_4 *)b.Data(), &ldb, &info);
75 else if (typeid(T) == typeid(r_8) )
76 dgesv_(&n, &nrhs, (r_8 *)a.Data(), &lda, ipiv, (r_8 *)b.Data(), &ldb, &info);
77 else if (typeid(T) == typeid(complex<r_4>) )
78 cgesv_(&n, &nrhs, (complex<r_4> *)a.Data(), &lda, ipiv,
79 (complex<r_4> *)b.Data(), &ldb, &info);
80 else if (typeid(T) == typeid(complex<r_8>) )
81 zgesv_(&n, &nrhs, (complex<r_8> *)a.Data(), &lda, ipiv,
82 (complex<r_8> *)b.Data(), &ldb, &info);
83 else {
84 delete[] ipiv;
85 string tn = typeid(T).name();
86 cerr << " LapackServer::LinSolve(a,b) - Unsupported DataType T = " << tn << endl;
87 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
88 }
89 delete[] ipiv;
90 return(info);
91}
92
93template <class T>
94int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s)
95{
96 return (SVDDriver(a, s, NULL, NULL) );
97}
98
99template <class T>
100int LapackServer<T>::SVD(TArray<T>& a, TArray<T> & s, TArray<T> & u, TArray<T> & vt)
101{
102 return (SVDDriver(a, s, &u, &vt) );
103}
104
105template <class T>
106int LapackServer<T>::SVDDriver(TArray<T>& a, TArray<T> & s, TArray<T>* up, TArray<T>* vtp)
107{
108 if ( ( a.NbDimensions() != 2 ) )
109 throw(SzMismatchError("LapackServer::SVD(a, ...) a.NbDimensions() != 2"));
110
111 int_4 rowa = a.RowsKA();
112 int_4 cola = a.ColsKA();
113
114 if ( !a.IsPacked(rowa) )
115 throw(SzMismatchError("LapackServer::SVD(a, ...) a Not Column Packed "));
116
117 int_4 m = a.Size(rowa);
118 int_4 n = a.Size(cola);
119 int_4 maxmn = (m > n) ? m : n;
120 int_4 minmn = (m < n) ? m : n;
121
122 char jobu, jobvt;
123 jobu = 'N';
124 jobvt = 'N';
125
126 sa_size_t sz[2];
127 if ( up != NULL) {
128 if ( dynamic_cast< TVector<T> * > (vtp) )
129 throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for u !") );
130 up->SetMemoryMapping(BaseArray::FortranMemoryMapping);
131 sz[0] = sz[1] = m;
132 up->ReSize(2, sz );
133 jobu = 'A';
134 }
135 else {
136 up = new TMatrix<T>(1,1);
137 jobu = 'N';
138 }
139 if ( vtp != NULL) {
140 if ( dynamic_cast< TVector<T> * > (vtp) )
141 throw( TypeMismatchExc("LapackServer::SVD() Wrong type (=TVector<T>) for vt !") );
142 vtp->SetMemoryMapping(BaseArray::FortranMemoryMapping);
143 sz[0] = sz[1] = n;
144 vtp->ReSize(2, sz );
145 jobvt = 'A';
146 }
147 else {
148 vtp = new TMatrix<T>(1,1);
149 jobvt = 'N';
150 }
151
152 TVector<T> *vs = dynamic_cast< TVector<T> * > (&s);
153 if (vs) vs->ReSize(minmn);
154 else {
155 TMatrix<T> *ms = dynamic_cast< TMatrix<T> * > (&s);
156 if (ms) ms->ReSize(minmn,1);
157 else {
158 sz[0] = minmn; sz[1] = 1;
159 s.ReSize(1, sz);
160 }
161 }
162
163 int_4 lda = a.Step(a.ColsKA());
164 int_4 ldu = up->Step(up->ColsKA());
165 int_4 ldvt = vtp->Step(vtp->ColsKA());
166
167 int_4 lwork = maxmn*5*wspace_size_factor;
168 T * work = new T[lwork];
169 int_4 info;
170
171 if (typeid(T) == typeid(r_4) )
172 sgesvd_(&jobu, &jobvt, &m, &n, (r_4 *)a.Data(), &lda,
173 (r_4 *)s.Data(), (r_4 *) up->Data(), &ldu, (r_4 *)vtp->Data(), &ldvt,
174 (r_4 *)work, &lwork, &info);
175 else if (typeid(T) == typeid(r_8) )
176 dgesvd_(&jobu, &jobvt, &m, &n, (r_8 *)a.Data(), &lda,
177 (r_8 *)s.Data(), (r_8 *) up->Data(), &ldu, (r_8 *)vtp->Data(), &ldvt,
178 (r_8 *)work, &lwork, &info);
179 else if (typeid(T) == typeid(complex<r_4>) )
180 cgesvd_(&jobu, &jobvt, &m, &n, (complex<r_4> *)a.Data(), &lda,
181 (complex<r_4> *)s.Data(), (complex<r_4> *) up->Data(), &ldu,
182 (complex<r_4> *)vtp->Data(), &ldvt,
183 (complex<r_4> *)work, &lwork, &info);
184 else if (typeid(T) == typeid(complex<r_8>) )
185 zgesvd_(&jobu, &jobvt, &m, &n, (complex<r_8> *)a.Data(), &lda,
186 (complex<r_8> *)s.Data(), (complex<r_8> *) up->Data(), &ldu,
187 (complex<r_8> *)vtp->Data(), &ldvt,
188 (complex<r_8> *)work, &lwork, &info);
189 else {
190 if (jobu == 'N') delete up;
191 if (jobvt == 'N') delete vtp;
192 string tn = typeid(T).name();
193 cerr << " LapackServer::SVDDriver(...) - Unsupported DataType T = " << tn << endl;
194 throw TypeMismatchExc("LapackServer::LinSolve(a,b) - Unsupported DataType (T)");
195 }
196
197 if (jobu == 'N') delete up;
198 if (jobvt == 'N') delete vtp;
199 return(info);
200}
201
202void rztest_lapack(TArray<r_4>& aa, TArray<r_4>& bb)
203{
204 if ( aa.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
205 if ( aa.SizeX() != aa.SizeY()) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a square Matrix"));
206 if ( bb.NbDimensions() != 2 ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A Not a Matrix"));
207 if ( bb.SizeX() != aa.SizeX() ) throw(SzMismatchError("rztest_lapack(TMatrix<r_4> A <> B "));
208 if ( !bb.IsPacked() || !bb.IsPacked() )
209 throw(SzMismatchError("rztest_lapack(TMatrix<r_4> Not packed A or B "));
210
211 int_4 n = aa.SizeX();
212 int_4 nrhs = bb.SizeY();
213 int_4 lda = n;
214 int_4 ldb = bb.SizeX();
215 int_4 info;
216 int_4* ipiv = new int_4[n];
217 sgesv_(&n, &nrhs, aa.Data(), &lda, ipiv, bb.Data(), &ldb, &info);
218 delete[] ipiv;
219 cout << "rztest_lapack/Info= " << info << endl;
220 cout << aa << "\n" << bb << endl;
221 return;
222}
223
224///////////////////////////////////////////////////////////////
225#ifdef __CXX_PRAGMA_TEMPLATES__
226#pragma define_template LapackServer<r_4>
227#pragma define_template LapackServer<r_8>
228#pragma define_template LapackServer< complex<r_4> >
229#pragma define_template LapackServer< complex<r_8> >
230#endif
231
232#if defined(ANSI_TEMPLATES) || defined(GNU_TEMPLATES)
233template class LapackServer<r_4>;
234template class LapackServer<r_8>;
235template class LapackServer< complex<r_4> >;
236template class LapackServer< complex<r_8> >;
237#endif
238
239#if defined(OS_LINUX)
240// Pour le link avec f2c sous Linux
241extern "C" {
242 void MAIN__();
243}
244
245void MAIN__()
246{
247 cerr << "MAIN__() function for linking with libf2c.a " << endl;
248 cerr << " This function should never be called !!! " << endl;
249 throw PError("MAIN__() should not be called - see intflapack.cc");
250}
251#endif
Note: See TracBrowser for help on using the repository browser.