#include "fftwserver.h"
#include "FFTW/fftw.h"
#include "FFTW/rfftw.h"

class FFTWServerPlan{
public:
  FFTWServerPlan(int n, fftw_direction dir, bool fgreal=false);
  FFTWServerPlan(int nx, int ny, fftw_direction dir, bool fgreal=false);
  ~FFTWServerPlan();
  void Recreate(int n);
  void Recreate(int nx, int ny);

  int _n;
  int _nx, _ny;
  fftw_direction _dir;

  fftw_plan p;
  rfftw_plan rp;
  fftwnd_plan pnd;
  rfftwnd_plan rpnd;
   
};

FFTWServerPlan::FFTWServerPlan(int n, fftw_direction dir, bool fgreal)
{
  if (n < 1) 
    throw ParmError("FFTWServerPlan: Array size <= 0 !");
  p = NULL;
  rp = NULL;
  pnd = NULL;
  rpnd = NULL;
  _nx = _ny = -10;
  _n = n;  
  _dir = dir;
  if (fgreal) rp = rfftw_create_plan(n, dir, FFTW_ESTIMATE);
  else p = fftw_create_plan(n, dir, FFTW_ESTIMATE);
}
FFTWServerPlan::FFTWServerPlan(int nx, int ny, fftw_direction dir, bool fgreal)
{
  if ( (nx < 1) || (ny <1) ) 
    throw ParmError("FFTWServerPlan: Array size Nx or Ny <= 0 !");
  p = NULL;
  rp = NULL;
  pnd = NULL;
  rpnd = NULL;
  _n = -10;
  _nx = nx; 
  _ny = ny;
  _dir = dir;
  int sz[2];
  sz[0]= nx; sz[1] = ny;
  if (fgreal) rpnd = rfftwnd_create_plan(2, sz, dir, FFTW_ESTIMATE);
  else pnd = fftwnd_create_plan(2, sz, dir, FFTW_ESTIMATE);
}

FFTWServerPlan::~FFTWServerPlan()
{
  if (p) fftw_destroy_plan(p);
  if (rp) rfftw_destroy_plan(rp);
  if (pnd) fftwnd_destroy_plan(pnd);
  if (rpnd) rfftwnd_destroy_plan(rpnd);
}

void
FFTWServerPlan::Recreate(int n)
{
  if (n < 1) 
   throw ParmError("FFTWServerPlan::Recreate(n) n < 0 !");
  if ((_nx > 0) || (_ny > 0))  
   throw ParmError("FFTWServerPlan::Recreate(n) Nx or Ny > 0 !");
  if (n == _n) return;
  _n = n;  
  if (p) {
    fftw_destroy_plan(p);
    p = fftw_create_plan(n, _dir, FFTW_ESTIMATE);
  }
  else {
    rfftw_destroy_plan(rp);
    rp = rfftw_create_plan(n, _dir, FFTW_ESTIMATE);
  }
}

void
FFTWServerPlan::Recreate(int nx, int ny)
{
  if ( (nx < 1) || (ny <1) ) 
    throw ParmError("FFTWServerPlan:Recreate(nx, ny) size Nx or Ny <= 0 !");
  if (_n > 0)
    throw ParmError("FFTWServerPlan::Recreate(nx, ny) N > 0 !");
  if ((nx == _nx) && (ny == _ny)) return;
  _nx = nx; 
  _ny = ny;
  int sz[2];
  sz[0]= nx; sz[1] = ny;
  if (pnd) {
    fftwnd_destroy_plan(pnd);
    pnd = fftwnd_create_plan(2, sz,_dir, FFTW_ESTIMATE);
  }
  else {
    rfftwnd_destroy_plan(rpnd);
    rpnd = rfftwnd_create_plan(2, sz, _dir, FFTW_ESTIMATE);
  }

}


/* --Methode-- */
FFTWServer::FFTWServer()
  : FFTServerInterface("FFTServer using FFTW package")
{
  _p1df = NULL;
  _p1db = NULL;
  _p2df = NULL;
  _p2db = NULL;

  _p1drf = NULL;
  _p1drb = NULL;
  _p2drf = NULL;
  _p2drb = NULL;
}


/* --Methode-- */
FFTWServer::~FFTWServer()
{
  if (_p1df) delete _p1df ;
  if (_p1db) delete _p1db ;
  if (_p2df) delete _p2df ;
  if (_p2db) delete _p2db ;

  if (_p1drf) delete _p1drf ;
  if (_p1drb) delete _p1drb ;
  if (_p2drf) delete _p2drf ;
  if (_p2drb) delete _p2drb ;
}

/* --Methode-- */
FFTServerInterface * FFTWServer::Clone()
{
  return (new FFTWServer) ;
}

/* --Methode-- */
void 
FFTWServer::FFTForward(TVector< complex<double> > const & in, TVector< complex<double> > & out)
{
  if (_p1df) _p1df->Recreate(in.NElts());
  else _p1df = new FFTWServerPlan(in.NElts(), FFTW_FORWARD, false);
  out.ReSize(in.NElts());
  fftw_one(_p1df->p, (fftw_complex *)(in.Data()) , (fftw_complex *)(out.Data()) );
  if(this->getNormalize()) out=out/complex<double>(pow(in.NElts(),0.5),0.);
}
/* --Methode-- */
void FFTWServer::FFTBackward(TVector< complex<double> > const & in, TVector< complex<double> > & out)
{
  if (_p1db) _p1db->Recreate(in.NElts());
  else _p1db = new FFTWServerPlan(in.NElts(), FFTW_BACKWARD, false);
  out.ReSize(in.NElts());
  fftw_one(_p1db->p, (fftw_complex *)(in.Data()) , (fftw_complex *)(out.Data()) );
  if(this->getNormalize()) out=out/complex<double>(pow(in.NElts(),0.5),0.);

}


void FFTWServer::FFTForward(TVector< double > const & in, TVector< complex<double> > & out)
{  
  int size = in.NElts()/2;
  
  if(in.NElts()%2 != 0)  { size = in.NElts()/2 +1;}
  else  { size = in.NElts()/2 +1 ;}

  TVector< double > const outTemp(in.NElts());
  out.ReSize(size);
  if (_p1drf) _p1drf->Recreate(in.NElts());
  else _p1drf = new FFTWServerPlan(in.NElts(), FFTW_REAL_TO_COMPLEX, true);
  rfftw_one(_p1drf->rp, (fftw_real *)(in.Data()) , (fftw_real *)(outTemp.Data()));
  ReShapetoCompl(outTemp, out);
  if(this->getNormalize()) out=out/complex<double>(pow(in.NElts(),0.5),0.);
}



void FFTWServer::FFTBackward(TVector< complex<double> > const & in, TVector< double > & out)
{
  int size;
  if(in(in.NElts()).imag()  == 0) { size = 2*in.NElts()-2;}
  else { size = 2*in.NElts()-1;}
  
  TVector< double > inTemp(size);
  out.ReSize(size);

  if (_p1drb) _p1drb->Recreate(size);
  else _p1drb = new FFTWServerPlan(size, FFTW_COMPLEX_TO_REAL, true);

  ReShapetoReal(in, inTemp);
  rfftw_one(_p1drb->rp, (fftw_real *)(inTemp.Data()) , (fftw_real *)(out.Data()));
  if(this->getNormalize()) out=out/pow(size,0.5);
}

/* --Methode-- */
void FFTWServer::FFTForward(TMatrix< complex<double> > const & in, TMatrix< complex<double> > & out)
{
  out.ReSize(in.NRows(),in.NCols());

  if (_p2df) _p2df->Recreate( in.NRows(),in.NCols());
  else _p2df = new FFTWServerPlan( in.NCols(),in.NRows(), FFTW_FORWARD, false);
  
  fftwnd_one(_p2df->pnd, (fftw_complex *)(in.Data()) , (fftw_complex *)(out.Data()) );
  if(this->getNormalize()) out=out/complex<double>(pow(in.NRows()*in.NCols(),0.5),0.);  
}

/* --Methode-- */
void FFTWServer::FFTBackward(TMatrix< complex<double> > const & in, TMatrix< complex<double> > & out)
{
  if (_p2db) _p2db->Recreate(in.NCols(), in.NRows());
  else _p2db = new FFTWServerPlan(in.NCols(), in.NRows(), FFTW_BACKWARD, false);
  out.ReSize(in.NRows(), in.NCols());
  fftwnd_one(_p2db->pnd, (fftw_complex *)(in.Data()) , (fftw_complex *)(out.Data()) );
  if(this->getNormalize()) out=out/complex<double>(pow(in.NRows()*in.NCols(),0.5),0.);

}


/* --Methode-- */
void FFTWServer::FFTForward(TMatrix< double > const & in, TMatrix< complex<double> > & out)
{

  TMatrix< double > inNew(in.NCols(),in.NRows());
  for(int i=0; i<in.NRows(); i++)
    for(int j=0;j<in.NCols(); j++)
      inNew(j,i) = in(i,j);
  
  if (_p2drf) _p2drf->Recreate(inNew.NRows(),inNew.NCols());
  else _p2drf = new FFTWServerPlan(inNew.NRows(), inNew.NCols(),FFTW_REAL_TO_COMPLEX, true);
  //  rfftwnd_plan p;
  TMatrix< complex<double> > outTemp;
  outTemp.ReSize(in.NRows(),in.NCols());

  rfftwnd_one_real_to_complex(_p2drf->rpnd, (fftw_real *)(in.Data()) , (fftw_complex *)(out.Data()) );
}

/* --Methode-- */
void FFTWServer::FFTBackward(TMatrix< complex<double> > const & in, TMatrix< double > & out)
{

  TMatrix< complex<double> > inNew(in.NCols(),in.NRows());
  for(int i=0; i<in.NRows(); i++)
    for(int j=0;j<in.NCols(); j++)
      inNew(j,i) = in(i,j);
  
  if (_p2drb) _p2drb->Recreate(inNew.NRows(),inNew.NCols());
  else _p2drb = new FFTWServerPlan(inNew.NRows(), inNew.NCols(),FFTW_COMPLEX_TO_REAL, true);
  //  rfftwnd_plan p;
  out.ReSize(in.NRows(),in.NCols());

  rfftwnd_one_complex_to_real(_p2drb->rpnd, (fftw_complex *)(in.Data()) , (fftw_real *)(out.Data()) );
  cout << " in the function !!!" << endl;
  if(this->getNormalize()) 
    {
      double norm = (double)(in.NRows()*in.NCols());
      out=out/norm;
    }
}


/* --Methode-- */
void FFTWServer::ReShapetoReal( TVector< complex<double> > const & in, TVector< double >  & out)
{
  int N = in.NElts();
  int i;
  if (in(in.NElts()).imag() == 0) 
    {
      out(0) = in(0).real();
      for(i=1; i<in.NElts(); i++)
	{
	  out(i) = in(i).real();
	}
      for(i=1; i<in.NElts(); i++)
	{
	  out(i+in.NElts()-1) = in(in.NElts()-i-1).imag();
	}
    }
  else
    {
      out(0) = in(0).real();
      for(i=1; i<in.NElts(); i++)
	{
	  out(i) = in(i).real();
	}
      for(i=1; i<in.NElts(); i++)
	{
	  out(i+in.NElts()-1) = in(in.NElts()-i).imag();
	}
    }
  //  for(int k=0; k<out.NElts(); k++) cout << "ReShapetoReal out " << k << " " << out(k) << endl;
}


/* --Methode-- */
void FFTWServer::ReShapetoCompl(TVector< double > const & in, TVector< complex<double> > & out)
{
  int N = in.NElts();
  //  for(int k=0; k<in.NElts(); k++) cout << "ReShapetoCompl in " << k << " "  << in(k) << endl;
  out(0) = complex<double> (in(0),0.);
  if(in.NElts()%2 !=0)
    {
      for(int k=1;k<=N/2+1;k++)
	{
	  out(k) =  complex<double> (in(k),in(N-k));
	}
    }
  else 
    {
      for(int k=1;k<N/2;k++)
	{
	  out(k) =  complex<double> (in(k),in(N-k));
	}
      out(N/2) = complex<double> (in(N/2),0.);
    }
  //  for(int k=0; k<out.NElts(); k++) cout << "ReShapetoCompl out " << k << " "  << out(k) << endl;
}

