#include "mem.h"
#include "random.h"

#define  EPS  (1.0E-20)
#define  TOL  (0.1)
#define  LARGE_NUMBER  (1.0E20)

#define  LIST_SIZE  8

#define  SAFE_DIVIDE(a, b)  ((a) / MAX((b), EPS))

#define  VECTOR_OPERATE(v1, v2, n, func) \
	 {   int I;  for (I = 0; I < (n); I++)  v1[I] = func(v2[I]);   }

#define  EXPONENTIATE_VECTOR(v1, v2, n)  VECTOR_OPERATE(v1, v2, n, exp)
#define  COSH_VECTOR(v1, v2, n)  VECTOR_OPERATE(v1, v2, n, cosh)
#define  SINH_VECTOR(v1, v2, n)  VECTOR_OPERATE(v1, v2, n, sinh)

#define  WEIGHT_VECTOR(w, v, n) \
	 {   int I;  w = 0;  for (I = 0; I < (n); I++)  w += v[I];   }

static int ndata;
static int nimage;
static int nlist;
static int max_niter;

static Bool positive;
static Bool new_run;

static FILE *log_file;

static float alpha;
static float beta;
static float omega;
static float accuracy;
static float def;
static float weight;
static float var;
static float rate;
static float w2;
static float wr;
static float r2;

static float a_list[LIST_SIZE];
static float o_list[LIST_SIZE];
static float v_list[LIST_SIZE];

static float *g = NULL;
static float *Ag;
static float *w = NULL;
static float *cum_g = NULL;
static float *cum_Ag = NULL;
static float *scaled_g;
static float *scaled_Ag;
static float *scaled_w;
static float *delta_g;
static float *delta_w;
static float *residuals = NULL;
static float *data = NULL;
static float *mock_data;

static float *image;
static float *temp_image = NULL;

static Transform opus;
static Transform tropus;

static float scale;

static void find_g()
{
    SCALE_VECTOR(scaled_w, w, alpha/accuracy, ndata);
    SUBTRACT_VECTORS(g, residuals, scaled_w, ndata);
}

static void apply_metric()
{
    int i;

    if (positive)
	for (i = 0; i < nimage; i++)
	    temp_image[i] *= image[i];
    else
	for (i = 0; i < nimage; i++)
	    temp_image[i] *= sqrt(image[i]*image[i] + 4*def*def);
}

#define  NITER  40

static int update_w()
{
    int status, iter;
    float lower, upper, p, q, r, s, t;

    ZERO_VECTOR(cum_g, ndata);
    ZERO_VECTOR(cum_Ag, ndata);

    status = 1;

    lower = 0;
    upper = LARGE_NUMBER;

    for (iter = 0; iter < NITER; iter++)
    {
	SCALE_VECTOR(scaled_g, g, accuracy, ndata);
	(*tropus)(temp_image, scaled_g);
	apply_metric();
	(*opus)(Ag, temp_image);
	SCALE_VECTOR(Ag, Ag, accuracy, ndata);

	INNER_PRODUCT(p, g, Ag, ndata);

	if (p < EPS)
	{
	    upper = lower;
	    status = 0;
	    break;
	}

	p = 1 / p;

	SCALE_VECTOR(scaled_Ag, Ag, p, ndata);
	ADD_VECTORS(cum_Ag, cum_Ag, scaled_Ag, ndata);

	SCALE_VECTOR(scaled_g, g, p, ndata);
	ADD_VECTORS(cum_g, cum_g, scaled_g, ndata);

	INNER_PRODUCT(s, cum_g, cum_Ag, ndata);
	INNER_PRODUCT(t, cum_Ag, cum_Ag, ndata);

	if ((s < EPS) || (t < EPS))
	{
	    upper = lower;
	    status = 0;
	    break;
	}

	q = 1 / (alpha*s + t);
	INNER_PRODUCT(r, g, g, ndata);

	SCALE_VECTOR(delta_g, cum_g, alpha, ndata);
	ADD_VECTORS(delta_g, cum_Ag, delta_g, ndata);
	SCALE_VECTOR(delta_g, delta_g, -q, ndata);
	ADD_VECTORS(g, g, delta_g, ndata);

	SCALE_VECTOR(delta_w, cum_g, q*accuracy, ndata);
	ADD_VECTORS(w, w, delta_w, ndata);

/*
	printf("p = %f\n", p);
	printf("q = %f\n", q);
	printf("r = %f\n", r);
	printf("s = %f\n", s);
	printf("t = %f\n", t);
	printf("alpha = %f\n", alpha);
*/

	r = lower + MIN(r, 1/(alpha*p));

	if (r <= 0)
	{
	    lower = 0;
	    status = 2;
	    break;
	}

	upper = MIN(upper, r);
	lower += q;

	if (lower > upper)
	{
	    lower = upper;
	    status = 3;
	    break;
	}

/*
	printf("iter = %d, lower = %f, upper = %f\n", iter, lower, upper);
*/

	if (upper <= (1+TOL)*lower)  /* normal exit */
	{
	    status = 0;
	    break;
	}
    }

    return  status;
}

static void find_image()
{
    if (positive)
    {
	EXPONENTIATE_VECTOR(temp_image, temp_image, nimage);
	SCALE_VECTOR(image, temp_image, def, nimage);
	WEIGHT_VECTOR(weight, image, nimage);
    }
    else
    {
	SINH_VECTOR(image, temp_image, nimage);
	COSH_VECTOR(temp_image, temp_image, nimage);
	SCALE_VECTOR(image, image, 2*def, nimage);
	WEIGHT_VECTOR(weight, temp_image, nimage);
	weight *= 2 * def;
    }
}

static void find_residuals()
{
    SUBTRACT_VECTORS(residuals, data, mock_data, ndata);
    SCALE_VECTOR(residuals, residuals, accuracy, ndata);
}

static int find_param()
{
    int status;
    float t, grade;

    status = 0;

    INNER_PRODUCT(w2, w, w, ndata);
    INNER_PRODUCT(wr, w, residuals, ndata);
    INNER_PRODUCT(r2, residuals, residuals, ndata);

    w2 /= accuracy * accuracy;
    wr /= accuracy;

    t = sqrt(w2 * r2);
    t = SAFE_DIVIDE(wr, t);
    grade = 1 - t;

    if (grade > 1.00001)
	status += 1;

    var = SAFE_DIVIDE(grade, t) + TOL*TOL/12;

/*
    omega = ndata / r2;
*/
    omega = SAFE_DIVIDE(ndata, r2);

    if (ABS(omega-1) > TOL)
	status += 2;

    if (log_file)
	fprintf(log_file, "grade = %6.4f, omega = %8.6f\n", grade, omega);

    return  status;
}

static void insert_values()
{
    int i, j;
    float t, da, v, a_new, o_new, v_new;

    a_new = log(alpha);
    o_new = log(omega);
    v_new = var;

    for (i = 0; i < nlist; i++)
    {
	if (a_new == a_list[i])
	{
	    o_list[i] = (o_new*v_list[i] + o_list[i]*v_new) / (v_new+v_list[i]);
	    v_list[i] = v_new * (v_list[i] / (v_new+v_list[i]));

	    return;
	}
    }

    if (nlist == LIST_SIZE)
    {
	v = v_new;
	j = -1;

	for (i = 0; i < nlist; i++)
	{
	    da = a_list[i] - a_new;
	    t = v_list[i] + da*da*da*da;

	    if (t > v)
	    {
		v = t;
		j = i;
	    }
	}

	if (j == -1)  return;

	nlist--;
	for (i = j; i < nlist; i++)
	{
	    a_list[i] = a_list[i+1];
	    o_list[i] = o_list[i+1];
	    v_list[i] = v_list[i+1];
	}
    }

    j = 0;

    for (i = 0; i < nlist; i++)
	if (a_new > a_list[i])  j = i+1;

    for (i = nlist-1; i >= j; i--)
    {
	a_list[i+1] = a_list[i];
	o_list[i+1] = o_list[i];
	v_list[i+1] = v_list[i];
    }

    a_list[j] = a_new;
    o_list[j] = o_new;
    v_list[j] = v_new;

    nlist++;
}

static void interpolate(float a, float *o, float *v)
{
    int i;
    float t, da, w, wa, wo, waa, wao, abar, obar;

    if (nlist == 1)
    {
	*o = o_list[0];
	*v = 0;

	return;
    }

    w = wa = waa = wo = wao = 0;

    for (i = 0; i < nlist; i++)
    {
	da = log(a) - a_list[i];
	t = v_list[i] + da*da*da*da;
	t = SAFE_DIVIDE(1.0, t);

	w += t;
	wa += t * da;
	wo += t * o_list[i];
    }

    abar = wa / w;
    obar = wo / w;

    for (i = 0; i < nlist; i++)
    {
	da = log(a) - a_list[i];
	t = v_list[i] + da*da*da*da;
	t = SAFE_DIVIDE(1.0, t);
	da -= abar;

	waa += t * da * da;
	wao += t * da * (o_list[i] - obar);
    }

    *o = obar - abar * SAFE_DIVIDE(wao, waa);
    *v = 1 / sqrt(w);
}

static void find_alpha()
{
    int i;
    float t, o, o1, o2, alpha1, alpha2, v;

    t = wr + weight * rate / 2.002;

    if ((t * ABS(t)) > (w2*r2))
    {
	t += sqrt(t*t - w2*r2);
	alpha1 = SAFE_DIVIDE(t, w2);
	alpha2 = SAFE_DIVIDE(r2, t);
    }
    else
    {
	alpha2 = alpha1 = SAFE_DIVIDE(sqrt(r2), sqrt(w2));
    }

    alpha1 = MAX(alpha1, alpha);
    alpha2 = MIN(alpha2, alpha);

    interpolate(alpha1, &o1, &v);

    if (o1 > 0)
    {
	alpha = alpha1;
    }
    else
    {
	interpolate(alpha2, &o2, &v);

	if (o2 < 0)
	{
	    alpha = alpha2;
	}
	else
	{
	    for (i = 0; i < 30; i++)
	    {
		alpha = (alpha1 + alpha2) / 2;
		interpolate(alpha, &o, &v);

		if (o < 0)
		{
		    alpha1 = alpha;
		    o1 = o;
		}
		else
		{
		    alpha2 = alpha;
		    o2 = o;
		}
	    }

	    if (o1 != o2)
	    {
		t = o1 / (o1 - o2);
		alpha = alpha1 + (alpha2-alpha1) * t;
	    }
	}
    }
}

static int update_alpha()
{
    int status;
    float o, v, t;
/*
    int i;
*/
    insert_values();

/*
    printf("a_list, o_list, v_list\n");
    for (i = 0; i < nlist; i++)
	printf("%d: %f %f %f\n", i, a_list[i], o_list[i], v_list[i]);

    printf("rate = %f\n", rate);
    printf("weight = %f\n", weight);
    printf("w2 = %f\n", w2);
    printf("wr = %f\n", wr);
    printf("r2 = %f\n", r2);
    printf("alpha = %f\n", alpha);
*/

    interpolate(alpha, &o, &v);

    if (ABS(o) < v)
    {
	status = 0;
    }
    else
    {
	status = 1;
	find_alpha();
    }

    t = (w2*alpha - 2*wr + SAFE_DIVIDE(r2, alpha)) / (rate * weight);

    beta = alpha;
    if (t > 1)
    {
	status = 2;
	beta *= t;
    }

/*
    printf("alpha = %f, t = %f\n", alpha, t);
*/

    return  status;
}

static int mem()
{
    int status;

    status = 0;

    if (new_run)
    {
	ZERO_VECTOR(w, ndata);
	nlist = 0;
	beta = alpha = LARGE_NUMBER;
    }
    else
    {
	find_g();
	status += update_w();
    }

    (*tropus)(temp_image, w);
    find_image();
    (*opus)(mock_data, image);
    find_residuals();
    status += 10 * find_param();
    status += 100 * update_alpha();

    return  status;
}

void do_mem(float *data_in_out)
{
    int i, status;

    COPY_VECTOR(data, data_in_out, ndata);
    image = data_in_out;

    new_run = TRUE;

    for (i = 0; i < max_niter; i++)
    {
	if (log_file)
	    fprintf(log_file, "iteration %d\n", i);

	status = mem();

	if (log_file)
	    fprintf(log_file, "status = %d\n\n", status);

	if ((i == 0) && (omega >= 1))
	    status = 0;

	if (status == 0)
	    break;

	new_run = FALSE;

/*
	printf("image = %f\n", image[nimage/2]);
	printf("data = %f\n", data[0]);
	printf("w = %f\n", w[0]);
	printf("residuals = %f\n", residuals[0]);
	printf("g = %f\n", g[0]);
	printf("cum_g = %f\n", cum_g[0]);
	printf("cum_Ag = %f\n", cum_Ag[0]);
*/
	if (log_file)
	    fflush(log_file);
    }

    SCALE_VECTOR(image, image, scale, nimage);
}

Status check_opus_tropus(int ndata, int nimage, Transform opus,
					Transform tropus, float *err)
{
    int i;
    float a, b, c, d, e, f;
    static float *data1, *data2, *image1, *image2;
    static Bool first_pass = TRUE;

    if (first_pass)
    {
	set_seed(12345);
	first_pass = FALSE;

	MALLOC(data1, float, ndata);
	MALLOC(data2, float, ndata);
	MALLOC(image1, float, nimage);
	MALLOC(image2, float, nimage);
    }

    for (i = 0; i < ndata; i++)
	data1[i] = uniform01();

    for (i = 0; i < nimage; i++)
	image1[i] = uniform01();

    (*opus)(data2, image1);
    (*tropus)(image2, data1);

    INNER_PRODUCT(a, data1, data1, ndata);
    INNER_PRODUCT(b, data2, data2, ndata);
    INNER_PRODUCT(c, image1, image1, nimage);
    INNER_PRODUCT(d, image2, image2, nimage);
    INNER_PRODUCT(e, data1, data2, ndata);
    INNER_PRODUCT(f, image1, image2, nimage);

    *err = ABS(e - f) / sqrt(a*b*c*d);

    return  OK;
}

static void free_mem_alloc()
{
    FREE(data, float);
    FREE(cum_g, float);
    FREE(cum_Ag, float);
    FREE(g, float);
    FREE(w, float);
    FREE(residuals, float);

    FREE(temp_image, float);
}

static Status check_mem_alloc()
{
    free_mem_alloc();

    MALLOC(data, float, ndata);
    MALLOC(cum_g, float, ndata);
    MALLOC(cum_Ag, float, ndata);
    MALLOC(g, float, ndata);
    MALLOC(w, float, ndata);
    MALLOC(residuals, float, ndata);

    MALLOC(temp_image, float, nimage);

    scaled_g = scaled_Ag = scaled_w = temp_image;
    delta_g = delta_w = mock_data = Ag = temp_image;

    return  OK;
}

Status init_mem(int n_in, int n_out, int max_it, Bool pos, float sigma,
		float max_rate, float def_mult, float sc, FILE *lg,
		Transform op, Transform trop, String error_msg)
{
    ndata = n_in;
    nimage = n_out;
    max_niter = max_it;
    positive = pos;
    accuracy = 1 / sigma;
    rate = max_rate;
    def = def_mult * sigma * sqrt((double) n_in) / n_out;
    scale = sc;
    log_file = lg;
    opus = op;
    tropus = trop;

    if (check_mem_alloc() == ERROR)
	RETURN_ERROR_MSG("'maxent': allocating memory");

    return  OK;
}
