#include "toimanager.h"
#include "correl.h"
#include "wienerdecor.h"
extern "C" {
#include "nrutil.h"
}
extern "C" void dtoeplz(double r[], double x[], double y[], int n);

WienerDecorrelator::WienerDecorrelator(int n, int l) {
  nsamples = n;
  lcorr = l;
}

void WienerDecorrelator::init() {
  declareInput("signal");
  declareInput("probe");
  declareOutput("signal");
  declareOutput("noiseestim");
  name="WienerDecorrelator";
  setNeededHistory(nsamples+lcorr+1);
  lowExtra = lcorr;
}

void WienerDecorrelator::run() {
  int snb = getMinIn();
  int sne = getMaxIn();

  //  cout << "Wiener " << snb << " - " << sne << endl;

  CorrelEstimator corr(lcorr, nsamples), autocorr(lcorr, nsamples);

  double* r = new double[2*lcorr];  // autocorr toeplitz matrix
  double* w = new double[lcorr+1];  // filter
  double* y = new double[lcorr+1];  // corr vector
  double* window = new double[lcorr];
  double* filter = new double[lcorr];
  for (int i=0; i<lcorr; i++) filter[i]=0;
  
  int sn = snb;
  int snstartcorr = -1;
  
  while (sn < sne) {
    if (snstartcorr < 0 || 
	(snstartcorr + nsamples < sn && sn+nsamples < sne)) {
      // let's (re)compute the correlation
      snstartcorr = sn;
      corr.reset();
      autocorr.reset();
      int i;
      for (i=sn; i<sn+nsamples; i++) {
	double sig = getData(0, i);
	double prb = getData(1, i);
	corr.push(i, sig, prb);
	autocorr.push(i, prb);
      }
      // correlation is recomputed, let's recompute the wiener filter from wiener equations
      for (i=0; i<lcorr; i++) {
	r[lcorr+i] = r[lcorr-i] = autocorr.correl(i);
	y[i+1] = corr.correl(i);
      }
      dtoeplz(r,w,y,lcorr);
      if (!isnan(w[1])) {
	for (int i=0; i<lcorr; i++) {
	  filter[i] = w[i+1];
	}
      } else {
	cout << "Bad inversion, keeping previous filter\n";
      }
      cout << "Wiener filter : " << sn << "\n ";
      for (i=0; i<lcorr; i++) {
	cout << filter[i] << " ";
      }
     cout << endl;
    }

    if (sn >= snb+lcorr-1) {
      getData(1, sn-lcorr+1, lcorr, window);
      double outSig = 0;
      for (int i=0; i<lcorr; i++) {
	outSig += filter[i] * window[lcorr-1 - i];
      }
      putData(0, sn, getData(0, sn) - outSig);
      putData(1, sn, outSig);
    }
    sn++;
  }

  delete[] y;
  delete[] w;
  delete[] r;
  delete[] filter;
}
