#include <math.h>
#include "syn2fast.h"
#include "lambuilder.h"
#include "fftserver.h"
#ifdef __MWERKS__
#include "unixmac.h"
#endif

extern "C" {
  //  void fft_gpd_(long double* ,int& ,int& ,int& ,int& ,long double*);
  float gasdev2_(int& idum);
	   }

void a2lm2map(const int nsmax,const int nlmax,const int nmmax,
	     const vector< vector< complex<double> > >&alme,
	     const vector< vector< complex<double> > >&almb,
	     vector<float>& mapq,
	     vector<float>& mapu,
	     vector< complex<long double> > b_northp,
	     vector< complex<long double> > b_southp,
	     vector< complex<long double> > bwp,
	     vector< complex<long double> > b_northm,
	     vector< complex<long double> > b_southm,
	     vector< complex<long double> > bwm){
  /*=======================================================================
    computes a map form its alm for the HEALPIX pixelisation
    map(theta,phi) = sum_l_m a_lm Y_lm(theta,phi)
    = sum_m {e^(i*m*phi) sum_l a_lm*lambda_lm(theta)}
    
    where Y_lm(theta,phi) = lambda(theta) * e^(i*m*phi)
    
    * the recurrence of Ylm is the standard one (cf Num Rec)
    * the sum over m is done by FFT
    
    =======================================================================*/
  vector< complex<float> > mapp;
  vector< complex<float> > mapm;

  //convert the alm from e/b to +/-
  complex<double> im(0,1);
  vector< vector< complex<double> > >almp=alme;
  vector< vector< complex<double> > >almm=almb;
  for (int i=0;i<(signed) almp.size();i++){
    for (int j=0;j<(signed) almp[i].size();j++){
      almp[i][j]=-(alme[i][j]+im*almb[i][j]);
      almm[i][j]=-(alme[i][j]-im*almb[i][j]);
    }
  }
  //for (int i=0;i<(signed) almp.size();i++){cout<<alme[i][0]<<" "<<almb[i][0]<<" "<<almp[i][0]<<" "<<almm[i][0]<<endl;}
  mapp.resize(12*nsmax*nsmax);                       
  b_northp.resize(2*nmmax+1); //index m corresponds to nmmax+m
  b_southp.resize(2*nmmax+1);
  bwp.resize(4*nsmax);
  mapm.resize(12*nsmax*nsmax);                       
  b_northm.resize(2*nmmax+1); //index m corresponds to nmmax+m
  b_southm.resize(2*nmmax+1);
  bwm.resize(4*nsmax);


  /*      INTEGER l, indl */
  
  int istart_north = 0;
  int istart_south = 12*nsmax*nsmax;

  double dth1 = 1. / (3.*nsmax*nsmax);
  double  dth2 = 2. / (3.*nsmax);
  double  dst1 = 1. / (sqrt(6.) * nsmax);

  for (int ith = 1; ith <= 2*nsmax;ith++){
    int nph, kphi0;
    double cth, sth, sth2;
    if (ith <= nsmax-1){      /* north polar cap */
      nph = 4*ith;
      kphi0 = 1; 
      cth = 1.  - dth1*ith*ith; /* cos(theta) */
      sth = sin( 2. * asin( ith * dst1 ) ) ;  /* sin(theta) */
      sth2 = sth*sth;
    } else { /* tropical band + equat. */
      nph = 4*nsmax;
      kphi0 = (ith+1-nsmax) % 2;
      cth = (2.*nsmax-ith) * dth2;
      sth = sqrt((1.-cth)*(1.+cth)); /* ! sin(theta)*/
      sth2=(1.-cth)*(1.+cth);
    }

    /*        -----------------------------------------------------
	      for each theta, and each m, computes
	      b(m,theta) = sum_over_l>m (lambda_l_m(theta) * a_l_m) 
	      ------------------------------------------------------
	      lambda_mm tends to go down when m increases (risk of underflow)
	      lambda_lm tends to go up   when l increases (risk of overflow)*/
    Lambda2Builder l2b(acos(cth),nlmax,nmmax);
    for (int m = 0; m <= nmmax; m++){
      complex<double> b_np,b_sp,b_nm,b_sm;
      // cout <<l2b.lam2lmp(m,m)<<endl;
      b_np = l2b.lam2lmp(m,m)* almp[m][m];
      b_sp = l2b.lam2lmp(m,m,-1)* almp[m][m];
      b_nm = l2b.lam2lmm(m,m)* almm[m][m];
      b_sm = l2b.lam2lmm(m,m,-1)* almm[m][m];
      for (int l = m+1; l<= nlmax; l++){
	//cout<<l2b.lam2lmp(l,m)<<endl;
	b_np += l2b.lam2lmp(l,m)*almp[l][m];
	b_sp += l2b.lam2lmp(l,m,-1)*almp[l][m];
	b_nm += l2b.lam2lmm(l,m)*almm[l][m];
	b_sm += l2b.lam2lmm(l,m,-1)*almm[l][m];
      }
      
      b_northp[m+nmmax] = b_np;
      b_southp[m+nmmax] = b_sp;
      b_northm[m+nmmax] = b_nm;
      b_southm[m+nmmax] = b_sm;
    }
    
    //        obtains the negative m of b(m,theta) (= complex conjugate)
    for (int m=-nmmax;m<=-1;m++){
      int fac = 1;//(int) pow(-1,m);//  ! either 1 or -1
      complex<long double>
        shit(b_northp[-m+nmmax].real(),-b_northp[-m+nmmax].imag());
      complex<long double>
        shit2(b_southp[-m+nmmax].real(),-b_southp[-m+nmmax].imag());
      shit*=fac; shit2*=fac;
      b_northm[m+nmmax] = shit;
      b_southm[m+nmmax] = shit2;
      shit=complex<long double>(b_northm[-m+nmmax].real(),-b_northm[-m+nmmax].imag());
      shit2=complex<long double>(b_southm[-m+nmmax].real(),-b_southm[-m+nmmax].imag());
      shit*=fac; shit2*=fac;
      b_northp[m+nmmax] = shit;
      b_southp[m+nmmax] = shit2;
    }
    
    // for (int i=0;i<2*nmmax+1;i++){    cout<< b_northp[i]<<" "<<b_southp[i]<<" "<<b_northm[i]
    //				<<" "<<b_southm[i]<<endl;}
      /*        ---------------------------------------------------------------
	      sum_m  b(m,theta)*exp(i*m*phi)   -> f(phi,theta)
        ---------------------------------------------------------------*/
    syn2_phas(nsmax,nlmax,nmmax,b_northp,nph,mapp,istart_north,kphi0,bwp); // north hemisph. + equator
    syn2_phas(nsmax,nlmax,nmmax,b_northm,nph,mapm,istart_north,kphi0,bwm);
    istart_north = istart_north + nph;
    if (ith < 2*nsmax){
      istart_south=istart_south-nph;
      syn2_phas(nsmax,nlmax,nmmax,b_southp,nph,mapp,istart_south,kphi0,bwp); // south hemisph. w/o equat
      syn2_phas(nsmax,nlmax,nmmax,b_southm,nph,mapm,istart_south,kphi0,bwm);
    }
  }
  for (int i=0;i< (signed)mapp.size();i++){
    //cout << mapp[i]<<" "<<mapm[i]<<endl;
    mapq[i]=(mapp[i]+mapm[i]).real()/2;
    mapu[i]=(mapp[i]-mapm[i]).imag()/2;
    //cout << mapq[i]<<" "<<mapu[i]<<endl;
  }
}

void syn2_phas(const int nsmax,const int nlmax,const int nmmax,
	      const vector< complex<long double> >& datain,
	      int nph,vector< complex<float> >& dataout, const int start,
	      int kphi0,
	      vector< complex<long double> >& bw){

  /*=======================================================================
     dataout(j) = sum_m datain(m) * exp(i*m*phi(j)) 
     with phi(j) = j*2pi/nph + kphi0*pi/nph and kphi0 =0 or 1

     as the set of frequencies {m} is larger than nph, 
     we wrap frequencies within {0..nph-1}
     ie  m = k*nph + m' with m' in {0..nph-1}
     then
     noting bw(m') = exp(i*m'*phi0) 
                   * sum_k (datain(k*nph+m') exp(i*k*pi*kphi0))
        with bw(nph-m') = CONJ(bw(m')) (if datain(-m) = CONJ(datain(m)))
     dataout(j) = sum_m' [ bw(m') exp (i*j*m'*2pi/nph) ]
                = Fourier Transform of bw
        is real

         NB nph is not necessarily a power of 2

=======================================================================*/
  int ksign =  1;
  complex<double>* data= new complex<double>[4*nsmax];

  for (int iw=0;iw<nph;iw++){bw[iw]=0;}

  double phi0 = kphi0*M_PI/nph;
  int kshift = (int) pow(-1,kphi0);//  ! either 1 or -1
  //     all frequencies [-m,m] are wrapped in [0,nph-1]
  for (int i=1;i<=2*nmmax+1;i++){
    int m=i-nmmax-1; //in -nmmax, nmmax
    int iw=((m % nph) +nph) % nph; //between 0 and nph = m'
    int k=(m-iw)/nph; //number of 'turns'
    complex<long double> shit(pow(kshift,k));
    bw[iw]+=datain[i-1]*shit;//complex number
  }

  //  kshift**k = 1       for even turn numbers
  //            = 1 or -1 for odd  turn numbers : results from the shift in space

  //     applies the shift in position <-> phase factor in Fourier space
  for (int iw=1;iw<=nph;iw++){
    int m=ksign*(iw-1);
    complex<long double> shit(cos(m*phi0),sin(m*phi0));
    data[iw-1]=bw[iw-1]*shit;
  }
  for (int i = nph; i< 4*nsmax;i++){
    data[i] = 0;
  }
  //cout<<endl;
 //fft_gpd_(data,nph,dum0,ksign,dum1,work); // complex to complex 
  FFTServer fft;
  fft.fftf(nph,data);
  
  for (int iw=0;iw<nph;iw++){
    dataout[iw+start]=data[iw];
  }
  delete[] data;
}


void create_a2lm(const int nsmax,const int nlmax,const int nmmax,
		 vector< vector< complex<double> > >& a2lme,
		 vector< vector< complex<double> > >& a2lmb,
		 vector<float>& cls_e,
		 vector<float>& cls_b,
		 int& iseed,const float fwhm,
		 vector<float> lread){


  /*=======================================================================
     creates the a_lm from the power spectrum, 
     assuming they are gaussian  and complex 
     with a variance given by C(l)

     the input file should contain : l and C(l) with *consecutive* l's
     (missing C(l) are put to 0.)

     because the map is real we have : a_l-m = (-)^m conjug(a_lm)
     so we actually compute them only for m >= 0

 modifie G. Le Meur (13/01/98) :
 on ne lit plus les C(l) sur un fichier. Le tableau est entre en argument
 (cls_tt). Ce tableau doit contenir les valeur de C(l) par ordre 
 SEQUENTIEL de l (de l=0 a l=nlmax)
=======================================================================*/
/*      CHARACTER*128 filename

      integer unit,n_l,j,il,im
      real    fs,quadrupole,xxxtmp,correct
      real    fs_tt
      real    zeta1_r, zeta1_i
      LOGICAL ok
      character*20 string_quad

      real*4    prefact(0:2)
      logical   prefact_bool

      integer lneffct
      real    gasdev2
      external lneffct, gasdev2

c-----------------------------------------------------------------------*/
  a2lme.resize(nlmax+1);
  a2lmb.resize(nlmax+1);
  for (int i=0; i< (signed) a2lme.size();i++){
    a2lme[i].resize(nmmax+1);
    a2lmb[i].resize(nmmax+1);
  }


  iseed = -abs(iseed);
  if (iseed == 0){iseed=-1;}
  int idum = iseed;

  float sig_smooth = fwhm/sqrt(8.*log(2.))/(60.*180.)* M_PI;

  for (int i=0;i<nlmax+1;i++){lread[i]=0;}

  for (int i=0;i <= nlmax;i++){lread[i]  = i;}

  int n_l = nlmax+1;

  cout<<lread[0]<<" <= l <= "<<lread[n_l-1]<<endl;

  //    --- smoothes the initial power spectrum ---

  for (int i=0;i<n_l;i++){
    int l= (int) lread[i];
    float gauss=exp(-l*(l+1.)*sig_smooth*sig_smooth);
    cls_e[i]*=gauss;
    cls_b[i]*=gauss;
  }
  

  //     --- generates randomly the alm according to their power spectrum ---
  float hsqrt2 = 1.0 / sqrt(2.0);

  for (int i=0;i<n_l;i++){
    int l=(int) lread[i];
    float rms_e=sqrt(cls_e[i]);
    float rms_b=sqrt(cls_b[i]);
    //        ------ m = 0 ------
    complex<float> zeta1(gasdev2_(idum));
    a2lme[l][0]   = zeta1 * rms_e;
    zeta1=complex<float>(gasdev2_(idum));
    a2lmb[l][0]   = zeta1 * rms_b;

    //------ m > 0 ------
    for (int m=1;m<=l;m++){
      complex<float> shit1(hsqrt2);
      complex<float> shit2(gasdev2_(idum),gasdev2_(idum));
      zeta1=shit1*shit2;
      a2lme[l][m]=rms_e*zeta1;
      shit2= complex<float>(gasdev2_(idum),gasdev2_(idum));
      zeta1=shit1*shit2;
      a2lmb[l][m]=rms_b*zeta1;
    }
  }
}
