#include <math.h>
#include "toimanager.h"
#include "correl.h"
#include "wienerdecor.h"
extern "C" {
#include "nrutil.h"
}
extern "C" void dtoeplz(double r[], double x[], double y[], int n);

WienerDecorrelator::WienerDecorrelator(int n, int l) {
  nsamples = n;
  lcorr = l;
  doNotLookAt();
}

void WienerDecorrelator::init() {
  declareInput("signal");
  declareInput("probe");
  declareOutput("signal");
  declareOutput("noiseestim");
  name="WienerDecorrelator";
  setNeededHistory(nsamples+lcorr+1);
  lowExtra = lcorr;
}

void WienerDecorrelator::run() {
  int snb = getMinIn();
  int sne = getMaxIn();

  //  cout << "Wiener " << snb << " - " << sne << endl;

  CorrelEstimator corr(lcorr, nsamples), autocorr(lcorr, nsamples);

  double* r = new double[2*lcorr];  // autocorr toeplitz matrix
  double* w = new double[lcorr+1];  // filter
  double* y = new double[lcorr+1];  // corr vector
  double* window = new double[lcorr];
  uint_8* fwind  = new uint_8[lcorr];
  double* filter = new double[lcorr];
  for (int i=0; i<lcorr; i++) filter[i]=0;
  
  int sn = snb;
  int snstartcorr = -1;
  
  while (sn <= sne) {
    if (snstartcorr < 0 || 
	(snstartcorr + nsamples < sn && sn+nsamples < sne)) {
      // let's (re)compute the correlation
      snstartcorr = sn;
      corr.reset();
      autocorr.reset();
      cout << "computing correl " << sn << " -> " << sn+nsamples << endl;
      for (int i=sn; i<sn+nsamples; i++) {
	uint_8 flag1, flag2;
	double sig, prb;
	getData(0, i, sig, flag1);
	if (flag1 & flgNotLookAt) continue;
	getData(1, i, prb, flag2);
	if (flag2 & flgNotLookAt) continue;
	if ((i-sn)%100 == 0) {
	  //cout << " sig/prb : " << i << " : " << sig << " / " << prb 
	  //     << hex << " " << flag1 << " " << flag2 << dec << endl;
	}
	corr.push(i, sig, prb);
	autocorr.push(i, prb);
      }
      // correlation is recomputed, let's recompute the wiener filter from wiener equations
      {for (int i=0; i<lcorr; i++) {
	r[lcorr+i] = r[lcorr-i] = autocorr.correl(i);
	y[i+1] = corr.correl(i);
	//cout << "r " << lcorr+i << " " << lcorr -i << " = " << r[lcorr+i] 
	//    << "\n"
	//    << "y " << i+1 << " = " << y[i+1] << endl;
      }}
      dtoeplz(r,w,y,lcorr);
      if (!isnan(w[1])) {
	for (int i=0; i<lcorr; i++) {
	  filter[i] = w[i+1];
	}
      } else {
	cout << "Bad inversion, keeping previous filter\n";
      }
      cout << "Wiener filter : " << sn << "\n ";
      {for (int i=0; i<lcorr; i++) {
	cout << filter[i] << " ";
      }}
     cout << endl;
    }

    if (sn >= snb+lcorr-1) {
      getData(1, sn-lcorr+1, lcorr, window, fwind);
      uint_8 flag = 0;
      double outSig = 0;
      for (int i=0; i<lcorr; i++) {
	outSig += filter[i] * window[lcorr-1 - i];
	flag |= fwind[lcorr-1 -i];
      }
      putData(0, sn, getData(0, sn) - outSig, flag);
      putData(1, sn, outSig, flag);
    }
    sn++;
  }

  delete[] y;
  delete[] w;
  delete[] r;
  delete[] filter;
}
