#include "fblp.h"

#include "csvd.h"
#include "poly_roots.h"

/*
static void do_print(String msg, float *z, int n)
{
    int i;
    Line line;
 
    printf("%s:", msg);
    for (i = 0; i < n; i++)
        printf(" %6.3f+i%6.3f", z[2*i], z[2*i+1]);
    printf("\n");
}

static void do_print2(String msg, float *z, int m, int n)
{
    int i;
    Line line;
 
    printf("%s\n", msg);
    for (i = 0; i < n; i++)
    {
        sprintf(line, "%s %d", msg, i);
        do_print(line, z+2*i*m, m);
    }
}
*/

static void reflect_roots(Complex *z, int n)
{
    int i;
    float r;

    for (i = 0; i < n; i++)
    {
	r = COMPLEX_SQR_P(z+i);

	if (r > 1)
	{
	    r = 1 / r;
	    COMPLEX_SCALE_P(z+i, r);
	}
    }
}

/*
  (x-c_roots[0]) * (x-c_roots[1]) * ... * (x-c_roots[n-1])
  = x^n + c[n-1]*x^n-1 + ... + c[1]*x + c[0]
*/
static void poly_coeffs(Complex *c, int n, Complex *c_roots)
{
    int i, j;
    Complex z;

    COMPLEX_SET_P(c, c_roots); /* copy 0th term */

    for (i = 1; i < n; i++)
    {
	COMPLEX_ADD_P(c+i, c+i-1, c_roots+i);

	for (j = i-1; j > 0; j--)
	{
	    COMPLEX_MULTIPLY_P(&z, c+j, c_roots+i);
	    COMPLEX_SUBTRACT_P(c+j, c+j-1, &z);
	}

	COMPLEX_MULTIPLY_P(&z, c_roots+i, c);
	COMPLEX_SET_P(c, &z);
	COMPLEX_SCALE_P(c, -1);
    }
}

static void negate_coeffs(float *c, int n)
{
    int i;

    for (i = 0; i < 2*n; i++)
	c[i] = - c[i];
}

/*
  let m = npoints - length - 1, n = length
  as complex arrays (so double for real) require:
  x[npoints+npred], u[m*n], v[n*n], w[n],
  t1[n+1], t2[max(m,n+1)], t3[n+1], t4[n], t5[n]
*/
/* see Zhu and Bax, JMR 100 (1992) 202-207 */
void forward_backward_lp(float *x, int npoints, int length, int npred,
		float *u, float *v, float *w, float *t1, float *t2,
		float *t3, float *t4, float *t5, float cutoff)
{
    int m = npoints - length - 1, n = length, i, j;
    float *c = t2, *d = t3, *c_roots = t4, *d_roots = t5;
    Complex s, z;

/*
    do_print("x at start", x, npoints);
*/

/*  note that csvd has Fortran conventions on 2D array storage  */
    for (i = 0; i < m; i++)
    {
	for (j = 0; j < n; j++)
	{
	    u[2*(i+m*j)] = x[2*(i+j+1)];
	    u[2*(i+m*j)+1] = x[2*(i+j+1)+1];
	}
    }

/*
    do_print2("uT before", u, m, n);
*/

    if (!csvd(u, v, w, t1, t2, m, n))
    {
	printf("not converged\n");
	ZERO_VECTOR(x+2*npoints, 2*npred);
	return;
    }

/*
    do_print2("uT after", u, m, n);
    do_print2("VT after", v, n, n);
    do_print("w before refine", w, n);
*/
    csvd_refine(w, cutoff, n);
/*
    do_print("w after refine", w, n);
    do_print("x to be fit", x+2*(n+1), m);
*/
    csvd_fit(u, v, w, c, x+2*(n+1), t1, m, n);
/*
    do_print("c fit", c, n);
*/

    c[2*n] = 1;
    c[2*n+1] = 0;
    negate_coeffs(c, n);

    poly_roots((Complex *) c, n, (Complex *) c_roots, TRUE, (Complex *) t1);
/*
    do_print("c roots", c_roots, n);
*/
    reflect_roots((Complex *) c_roots, n);
/*
    do_print("c reflected roots", c_roots, n);
*/
    poly_coeffs((Complex *) c, n, (Complex *) c_roots);
/*
    do_print("c new fit", c, n);
*/

    csvd_fit(u, v, w, d, x, t1, m, n);
/*
    do_print("d fit", d, n);
*/

    d[2*n] = 1;
    d[2*n+1] = 0;
    negate_coeffs(d, n);

    /* ordering is backwards */
    for (i = 0; i < n/2; i++)
    {
	SWAP(d[2*i], d[2*(n-i-1)], float);
	SWAP(d[2*i+1], d[2*(n-i-1)+1], float);
    }

    poly_roots((Complex *) d, n, (Complex *) d_roots, TRUE, (Complex *) t1);
/*
    do_print("d roots", d_roots, n);
*/

    for (i = 0; i < n; i++)
       COMPLEX_CONJ_P((Complex *) (d_roots+2*i));

    reflect_roots((Complex *) d_roots, n);
/*
    do_print("d reflected roots", d_roots, n);
*/
    poly_coeffs((Complex *) d, n, (Complex *) d_roots);
/*
    do_print("d new fit", d, n);
*/

    /* find average of c and d */

    COMPLEX_ADD_VEC((Complex *) c, (Complex *) c, (Complex *) d, n);
    COMPLEX_SCALE_VEC((Complex *) c, 0.5, n);
/*
c = d;
    do_print("avg fit", c, n);
*/

    negate_coeffs(c, n);
    poly_roots((Complex *) c, n, (Complex *) c_roots, TRUE, (Complex *) t1);
/*
    do_print("avg roots", c_roots, n);
*/
    reflect_roots((Complex *) c_roots, n);
/*
    do_print("avg reflected roots", c_roots, n);
*/
    poly_coeffs((Complex *) c, n, (Complex *) c_roots);
/*
    do_print("avg new fit", c, n);
*/

    for (i = 0; i < npred; i++)
    {
	COMPLEX_ZERO(s);

	for (j = 0; j < length; j++)
	{
	    COMPLEX_MULTIPLY_P(&z, (Complex *) (c+2*j), (Complex *) (x+2*(npoints-length+i+j)));
	    COMPLEX_ADD_P(&s, &s, &z);
	}

	x[2*(npoints+i)] = COMPLEX_REAL(s);
	x[2*(npoints+i)+1] = COMPLEX_IMAG(s);
    }

/*
    do_print("x at end", x, npoints+npred);
*/
}

