[221] | 1 | #ifndef BZ_CGSOLVE_H
|
---|
| 2 | #define BZ_CGSOLVE_H
|
---|
| 3 |
|
---|
| 4 | BZ_NAMESPACE(blitz)
|
---|
| 5 |
|
---|
| 6 | template<class T_numtype>
|
---|
| 7 | void dump(const char* name, Array<T_numtype,3>& A)
|
---|
| 8 | {
|
---|
| 9 | T_numtype normA = 0;
|
---|
| 10 |
|
---|
| 11 | for (int i=A.lbound(0); i <= A.ubound(0); ++i)
|
---|
| 12 | {
|
---|
| 13 | for (int j=A.lbound(1); j <= A.ubound(1); ++j)
|
---|
| 14 | {
|
---|
| 15 | for (int k=A.lbound(2); k <= A.ubound(2); ++k)
|
---|
| 16 | {
|
---|
| 17 | T_numtype tmp = A(i,j,k);
|
---|
| 18 | normA += ::fabs(tmp);
|
---|
| 19 | }
|
---|
| 20 | }
|
---|
| 21 | }
|
---|
| 22 |
|
---|
| 23 | normA /= A.numElements();
|
---|
| 24 | cout << "Average magnitude of " << name << " is " << normA << endl;
|
---|
| 25 | }
|
---|
| 26 |
|
---|
| 27 | template<class T_stencil, class T_numtype, int N_rank, class T_BCs>
|
---|
| 28 | int conjugateGradientSolver(T_stencil stencil,
|
---|
| 29 | Array<T_numtype,N_rank>& x,
|
---|
| 30 | Array<T_numtype,N_rank>& rhs, double haltrho,
|
---|
| 31 | const T_BCs& boundaryConditions)
|
---|
| 32 | {
|
---|
| 33 | // NEEDS_WORK: only apply CG updates over interior; need to handle
|
---|
| 34 | // BCs separately.
|
---|
| 35 |
|
---|
| 36 | // x = unknowns being solved for (initial guess assumed)
|
---|
| 37 | // r = residual
|
---|
| 38 | // p = descent direction for x
|
---|
| 39 | // q = descent direction for r
|
---|
| 40 |
|
---|
| 41 | RectDomain<N_rank> interior = interiorDomain(stencil, x, rhs);
|
---|
| 42 |
|
---|
| 43 | cout << "Interior: " << interior.lbound() << ", " << interior.ubound()
|
---|
| 44 | << endl;
|
---|
| 45 |
|
---|
| 46 | // Calculate initial residual
|
---|
| 47 | Array<T_numtype,N_rank> r = rhs.copy();
|
---|
| 48 | r *= -1.0;
|
---|
| 49 |
|
---|
| 50 | boundaryConditions.applyBCs(x);
|
---|
| 51 |
|
---|
| 52 | applyStencil(stencil, r, x);
|
---|
| 53 |
|
---|
| 54 | dump("r after stencil", r);
|
---|
| 55 | cout << "Slice through r: " << endl << r(23,17,Range::all()) << endl;
|
---|
| 56 | cout << "Slice through x: " << endl << x(23,17,Range::all()) << endl;
|
---|
| 57 | cout << "Slice through rhs: " << endl << rhs(23,17,Range::all()) << endl;
|
---|
| 58 |
|
---|
| 59 | r *= -1.0;
|
---|
| 60 |
|
---|
| 61 | dump("r", r);
|
---|
| 62 |
|
---|
| 63 | // Allocate the descent direction arrays
|
---|
| 64 | Array<T_numtype,N_rank> p, q;
|
---|
| 65 | allocateArrays(x.shape(), p, q);
|
---|
| 66 |
|
---|
| 67 | int iteration = 0;
|
---|
| 68 | int converged = 0;
|
---|
| 69 | T_numtype rho = 0.;
|
---|
| 70 | T_numtype oldrho = 0.;
|
---|
| 71 |
|
---|
| 72 | const int maxIterations = 1000;
|
---|
| 73 |
|
---|
| 74 | // Get views of interior of arrays (without boundaries)
|
---|
| 75 | Array<T_numtype,N_rank> rint = r(interior);
|
---|
| 76 | Array<T_numtype,N_rank> pint = p(interior);
|
---|
| 77 | Array<T_numtype,N_rank> qint = q(interior);
|
---|
| 78 | Array<T_numtype,N_rank> xint = x(interior);
|
---|
| 79 |
|
---|
| 80 | while (iteration < maxIterations)
|
---|
| 81 | {
|
---|
| 82 | rho = sum(r * r);
|
---|
| 83 |
|
---|
| 84 | if ((iteration % 20) == 0)
|
---|
| 85 | cout << "CG: Iter " << iteration << "\t rho = " << rho << endl;
|
---|
| 86 |
|
---|
| 87 | // Check halting condition
|
---|
| 88 | if (rho < haltrho)
|
---|
| 89 | {
|
---|
| 90 | converged = 1;
|
---|
| 91 | break;
|
---|
| 92 | }
|
---|
| 93 |
|
---|
| 94 | if (iteration == 0)
|
---|
| 95 | {
|
---|
| 96 | p = r;
|
---|
| 97 | }
|
---|
| 98 | else {
|
---|
| 99 | T_numtype beta = rho / oldrho;
|
---|
| 100 | p = beta * p + r;
|
---|
| 101 | }
|
---|
| 102 |
|
---|
| 103 | q = 0.;
|
---|
| 104 | // boundaryConditions.applyBCs(p);
|
---|
| 105 | applyStencil(stencil, q, p);
|
---|
| 106 |
|
---|
| 107 | T_numtype pq = sum(p*q);
|
---|
| 108 |
|
---|
| 109 | T_numtype alpha = rho / pq;
|
---|
| 110 |
|
---|
| 111 | x += alpha * p;
|
---|
| 112 | r -= alpha * q;
|
---|
| 113 |
|
---|
| 114 | oldrho = rho;
|
---|
| 115 | ++iteration;
|
---|
| 116 | }
|
---|
| 117 |
|
---|
| 118 | if (!converged)
|
---|
| 119 | cout << "Warning: CG solver did not converge" << endl;
|
---|
| 120 |
|
---|
| 121 | return iteration;
|
---|
| 122 | }
|
---|
| 123 |
|
---|
| 124 | BZ_NAMESPACE_END
|
---|
| 125 |
|
---|
| 126 | #endif // BZ_CGSOLVE_H
|
---|