#include "sopnamsp.h"
#include "machdefs.h"

#include <math.h>
#include <ctype.h>
#include <iostream>
#include <typeinfo>

#include "srandgen.h"
#include "matharr.h"
#include "fftpserver.h"
#include "fftmserver.h"
#include "fftmayer.h"
#include "fftwserver.h"
#include "ntoolsinit.h"

#include "timing.h"

/*  --------------------------------------------------------- */
/*  ----       Programme de test de calcul de FFT        ---  */
/*  ----          par les FFTServer de SOPHYA            ---  */
/*  2000-2005      -   C. Magneville,  R. Ansari              */
/*      Test FFT-1D real<>complex complex<>complex            */
/*  Aide/Liste d'arguments :                                  */
/*  csh> tfft                                                 */
/*  Tests de base     ( ==> Rc=0 )                            */
/*  csh> tfft 15 P                                            */
/*  csh> tfft 16 P                                            */  
/*  csh> tfft 15 W                                            */
/*  csh> tfft 16 W                                            */ 
/*  csh> tfft 1531 P D 0 50 0.0002                            */
/*  csh> tfft 1220 P D 0 50 0.0002                            */
/*  csh> tfft 1531 W D 0 50 0.0002                            */
/*  csh> tfft 1220 W D 0 50 0.0002                            */
/*  --------------------------------------------------------- */

static bool inp_typ_random = false ;  // true -> random input

template <class T>
inline T module(complex<T> c) 
{
  return (sqrt(c.real()*c.real()+c.imag()*c.imag()));
}

// Max Matrix print elts
static int nprt = 2;
static int nprtfc = 8;

static  int prtlev = 0;

//-------------------------------------------------
template <class T>
void TestFFTPack(T seuil, sa_size_t num)
{
  int i;
  T fact = 1./num;

  TVector< complex<T> > inc(num), bkc(num), difc(num);
  TVector< T > in(num), ino(num), bk(num),dif(num);
  TVector< complex<T> > outc(num);

  cout << " DBG/1 outc " << outc.NElts()  << endl; 
  outc.ReSize(32);
  cout << " DBG/2 outc " << outc.NElts()  << endl; 
  outc.ReSize(10);
  cout << " DBG/3 outc " << outc.NElts()  << endl; 
  outc.ReSize(48);
  cout << " DBG/4 outc " << outc.NElts()  << endl; 

  if (inp_typ_random) 
    for (i=0; i<num ; i++){
      ino[i] = in[i] = GaussianRand(1.,0.);
      inc[i] = complex<T> (in[i], 0.);
    }
  else for (i=0; i<num ; i++){
    ino[i] = in[i] = 0.5 + cos(2*M_PI*(double)i/(double)num) 
                         + 2*sin(4*M_PI*(double)i/(double)num);
    inc[i] = complex<T> (in[i], 0.);
  }
  

  cout << "Input / L = " << num << in; 
  cout << endl;

  cout << " >>>> Testing FFTPackServer "  << endl;
  FFTPackServer fftp;
  cout << " Testing FFTPackServer "  << endl;
  fftp.fftf(in.NElts(), in.Data());
  //  in /= (num/2.);
  cout << " fftp.fftf(in.NElts(), in.Data()) FORWARD: " << in << endl; 
  cout << endl;
  fftp.fftb(in.NElts(), in.Data());
  cout << " fftp.fftb(in.NElts(), in.Data()) BACKWARD: " << in <<endl; 
  cout << endl;
  dif = ino-in;
  cout << " dif , NElts= " << dif.NElts() << dif << endl;

  int ndiff = 0;
  T maxdif=0., vdif;
  for(i=0; i<num; i++) {
    vdif = fabs(dif(i));
    if (vdif > seuil)  ndiff++;
    if (vdif > maxdif) maxdif = vdif;
  }
  cout << " Difference, Seuil= " << seuil << " NDiff= " << ndiff 
       << " MaxDiff= " << maxdif << endl;

}

//-------------------------------------------------
template <class T>
int TestFFTS(T seuil, FFTServerInterface & ffts, sa_size_t num)
{

  cout <<" ===>  TestFFTS " << ffts.getInfo() << " ArrSz= " << num << endl;  
  int i;

  T fact = 1.;

  TVector< complex<T> > inc(num), bkc(num), difc(num);
  TVector< T > in(num), ino(num), bk(num),dif(num);
  TVector< complex<T> > outc(num);

  if (inp_typ_random) {
    cout << " TestFFTS/Random input vector ... " << endl;
    ino = in = RandomSequence() ; // Nombre aleatoire gaussienne - sigma=1, mean=0
    ComplexMathArray< T > cma;
    TVector< T > im(num);
    im = RandomSequence(RandomSequence::Flat);
    inc = cma.FillFrom(in, im);
  }
  else {
    cout << " TestFFTS/Random input vector = 0.5+cos(2x)+2sin(4x)... " << endl;
    for (i=0; i<num ; i++){
      ino[i] = in[i] = 0.5 + cos(2*M_PI*(double)i/(double)num) 
	+ 2*sin(4*M_PI*(double)i/(double)num);
      inc[i] = complex<T> (in[i], 0.);
    }
  }

  cout << " Testing FFTServer " << ffts.getInfo() << endl;

  cout << "Input / Length= " << num << endl;
  if (prtlev > 0) 
    cout << in << endl; 

  int ndiff = 0;
  int rc = 0;

  cout << "\n ----  Testing FFT-1D(T, complex<T>) ---- " << endl;
  ffts.FFTForward(in, outc);
  if (prtlev > 0) 
    cout << " FourierCoefs: outc= \n" << outc << endl;

  ffts.FFTBackward(outc, bk);
  if (prtlev > 0) 
    cout << " Backward: bk= \n" << bk << endl;
 
  dif = bk*fact - in;
  if (prtlev > 0) 
    cout << " Difference dif= \n" << dif << endl;

  ndiff = 0;
  T maxdif=0., vdif;
  for(i=0; i<num; i++) {
    vdif = fabs(dif(i));
    if (vdif > seuil)  ndiff++;
    if (vdif > maxdif) maxdif = vdif;
  }
  cout << " Difference, Seuil= " << seuil << " NDiff= " << ndiff 
       << " MaxDiff= " << maxdif << endl;
 
  if (ndiff != 0)  rc += 4;

  cout << "\n ----  Testing FFT-1D(complex<T>, complex<T>) ---- " << endl;
  ffts.FFTForward(inc, outc);
  if (prtlev > 0) 
    cout << " FourierCoef , outc= \n" << outc << endl;

  ffts.FFTBackward(outc, bkc);
  if (prtlev > 0) 
    cout << " Backward , bkc= \n " << bkc << endl;
 
  difc = bkc*complex<T>(fact,0.) - inc;
  if (prtlev > 0) 
    cout << " Difference , difc= \n " << difc << endl;

  ndiff = 0;
  maxdif=0.;
  for(i=0; i<num; i++) {
    vdif = fabs(module(difc(i)));
    if (vdif > seuil)  ndiff++;
    if (vdif > maxdif) maxdif = vdif;
  }
  cout << " Difference, Seuil= " << seuil << " NDiff= " << ndiff 
       << " MaxDiff= " << maxdif << endl;

  if (ndiff != 0)  rc += 8;
  return rc;
}

inline void MayerFFTForw(r_4* d, int sz) {
  fht_r4(d, sz);
}
inline void MayerFFTForw(r_8* d, int sz) {
  fht_r8(d, sz);
}

//-------------------------------------------------
template <class T>
int MultiFFTTest(T t, sa_size_t sz, int nloop, bool fgp)
{
  cout <<" ===> MultiFFTTest<T=r_" << sizeof(T) << "> NLoop= " << nloop << " ArrSz= " << sz << endl;  
  TVector< T > in(sz), incopie(sz);

  incopie = in = RandomSequence();
  FFTPackServer fftp(false);
  TVector< complex<T> > outc;
    
  PrtTim("MultiFFT-LoopStart");

  if (fgp) {
    cout << " --- Test effectue avec FFTPack " << endl;
    for(int kk=0; kk<nloop; kk++) {
      in = incopie;
      //      fftp.FFTForward(in, outc);
      fftp.fftf(sz, in.Data());
    }
  }
  else { 
    cout << " --- Test effectue avec FFTMayer " << endl;
    for(int kk=0; kk<nloop; kk++) {
      in = incopie;
      //      in += 0.01;
      MayerFFTForw(in.Data(), sz);
    }
  }
  PrtTim("MultiFFT-LoopEnd");
  cout << " ----- End ---- MultiFFTTest<T> ----" << endl;
  return 0;
}



//-------------------------------------------------
//-------------------------------------------------
//-------------------------------------------------
int main(int narg, char* arg[])
{

  SophyaInit();
  InitTim();   // Initializing the CPU timer

  if (narg < 3) {
    cout << "tfft/ args error - \n  Usage tfft size p/P/M/W [f/d/F/D  PrtLev=0 MaxNPrt=50 diffthr] \n" 
	 << "    OR  tfft size P/M FL/DL/FZ/DZ  [NFFT=10]  \n" 
	 << " 1D real/complex-FFT test (FFTServer) \n"
	 << " size: input vector length \n " 
	 << " p=FFTPackTest  P=FFTPack, M=FFTMayer, W= FFTWServer \n " 
         << " F/f:float, D/d:double F/D:random in_vector (default=D) \n" 
         << "  FL/DL : perform NFFT on a serie of float[size] or double[size] " 
         << "  FZ/DZ : perform NFFT on a series of complex<float>[size] or complex<double>[size] " 
         << " diffthr : Threshold for diff checks (=10^-6/10^-4 double/float)" 
	 << endl;
    return(1);
  }

  FFTPackServer fftp;
  FFTMayerServer fftm;

  FFTWServer fftw;

  sa_size_t sz = atol(arg[1]);
  int nprt = 50;
  if (narg > 4) prtlev = atoi(arg[4]);
  if (narg > 5) nprt = atoi(arg[5]);
  BaseArray::SetMaxPrint(nprt, prtlev);
  float fs = 1.e-4;
  double ds = 1.e-6;
  if (narg > 6) fs = ds = atof(arg[6]);

  if (sz < 2) sz = 2;

  char dtyp = 'D';
  if (narg > 3) dtyp = *arg[3];
  inp_typ_random = true;
  if (islower(dtyp))  inp_typ_random = false;
  dtyp = toupper(dtyp);


  FFTServerInterface * ffts;
  if (*arg[2] == 'M')  ffts = fftm.Clone();
  else if (*arg[2] == 'W')  ffts = fftw.Clone();
  else ffts = fftp.Clone();

  int nloop = 0;
  bool fgpack = true;
  if (arg[3][1] == 'L') {  // MultiFFT test 
    nloop = 10;
    if (narg > 4) nloop = atoi(arg[4]);
    if (*arg[2] == 'M')  fgpack = false;
  }


  cout << "\n ============================================= \n"
       << " ------ Testing FFTServer " << typeid(*ffts).name() << "\n"
       << "   VecSize= " << sz << " InputType= " 
       << ( (inp_typ_random ) ? " random " : " fixed(sin+cos) ") << "\n"
       << " =============================================== " << endl;

  int rc = 0;
  try {
    if (dtyp == 'D') {
      cout << "   ------ Testing FFTServer for double (r_8)----- " << endl;
      if (*arg[2] == 'p') TestFFTPack(ds, sz);
      else if (nloop > 0) rc = MultiFFTTest(ds, sz, nloop, fgpack);
      else rc = TestFFTS(ds, *ffts, sz);
    }
    else {
      cout << "   ------ Testing FFTServer for float (r_4)----- " << endl;
      if (*arg[2] == 'p') TestFFTPack(fs, sz);
      else if (nloop > 0) rc = MultiFFTTest(fs, sz, nloop, fgpack);
      else rc = TestFFTS(fs, *ffts, sz);
    }
  }
  catch(PThrowable& exc ) {
    cerr << "TestFFT-main() , Catched exception: \n" << exc.Msg() << endl;
    rc = 97;
  }
  catch(std::exception ex) {
    cerr << "TestFFT-main() , Catched exception ! " << (string)(ex.what()) << endl;
    rc = 98;
  }
  catch(...) {
    cerr << "TestFFT-main() , Catched ... exception ! " << endl;
    rc = 99;
  }
  PrtTim("End of tfft ");
  delete ffts;
  cout << " =========== End of tfft Rc= " << rc << " =============" << endl;
}
