#ifndef BZ_ARRAYWHERE_H #define BZ_ARRAYWHERE_H #ifndef BZ_ARRAYEXPR_H #error must be included via #endif BZ_NAMESPACE(blitz) template class _bz_ArrayWhere { public: typedef P_expr1 T_expr1; typedef P_expr2 T_expr2; typedef P_expr3 T_expr3; typedef _bz_typename T_expr2::T_numtype T_numtype2; typedef _bz_typename T_expr3::T_numtype T_numtype3; typedef BZ_PROMOTE(T_numtype2, T_numtype3) T_numtype; typedef T_expr1 T_ctorArg1; typedef T_expr2 T_ctorArg2; typedef T_expr3 T_ctorArg3; enum { numArrayOperands = BZ_ENUM_CAST(P_expr1::numArrayOperands) + BZ_ENUM_CAST(P_expr2::numArrayOperands) + BZ_ENUM_CAST(P_expr3::numArrayOperands), numIndexPlaceholders = BZ_ENUM_CAST(P_expr1::numIndexPlaceholders) + BZ_ENUM_CAST(P_expr2::numIndexPlaceholders) + BZ_ENUM_CAST(P_expr3::numIndexPlaceholders), rank = _bz_meta_max<_bz_meta_max::max, P_expr3::rank>::max }; _bz_ArrayWhere(const _bz_ArrayWhere& a) : iter1_(a.iter1_), iter2_(a.iter2_), iter3_(a.iter3_) { } template _bz_ArrayWhere(T1 a, T2 b, T3 c) : iter1_(a), iter2_(b), iter3_(c) { } T_numtype operator*() { return (*iter1_) ? (*iter2_) : (*iter3_); } template T_numtype operator()(const TinyVector& i) { return iter1_(i) ? iter2_(i) : iter3_(i); } int lbound(int rank) { return bounds::compute_lbound(rank, bounds::compute_lbound( rank, iter1_.lbound(rank), iter2_.lbound(rank)), iter3_.lbound(rank)); } int ubound(int rank) { return bounds::compute_ubound(rank, bounds::compute_ubound( rank, iter1_.ubound(rank), iter2_.ubound(rank)), iter3_.ubound(rank)); } void push(int position) { iter1_.push(position); iter2_.push(position); iter3_.push(position); } void pop(int position) { iter1_.pop(position); iter2_.pop(position); iter3_.pop(position); } void advance() { iter1_.advance(); iter2_.advance(); iter3_.advance(); } void advance(int n) { iter1_.advance(n); iter2_.advance(n); iter3_.advance(n); } void loadStride(int rank) { iter1_.loadStride(rank); iter2_.loadStride(rank); iter3_.loadStride(rank); } _bz_bool isUnitStride(int rank) const { return iter1_.isUnitStride(rank) && iter2_.isUnitStride(rank) && iter3_.isUnitStride(rank); } void advanceUnitStride() { iter1_.advanceUnitStride(); iter2_.advanceUnitStride(); iter3_.advanceUnitStride(); } _bz_bool canCollapse(int outerLoopRank, int innerLoopRank) const { // BZ_DEBUG_MESSAGE("_bz_ArrayExprOp<>::canCollapse"); return iter1_.canCollapse(outerLoopRank, innerLoopRank) && iter2_.canCollapse(outerLoopRank, innerLoopRank) && iter3_.canCollapse(outerLoopRank, innerLoopRank); } template void moveTo(const TinyVector& i) { iter1_.moveTo(i); iter2_.moveTo(i); iter3_.moveTo(i); } T_numtype operator[](int i) { return iter1_[i] ? iter2_[i] : iter3_[i]; } T_numtype fastRead(int i) { return iter1_.fastRead(i) ? iter2_.fastRead(i) : iter3_.fastRead(i); } int suggestStride(int rank) const { int stride1 = iter1_.suggestStride(rank); int stride2 = iter2_.suggestStride(rank); int stride3 = iter3_.suggestStride(rank); return max(max(stride1,stride2),stride3); } _bz_bool isStride(int rank, int stride) const { return iter1_.isStride(rank,stride) && iter2_.isStride(rank,stride) && iter3_.isStride(rank,stride); } void prettyPrint(string& str, prettyPrintFormat& format) const { str += "[WHERE]"; // NEEDS_WORK } template _bz_bool shapeCheck(const T_shape& shape) { return iter1_.shapeCheck(shape) && iter2_.shapeCheck(shape) && iter3_.shapeCheck(shape); } private: _bz_ArrayWhere() { } T_expr1 iter1_; T_expr2 iter2_; T_expr3 iter3_; }; template inline _bz_ArrayExpr<_bz_ArrayWhere<_bz_typename asExpr::T_expr, _bz_typename asExpr::T_expr, _bz_typename asExpr::T_expr> > where(const T1& a, const T2& b, const T3& c) { return _bz_ArrayExpr<_bz_ArrayWhere<_bz_typename asExpr::T_expr, _bz_typename asExpr::T_expr, _bz_typename asExpr::T_expr> >(a,b,c); } BZ_NAMESPACE_END #endif // BZ_ARRAYWHERE_H