| 1 | #ifndef BZ_CGSOLVE_H | 
|---|
| 2 | #define BZ_CGSOLVE_H | 
|---|
| 3 |  | 
|---|
| 4 | BZ_NAMESPACE(blitz) | 
|---|
| 5 |  | 
|---|
| 6 | template<class T_numtype> | 
|---|
| 7 | void dump(const char* name, Array<T_numtype,3>& A) | 
|---|
| 8 | { | 
|---|
| 9 | T_numtype normA = 0; | 
|---|
| 10 |  | 
|---|
| 11 | for (int i=A.lbound(0); i <= A.ubound(0); ++i) | 
|---|
| 12 | { | 
|---|
| 13 | for (int j=A.lbound(1); j <= A.ubound(1); ++j) | 
|---|
| 14 | { | 
|---|
| 15 | for (int k=A.lbound(2); k <= A.ubound(2); ++k) | 
|---|
| 16 | { | 
|---|
| 17 | T_numtype tmp = A(i,j,k); | 
|---|
| 18 | normA += ::fabs(tmp); | 
|---|
| 19 | } | 
|---|
| 20 | } | 
|---|
| 21 | } | 
|---|
| 22 |  | 
|---|
| 23 | normA /= A.numElements(); | 
|---|
| 24 | cout << "Average magnitude of " << name << " is " << normA << endl; | 
|---|
| 25 | } | 
|---|
| 26 |  | 
|---|
| 27 | template<class T_stencil, class T_numtype, int N_rank, class T_BCs> | 
|---|
| 28 | int conjugateGradientSolver(T_stencil stencil, | 
|---|
| 29 | Array<T_numtype,N_rank>& x, | 
|---|
| 30 | Array<T_numtype,N_rank>& rhs, double haltrho, | 
|---|
| 31 | const T_BCs& boundaryConditions) | 
|---|
| 32 | { | 
|---|
| 33 | // NEEDS_WORK: only apply CG updates over interior; need to handle | 
|---|
| 34 | // BCs separately. | 
|---|
| 35 |  | 
|---|
| 36 | // x = unknowns being solved for (initial guess assumed) | 
|---|
| 37 | // r = residual | 
|---|
| 38 | // p = descent direction for x | 
|---|
| 39 | // q = descent direction for r | 
|---|
| 40 |  | 
|---|
| 41 | RectDomain<N_rank> interior = interiorDomain(stencil, x, rhs); | 
|---|
| 42 |  | 
|---|
| 43 | cout << "Interior: " << interior.lbound() << ", " << interior.ubound() | 
|---|
| 44 | << endl; | 
|---|
| 45 |  | 
|---|
| 46 | // Calculate initial residual | 
|---|
| 47 | Array<T_numtype,N_rank> r = rhs.copy(); | 
|---|
| 48 | r *= -1.0; | 
|---|
| 49 |  | 
|---|
| 50 | boundaryConditions.applyBCs(x); | 
|---|
| 51 |  | 
|---|
| 52 | applyStencil(stencil, r, x); | 
|---|
| 53 |  | 
|---|
| 54 | dump("r after stencil", r); | 
|---|
| 55 | cout << "Slice through r: " << endl << r(23,17,Range::all()) << endl; | 
|---|
| 56 | cout << "Slice through x: " << endl << x(23,17,Range::all()) << endl; | 
|---|
| 57 | cout << "Slice through rhs: " << endl << rhs(23,17,Range::all()) << endl; | 
|---|
| 58 |  | 
|---|
| 59 | r *= -1.0; | 
|---|
| 60 |  | 
|---|
| 61 | dump("r", r); | 
|---|
| 62 |  | 
|---|
| 63 | // Allocate the descent direction arrays | 
|---|
| 64 | Array<T_numtype,N_rank> p, q; | 
|---|
| 65 | allocateArrays(x.shape(), p, q); | 
|---|
| 66 |  | 
|---|
| 67 | int iteration = 0; | 
|---|
| 68 | int converged = 0; | 
|---|
| 69 | T_numtype rho = 0.; | 
|---|
| 70 | T_numtype oldrho = 0.; | 
|---|
| 71 |  | 
|---|
| 72 | const int maxIterations = 1000; | 
|---|
| 73 |  | 
|---|
| 74 | // Get views of interior of arrays (without boundaries) | 
|---|
| 75 | Array<T_numtype,N_rank> rint = r(interior); | 
|---|
| 76 | Array<T_numtype,N_rank> pint = p(interior); | 
|---|
| 77 | Array<T_numtype,N_rank> qint = q(interior); | 
|---|
| 78 | Array<T_numtype,N_rank> xint = x(interior); | 
|---|
| 79 |  | 
|---|
| 80 | while (iteration < maxIterations) | 
|---|
| 81 | { | 
|---|
| 82 | rho = sum(r * r); | 
|---|
| 83 |  | 
|---|
| 84 | if ((iteration % 20) == 0) | 
|---|
| 85 | cout << "CG: Iter " << iteration << "\t rho = " << rho << endl; | 
|---|
| 86 |  | 
|---|
| 87 | // Check halting condition | 
|---|
| 88 | if (rho < haltrho) | 
|---|
| 89 | { | 
|---|
| 90 | converged = 1; | 
|---|
| 91 | break; | 
|---|
| 92 | } | 
|---|
| 93 |  | 
|---|
| 94 | if (iteration == 0) | 
|---|
| 95 | { | 
|---|
| 96 | p = r; | 
|---|
| 97 | } | 
|---|
| 98 | else { | 
|---|
| 99 | T_numtype beta = rho / oldrho; | 
|---|
| 100 | p = beta * p + r; | 
|---|
| 101 | } | 
|---|
| 102 |  | 
|---|
| 103 | q = 0.; | 
|---|
| 104 | //        boundaryConditions.applyBCs(p); | 
|---|
| 105 | applyStencil(stencil, q, p); | 
|---|
| 106 |  | 
|---|
| 107 | T_numtype pq = sum(p*q); | 
|---|
| 108 |  | 
|---|
| 109 | T_numtype alpha = rho / pq; | 
|---|
| 110 |  | 
|---|
| 111 | x += alpha * p; | 
|---|
| 112 | r -= alpha * q; | 
|---|
| 113 |  | 
|---|
| 114 | oldrho = rho; | 
|---|
| 115 | ++iteration; | 
|---|
| 116 | } | 
|---|
| 117 |  | 
|---|
| 118 | if (!converged) | 
|---|
| 119 | cout << "Warning: CG solver did not converge" << endl; | 
|---|
| 120 |  | 
|---|
| 121 | return iteration; | 
|---|
| 122 | } | 
|---|
| 123 |  | 
|---|
| 124 | BZ_NAMESPACE_END | 
|---|
| 125 |  | 
|---|
| 126 | #endif // BZ_CGSOLVE_H | 
|---|