#ifndef FLUC_INCLUDE
#include "fluctuate.h"
#endif

#ifndef FLUC_MODELLIKE_INCLUDE
#include "fluc_modellike.h"
#endif

#ifdef DMALLOC_FUNC_CHECK
#include "/usr/local/include/dmalloc.h"
#endif

#include "coal_modellike.h"

#define USETRUE false /* include Mary's true-theta/true-g tests */

/***************************************************************************
 *  ALPHA                                                                  *
 *  version 1.10. (c) Copyright 1986, 1991, 1992 by the University of      *
 *  Washington and Joseph Felsenstein.  Written by Joseph Felsenstein,     *
 *  Mary K. Kuhner and Jon A. Yamato, with some additional grunt work by   *
 *  Sean T. Lamont.  Permission is granted to copy and use this program    *
 *   provided no fee is charged for it and provided that this copyright    *
 *  notice is not removed.                                                 *
 *                                                                         *
 ***************************************************************************/

/************************************************************
   This file contains all the functions used in calculating
   P(D|theta), the model_likelihood, in the no-growth case.
 ***********************************************************/

extern FILE *thetafile, *simlog, *outfile;
extern long locus, numseq, totchains, numloci, numtrees;
extern double watttheta, *ne_ratio;
extern tree *curtree;
extern boolean **sametree;
extern option_struct *op;
extern treerec ***sum;

extern double **growi, **theti, **lntheti, *fixed, *locuslike, **savethetai,
   xinterval;
extern long numfix;

/*********************************************************************
 * theta_confidence generates and prints 95% confidence intervals on *
 * the estimation of a particular value of theta, "maxtheta".        *
 * Pass "lowcus" < 0 for multiple locus evaluation over long chains  *
 *    up to chain # "chain".                                         */
void theta_confidence(long chain, double maxtheta, long lowcus, double
  **lthetai)
{
double target, upper, lower, likel, step;
long firstlong, lastlong;

firstlong = op->numchains[0];
lastlong = chain;

step = 1.0;
/* critical value for 95% conf interval in lnL is 1.9205*/
target = model_likelihood(maxtheta,0.0,firstlong,lastlong,lowcus,lthetai) -
1.9205;

/* avoid calculating lower bound if it would be < 0 */
if(model_likelihood(epsilon,0.0,firstlong,lastlong,lowcus,lthetai) > target)
  lower = 0.0;
else {
  lower = maxtheta;
  do {
    do {
      lower -= step;
      /* don't let it get below 0 during calculation either */
      if (lower < epsilon) {
         step = lower + step - epsilon;
         lower = epsilon;
      }
      likel = model_likelihood(lower,0.0,firstlong,lastlong,lowcus,lthetai);
    } while (likel>target);
    step /= 2.0;
    do {
      lower += step;
      likel = model_likelihood(lower,0.0,firstlong,lastlong,lowcus,lthetai);
    } while (likel<target);
    step /= 2.0;
  } while (fabs(target-likel)>epsilon);
}
upper = maxtheta;
step = 1.0;
do {
  do {
    upper += step;
    likel = model_likelihood(upper,0.0,firstlong,lastlong,lowcus,lthetai);
  } while (likel>target);
  step /= 2.0;
  do {
    upper -= step;
    likel = model_likelihood(upper,0.0,firstlong,lastlong,lowcus,lthetai);
  } while (likel<target);
  step /= 2.0;
} while (fabs(target-likel) > epsilon);

fprintf(outfile,"confidence limits on theta:  lower %12.8f, upper %12.8f\n",
  lower,upper);
} /* theta_confidence */

/*
   functions to be used in Newton-Raphson iteration of single chain
   point estimate.
*/

/*********************************************************************
 * this is the first derivative of the likelihood of a single chain. *
 * Note that it returns an answer in log form.                       *
 * EQN: k = # of active lineages within an interval                  *
 *      inter_lngth = length of an interval                          *
 *      thgiven = theta of interest                                  *
 *      thchain = theta under which the chain was run                *
 *      numints = # of intervals in tree                             *
 *                                                                   *
 *      finterval = sum_over_intervals [(k * (k - 1)) * inter_lngth] *
 *                                                                   *
 *      answer = sum_over_trees [(finterval/thgiven - numints) *     *
 *                  exp(-finterval * (1.0/thgiven - 1.0/thchain))]   */
double fnx (long chain, double theval, long numints, long *fxplus,
            boolean *fxzero)
{
  long i, refchain, chaintype, *sign;
  double other, chaintheta, *temp, maxtemp, result;

  chaintheta = theti[locus][chain];
  refchain = REF_CHAIN(chain);
  chaintype = TYPE_CHAIN(chain);
  
  temp = (double *)calloc(op->numout[chaintype],sizeof(double));
  sign = (long *)calloc(op->numout[chaintype],sizeof(long));

  maxtemp = NEGMAX;
  for(i = 0; i < op->numout[chaintype]; i++) {
     other = (sum[locus][refchain][i].kend[0])/theval - numints;
     if (other > 0) sign[i] = 1;
     else if (other < 0) sign[i] = -1;
          else sign[i] = 0;
     temp[i] = log(fabs(other)) +
        (-sum[locus][refchain][i].kend[0])*(1.0/theval - 1.0/chaintheta);
     if (temp[i] > maxtemp) maxtemp = temp[i];
  }

  result = 0.0;
  for(i = 0; i < op->numout[chaintype]; i++) {
     if (temp[i] - maxtemp > EXPMIN)
        result += sign[i] * exp(temp[i] - maxtemp);
  }

  *fxzero = zerocheck(result);
  *fxplus = whatsign(result);

  if (!*fxzero) result = log(fabs(result)) + maxtemp;
  else result = 0.0;

  free(sign);
  free(temp);
  return result;
} /* fnx */

/**********************************************************************
 * this is the second derivative of the likelihood of a single chain. *
 * Note that it returns an answer in log form.                        *
 * EQN: k = # of active lineages within an interval                   *
 *      inter_lngth = length of an interval                           *
 *      thgiven = theta of interest                                   *
 *      thchain = theta under which the chain was run                 *
 *      numints = # of intervals in tree                              *
 *                                                                    *
 *      finterval = sum_over_intervals [(k * (k - 1)) * inter_lngth]  *
 *                                                                    *
 *      f1x = finterval/thgiven                                       *
 *      f2x = finterval/(thgiven**2)                                  *
 *                                                                    *
 *      answer = sum_over_trees [                                     *
 *          exp(-finterval * (1.0/thgiven - 1.0/thchain)) *           *
 *     (f2x*(-2*numints-2) + (numints*(numints+1))/thgiven + f1x*f2x) *
 *          ]                                                         */
double dfnx (long chain, double theval, long numints, long *dfxplus,
             boolean *dfxzero)
{
  long i, refchain, chaintype, *sign;
  double f1x, f2x, other, chaintheta, *temp, result, maxtemp;

  chaintheta = theti[locus][chain];
  refchain = REF_CHAIN(chain);
  chaintype = TYPE_CHAIN(chain);

  temp = (double *)calloc(op->numout[chaintype],sizeof(double));
  sign = (long *)calloc(op->numout[chaintype],sizeof(long));

  maxtemp = NEGMAX;
  for(i = 0; i < op->numout[chaintype]; i++) {
     f1x = ((double)sum[locus][refchain][i].kend[0])/theval;
     f2x = ((double)sum[locus][refchain][i].kend[0])/(theval*theval);
     other = (f2x*(-2*numints-2) + (numints*(numints+1))/theval + f1x*f2x);
     if (other > 0) sign[i] = 1;
     else if (other < 0) sign[i] = -1;
          else sign[i] = 0;
     temp[i] =  -f1x + f1x*theval/chaintheta + log(fabs(other));
     if (temp[i] > maxtemp) maxtemp = temp[i];
  }

  result = 0.0;
  for(i = 0; i < op->numout[chaintype]; i++) {
     if (temp[i] - maxtemp > EXPMIN)
        result += sign[i] * exp(temp[i] - maxtemp);
  }

  *dfxzero = zerocheck(result);
  *dfxplus = whatsign(result);

  if (!*dfxzero) result = log(fabs(result)) + maxtemp;
  else result = 0.0;

  free(sign);
  free(temp);
  return result;
} /* dfnx */

/*********************************************************************
 * solve for maximum of theta-likelihood curve using Newton-Raphson  *
 * iteration                                                         *
 * EQN:  SUMs are over all intervals in the tree.                    *
 *       k = # of active (coalesceable) lineages within an interval. *
 *       f1(x) = SUM[k*(k-1)*t/x]                                    *
 *       f2(x) = SUM[k*(k-1)*t/x**2]                                 *
 *       f3(x) = SUM[-k*(k-1)*t/(1/x-1/chaintheta)]                  *
 *       f4(x) = f2(x)*(-2*#intervals-2) +                           *
 *               (#intervals**2+#intervals)/x + f2(x)*f1(x)          *
 *                                                                   *
 *               exp[f3(oldtheta)] * (f1(oldtheta) - #intervals)     *
 *     change =  -----------------------------------------------     *
 *                   ABS[exp[f3(oldtheta)] * f4(oldtheta)]           *
 *                                                                   *
 * newtheta = oldtheta + change                                      *
 *                                                                   *
 * Both numerator and denominator of the above are summed over all   *
 * trees.                                                            *
 *                                                                   *
 * Note that this is not the standard Newton-Raphson iteration.  The *
 * ratio is being added rather than subtracted, and the absolute     *
 * value of the second derivative is used.  These changes should     *
 * ensure that the function is always moving towards a relative      *
 * maximum, rather than a minimum.                                   */
double coal_singlechain(long chain, boolean chend, boolean rend)
{
  int numloop; /* type "int" because of library function demands! */
  long i, numintervals, fxplus, dfxplus;
  double theta, newtheta, oldlike, newlike, fx, dfx, change;
  boolean fxzero, dfxzero;

  /* point estimate of theta */
  theta = watttheta;
  numintervals = numseq - 1; /* WARNING--Wrong for recombination */
  oldlike = model_llike(theta,0.0,chain,locus);
  i = 0;

  /* solve by modified Newton Raphson */
  while (1) {
     fx = fnx(chain,theta,numintervals,&fxplus,&fxzero);
     dfx = dfnx(chain,theta,numintervals,&dfxplus,&dfxzero); 

     if (fxzero) /* found a maximum! at theta! */
        break;
     if (dfxzero) {
        theta += epsilon;
        fx = fnx(chain,theta,numintervals,&fxplus,&fxzero);
        dfx = dfnx(chain,theta,numintervals,&dfxplus,&dfxzero);
     }
     if (dfxplus < 0) change = fxplus * exp(fx - dfx);
     else change = fxplus * theta/2.0;

     newtheta = theta + change;

    /* now deal with negative or zero theta values */
     numloop = 0;
     while (newtheta <= 0) {
        numloop++;
        newtheta = theta + ldexp(change,-numloop);
     }
     newlike = model_llike(newtheta,0.0,chain,locus);
     if(newlike < oldlike) {
    /* in case we overshoot the maximum, don't jump so far...*/
        numloop = 0;
        while(1) {
           numloop++;
           newtheta = theta + ldexp(change,-numloop);
           if (newtheta <= 0) continue;
           newlike = model_llike(newtheta,0.0,chain,locus);
           if(newlike >= oldlike) break;
        }
     }

     oldlike = newlike;
     if(fabs(newtheta - theta) < epsilon) {
        theta = newtheta;
        break;
     }
     theta = newtheta;
     i++;
  }


  if (!chend) {
    if (thetaout) fprintf(thetafile, 
      "within chain %ld point: %10.7f log likelihood %10.7f\n",
      chain+1, theta, oldlike);
    return(theta);
  }
  if (theta > 0) {
    theti[locus][chain+1] = theta;
    lntheti[locus][chain+1] = log(theta);
    if (thetaout) fprintf(thetafile, 
      "chain %ld point: %10.7f log likelihood: %10.7f\n",
      chain+1, theta, oldlike);
    if (rend) { 
      fprintf(outfile,
        "Single chain point estimate of theta (from final chain)=%12.8f\n",
        theta);
    }
    return(theta);
  }
  /* the estimate failed */
  theti[locus][chain+1] = theti[locus][chain];
  lntheti[locus][chain+1] = lntheti[locus][chain];
  fprintf(ERRFILE,"WARNING, point estimate of theta failed!\n");
  fprintf(ERRFILE,"using previous iteration theta estimate\n");
  if (rend) 
    fprintf(outfile,"Single chain point estimate of theta failed\n");
  return(theti[locus][chain]);
}  /* coal_singlechain */

double combined_fnx (double theval, long numints, double **treewt,
   long firstlong, long lastlong, long *fxplus, boolean *fxzero,
   long lowcus)
/* this is the first derivative of the combined likelihood of several
   chains.  Note that it returns an answer in log form. */
/* EQN:  k = # of active lineages within an interval
         inter_lngth = length of an interval
         thgiven = theta of interest
         thchain = theta under which the chain was run
         numints = # of intervals in tree
         plike = provisional likelihood of a chain

         kk = sum_over_intervals (k*(k - 1) * inter_lngth)

         chainlike = sum_over_trees [
                     numtree * exp(-kk/thchain) /
                     (plike * (thchain**numints)) *
                     exp(-kk/thgiven) * (kk/thgiven - numints)
                     ]

         likelihood = product_over_chains [chainlike]

         Note: combined_fnx returns Ln(likelihood)!!!!!
*/
{
  long i, j, refchain;
  double **temp, kk, other, maxtemp, result, **sign;

  temp = (double **)calloc(1, totchains * sizeof(double *));
  temp[firstlong] = (double *)calloc(1,op->numchains[1]*numtrees*sizeof(double));
  sign = (double **)calloc(1,totchains * sizeof(double *));
  sign[firstlong] = (double *)calloc(1,op->numchains[1]*numtrees*sizeof(double));
  for (i = firstlong + 1; i <= lastlong; i++) {
     temp[i] = temp[firstlong] + (i-firstlong)*numtrees;
     sign[i] = sign[firstlong] + (i-firstlong)*numtrees;
  }

  maxtemp = NEGMAX;
  for(i = firstlong; i <= lastlong; i++) {
     refchain = REF_CHAIN(i);
     for(j = 0; j < op->numout[1]; j++) {
        kk = sum[lowcus][refchain][j].kend[0];
        other = kk/theval - numints;
        sign[i][j] = whatsign(other);
        temp[i][j] = treewt[i][j] - kk/theval + log(fabs(other));
        if (temp[i][j] > maxtemp) maxtemp = temp[i][j];
     }
  }

  result = 0.0;
  for(i = firstlong; i <= lastlong; i++)
     for(j = 0; j < op->numout[1]; j++) {
        if (temp[i][j] - maxtemp > EXPMIN)
           result += sign[i][j] * exp(temp[i][j] - maxtemp);
     }

  *fxzero = zerocheck(result);
  *fxplus = whatsign(result);

  if (!*fxzero) result = log(fabs(result)) + maxtemp;

  free(temp[firstlong]);
  free(temp);
  free(sign[firstlong]);
  free(sign);

  return result;
} /* combined_fnx */

double combined_dfnx (double theval, long numints, double **treewt,
   long firstlong, long lastlong, long *dfxplus, boolean *dfxzero,
   long lowcus)
/* this is the second derivative of the combined likelihood of several
   chains.  Note that it returns an answer in log form. */
/* EQN:  k = # of active lineages within an interval
         inter_lngth = length of an interval
         thgiven = theta of interest
         thchain = theta under which the chain was run
         numints = # of intervals in tree
         plike = provisional likelihood of a chain

         kk = sum_over_intervals (k*(k - 1) * inter_lngth)
         f1x = kk/thgiven
         f2x = kk/(thgiven**2)

         chainlike = sum_over_trees [
                     numtree * exp(-kk/thchain) /
                     (plike * (thchain**numints)) *
                     exp(-kk/thgiven) * 
                     (f2x*(-2*numints-2) + (numints*(numints+1))/thgiven +
                     f1x*f2x)
                     ]

         likelihood = product_over_chains [chainlike]

         Note: combined_dfnx returns Ln(likelihood)!!!!!
*/
{
  long i, j, refchain;
  double f1x, f2x, kk, other, **temp, maxtemp, result, **sign;

  temp = (double **)calloc(1, totchains * sizeof(double *));
  temp[firstlong] = (double *)calloc(1,op->numchains[1]*numtrees*sizeof(double));
  sign = (double **)calloc(1,totchains * sizeof(double *));
  sign[firstlong] = (double *)calloc(1,op->numchains[1]*numtrees*sizeof(double));
  for (i = firstlong + 1; i <= lastlong; i++) {
     temp[i] = temp[firstlong] + (i-firstlong)*numtrees;
     sign[i] = sign[firstlong] + (i-firstlong)*numtrees;
  }

  maxtemp = NEGMAX;
  for(i = firstlong; i <= lastlong; i++) {
     refchain = REF_CHAIN(i);
     for(j = 0; j < op->numout[1]; j++) {
        kk = sum[lowcus][refchain][j].kend[0];
        f1x = kk/theval;
        f2x = kk/(theval*theval);
        other = f2x*(-2*numints-2) + (numints*(numints+1))/theval + f1x*f2x;
        sign[i][j] = whatsign(other);
        temp[i][j] = treewt[i][j] - kk/theval + log(fabs(other));
        if (temp[i][j] > maxtemp) maxtemp = temp[i][j];
     }
  }

  result = 0.0;
  for(i = firstlong; i <= lastlong; i++)
     for(j = 0; j < op->numout[1]; j++) {
        if (temp[i][j] - maxtemp > EXPMIN)
           result += sign[i][j] * exp(temp[i][j] - maxtemp);
     }

  *dfxzero = zerocheck(result);
  *dfxplus = whatsign(result);

  if (!*dfxzero) result = log(fabs(result)) + maxtemp;

  free(temp[firstlong]);
  free(temp);
  free(sign[firstlong]);
  free(sign);

  return result;
} /* combined_dfnx */


double combined_locus_fnx(double theta, long numintervals, double ***treewt,
   long firstlong, long lastlong, boolean *return_zero, long *return_plus)
{
long lowcus, dfxplus;
double *fntheta, *dfntheta, temp1, temp2, result;
boolean dfxzero;

fntheta = (double *)calloc(1,numloci * sizeof(double));
dfntheta = (double *)calloc(1,numloci * sizeof(double));

temp1 = 0.0;
temp2 = 0.0;

for(lowcus = 0; lowcus < numloci; lowcus++) {
   fntheta[lowcus] = 
      combined_llike(theta,firstlong,lastlong,savethetai[lowcus],lowcus);
   dfntheta[lowcus] = 
      combined_fnx(theta,numintervals,treewt[lowcus],firstlong,
      lastlong,&dfxplus,&dfxzero,lowcus);
   temp1 += fntheta[lowcus];
   if (!dfxzero)
      if ((dfntheta[lowcus]-fntheta[lowcus]) > EXPMIN)
         temp2 += dfxplus * exp(dfntheta[lowcus]-fntheta[lowcus]);
}

*return_zero = zerocheck(temp2);
*return_plus = whatsign(temp2);

if (!*return_zero) result = temp1 + log(fabs(temp2));
else result = 0.0;

free(fntheta);
free(dfntheta);
return (result);

} /* combined_locus_fnx */

double combined_locus_dfnx(double theta, long numintervals, double ***treewt,
   long firstlong, long lastlong, boolean *return_zero, long *return_plus)
{
long lowcus, dfxplus, ddfxplus;
double *fntheta, *dfntheta, *idfntheta, *ddfntheta, temp1, temp2, 
   temp3, temp4, result;
boolean dfxzero, ddfxzero;

fntheta = (double *)calloc(1,numloci * sizeof(double));
dfntheta = (double *)calloc(1,numloci * sizeof(double));
idfntheta = (double *)calloc(1,numloci * sizeof(double));
ddfntheta = (double *)calloc(1,numloci * sizeof(double));

temp1 = 0.0;
temp2 = 0.0;
temp3 = 0.0;
temp4 = 0.0;

for(lowcus = 0; lowcus < numloci; lowcus++) {
   fntheta[lowcus] = 
      combined_llike(theta,firstlong,lastlong,savethetai[lowcus],lowcus);
   dfntheta[lowcus] = 
      combined_fnx(theta,numintervals,treewt[lowcus],firstlong,
      lastlong,&dfxplus,&dfxzero,lowcus);
/* this is from the generalized derivative of an inverse function */
   idfntheta[lowcus] =
      -1.0/(fntheta[lowcus]*fntheta[lowcus]) * dfntheta[lowcus];
   ddfntheta[lowcus] = 
      combined_dfnx(theta,numintervals,treewt[lowcus],firstlong,
      lastlong,&ddfxplus,&ddfxzero,lowcus);
   temp1 += fntheta[lowcus];
   if (!dfxzero) {
      if ((dfntheta[lowcus]-fntheta[lowcus]) > EXPMIN)
         temp2 += dfxplus * exp(dfntheta[lowcus]-fntheta[lowcus]);
      /* the next ratio is obligatorily negative on working out */
      if ((dfntheta[lowcus]-idfntheta[lowcus]) > EXPMIN)
         temp4 -= exp(dfntheta[lowcus]-idfntheta[lowcus]);
   }
   if (!ddfxzero)
      if ((ddfntheta[lowcus]-fntheta[lowcus]) > EXPMIN)
         temp3 += ddfxplus * exp(ddfntheta[lowcus]-fntheta[lowcus]);
}

temp2 = temp2*temp2+temp3+temp4;

*return_zero = zerocheck(temp2);
*return_plus = whatsign(temp2);

if (!*return_zero) result = temp1 + log(fabs(temp2));
else result = 0.0;

free(fntheta);
free(dfntheta);
free(idfntheta);
free(ddfntheta);
return (result);

} /* combined_locus_dfnx */

void combined_locus_weights(double **treewt, long firstlong, long lastlong, long
   lowcus, double *lthetai)
{
long i, dataset, trii, refchain;
double maxwt, *tempwt;

tempwt = (double *)calloc(1,(totchains + 1) * sizeof(double));

for (dataset = firstlong; dataset <= lastlong; dataset++) {
   refchain = REF_CHAIN(dataset);
   for (trii = 0; trii < op->numout[1]; trii++) {
      maxwt = NEGMAX;
      treewt[dataset][trii] = 0.0;
      for (i = firstlong; i <= lastlong; i++) {
         tempwt[i] =  log((double)op->numout[1]) - 
         (sum[lowcus][refchain][trii].kend[0]) / theti[lowcus][i] -
         lthetai[i] + (1 - numseq) * lntheti[lowcus][i];
         if (tempwt[i] > maxwt) maxwt = tempwt[i];
      }
      for (i = firstlong; i <= lastlong; i++)
        if (tempwt[i] - maxwt > EXPMIN)
           treewt[dataset][trii] += exp(tempwt[i] - maxwt);
      treewt[dataset][trii] = -(log(treewt[dataset][trii]) + maxwt);
   }
}

free(tempwt);
} /* combined_locus_weights */

/***********************************************************************
 * combined_llike returns the Ln(like) of a given theta, "thgiven",    *
 * over a set of chains, numbered "firstlong" to "lastlong", with a    *
 * pre-computed set of Ln(likelihoods) for each chain contained in     *
 * "lthetai".                                                          *
 * EQN: k = # of active (coalesceable) lineages within an interval.    *
 *      inter_lngth = length of an interval                            *
 *      thgiven = theta of interest                                    *
 *      thchain = theta under which the chain was run                  *
 *      numints = # of intervals in tree                               *
 *      numtree = # of trees in chain                                  *
 *      plike = provisional likelihood of a chain                      *
 *                                                                     *
 *      f1(x) = sum_over_intervals [(k*(k-1) * inter_lngth)/x]         *
 *                                                                     *
 *      numer = ((1/thgiven)**numints) * exp(-f1(thgiven))             *
 *      denom = sum_over_chains [numtree * exp(-f1(thchain)) /         *
 *                               (plike * (thchain**numints))]         *
 *                                                                     *
 *      likelihood = sum_over_all_trees [numer / denom]                *
 *                                                                     */
double combined_llike(double thgiven, long firstlong, long lastlong, 
   double *lthetai, long lowcus) 
{
  long i, j, refchain, dataset, trii, nchains;
  double numer, denom, lnthgiven, **num, bigsum, maxdenom, maxnum, 
     *tempdenom;

  nchains = lastlong + 1;
  num = (double **)calloc(1,nchains * sizeof (double *));
  num[firstlong] = (double *)calloc(1,op->numchains[1]*numtrees*sizeof(double));
  for (i = firstlong + 1 ; i <= lastlong; i++)
     num[i] = num[firstlong] + (i-firstlong)*numtrees;
  tempdenom = (double *)calloc(1,nchains * sizeof(double));

  bigsum = 0.0;
  lnthgiven = log(thgiven);
  maxnum = NEGMAX;
  for (dataset = firstlong; dataset <= lastlong; dataset++) {
    refchain = REF_CHAIN(dataset);
    for (trii = 0; trii < op->numout[1]; trii++) {
      numer = (1 - numseq) * lnthgiven - 
              (sum[lowcus][refchain][trii].kend[0]) / thgiven;
      denom = 0.0;
      maxdenom = NEGMAX;
      for (j = firstlong; j <= lastlong; j++) {
	tempdenom[j] =  log((double)op->numout[1]) - 
            (sum[lowcus][refchain][trii].kend[0]) / theti[lowcus][j] -
            lthetai[j] + (1 - numseq) * lntheti[lowcus][j];
	if (tempdenom[j] > maxdenom) maxdenom = tempdenom[j];
      }
      for (j = firstlong; j <= lastlong; j++)
        if (tempdenom[j] - maxdenom > EXPMIN)
	   denom += exp(tempdenom[j] - maxdenom);
      num[dataset][trii] = numer - log(denom) - maxdenom;
      if (num[dataset][trii] > maxnum) maxnum = num[dataset][trii];
    }
  }
  for (dataset = firstlong; dataset <= lastlong; dataset++) {
    for (trii = 0; trii < op->numout[1]; trii++) {
      if (num[dataset][trii] - maxnum > EXPMIN)
	bigsum += exp(num[dataset][trii] - maxnum);
    }
  }
  free(num[firstlong]);
  free(num);
  free(tempdenom);
  return (log(bigsum) + maxnum);
  /* we're just trying to sum up num, honest */
}  /* combined_llike */

/***********************************************************************
 * sum_combined_llike returns the Ln(like) of a given theta, "thgiven" *
 * over a set of loci, all loci present in input data, using only the  *
 * long chains run on each locus.                                      */
double sum_combined_llike(double theta, long firstlong, long lastlong, 
   double **lthetai)
{
long lowcus;
double temp;

temp = 0.0;
for (lowcus = 0; lowcus < numloci; lowcus++)
      temp += combined_llike(theta,firstlong,lastlong,lthetai[lowcus],lowcus);

return(temp);

} /* sum_combined_llike */

double coal_combined_locus_estimate()
{
int numloop; /* type "int" because of library function demands! */
long i, j, lowcus, firstlong, lastlong, numintervals,
   fxplus, dfxplus, chosen;
double ***treewt, newtheta, theta, oldlike, newlike, besttheta, 
   bestlike, fx, dfx, change;
boolean fxzero, dfxzero;

firstlong = op->numchains[0];
lastlong = op->numchains[0] + op->numchains[1] - 1;
numintervals = numseq - 1; /* WARNING: wrong for recombination!!! */

treewt = (double ***)calloc(1,numloci * sizeof(double **));
treewt[0] = (double **)calloc(1,numloci*totchains * sizeof(double *));
for(i = 1; i < numloci; i++)
   treewt[i] = treewt[0] + i*totchains;
treewt[0][0] = (double *)calloc(1,numloci*totchains*numtrees*sizeof(double));
for(i = 0; i < numloci; i++)
   for(j = 0; j < totchains; j++)
      treewt[i][j] = treewt[0][0] + i*totchains*numtrees + j*numtrees;

/* first calculate the weighting factor for each tree */
for (lowcus = 0; lowcus < numloci; lowcus++)
   combined_locus_weights(treewt[lowcus],firstlong,lastlong,lowcus,savethetai[lowcus]);
   
besttheta = NEGMAX;
bestlike = NEGMAX;

for (chosen = 0; chosen < numloci; chosen++) {
  theta = theti[chosen][totchains];
  oldlike = sum_combined_llike(theta,firstlong,lastlong,savethetai); 
  i = 0;

  while (1) {
     fx =
     combined_locus_fnx(theta,numintervals,treewt,firstlong,lastlong,&fxzero,&fxplus);
     dfx = 
     combined_locus_dfnx(theta,numintervals,treewt,firstlong,lastlong,&dfxzero,&dfxplus);

     if (fxzero) {/* found a maximum! at theta! */
        newlike = oldlike;
        break;
     }
     if (dfxzero) {
        theta += epsilon;
        fx =
    combined_locus_fnx(theta,numintervals,treewt,firstlong,lastlong,&fxzero,&fxplus);
        dfx = 
  combined_locus_dfnx(theta,numintervals,treewt,firstlong,lastlong,&dfxzero,&dfxplus);
     }
     if (dfxplus < 0) change = fxplus * exp(fx - dfx);
     else change = fxplus * theta/2.0;

     newtheta = theta + change;

    /* now deal with negative or zero theta values */
     numloop = 1;
     while (newtheta <= 0) {
        numloop++;
        newtheta = theta + ldexp(change,-numloop);
     }
     newlike = sum_combined_llike(newtheta,firstlong,lastlong,savethetai);
     if(newlike < oldlike) {
    /* in case we overshoot the maximum, don't jump so far...*/
        numloop = 1;
        while(1) {
           numloop++;
           newtheta = theta + ldexp(change,-numloop);
           if (newtheta <= 0) continue;
           newlike = sum_combined_llike(newtheta,firstlong,lastlong,
              savethetai);
           if(newlike >= oldlike) break;
        }
     }

     oldlike = newlike;
     if(fabs(newtheta - theta) < epsilon) {
        theta = newtheta;
        break;
     }
     theta = newtheta;
     i++;
  }

  if (newlike > bestlike) {
     besttheta = theta;
     bestlike = oldlike;
  }
}

fprintf(outfile,"Point estimate using %ld loci = %12.9f,",
        numloci,besttheta);
fprintf(outfile,"   with lnL = %12.9f\n",bestlike);

free(treewt[0][0]);
free(treewt[0]);
free(treewt);

return(besttheta);

} /* coal_combined_locus_estimate */


void coal_curveplot()
/* the multi-locus LnLike curve constructor. */
{
  long i, j, temp, firstlong, lastlong, numpoints, *sorted;
  double *printth, maxtheta;

  /* allocate local arrays */
  firstlong = op->numchains[0];
  lastlong = totchains - 1;
  numpoints = numloci + numfix;
  sorted = (long *)calloc(1,numpoints * sizeof(long));
  printth = (double *)calloc(1,numpoints * sizeof(double));

  fprintf(outfile,"\n---------------------------------------\n");
  fprintf(outfile,"Combined likelihood over all loci\n\n");
  fprintf(outfile,"---------------------------------------\n");
  for(i = 0; i < numloci; i++) printth[i] = theti[i][totchains];
  for(i = 0; i < numfix; i++) printth[i+numloci] = fixed[i];
  
  /* first calculate point estimate */
  maxtheta = coal_combined_locus_estimate();

  if (numloci > 1) {
  /* first calculated summed Lnlike for each locus' final estimate of
     theta */
  for (i = 0; i < numloci; i++)
     for(j = 0; j < numloci; j++)
        if (i != j)
           locuslike[i] += combined_llike(theti[i][totchains],firstlong,
              lastlong,savethetai[j],j);
  } else { /* single locus case */
     locuslike[0] = combined_llike(theti[0][totchains],firstlong,lastlong,
        savethetai[0],0);
  }

  /* now calculate each "fixed" point's Lnlike for each chain, and
     sum that up */
  for(i = 0; i < numfix; i++)
     locuslike[numloci + i] = sum_combined_llike(fixed[i],firstlong,
        lastlong,savethetai);

  /* sorted table of theti */
    for (i = 0; i < numpoints; i++) sorted[i] = i;

    for (i = 0; i < numpoints; i++)
      for (j = 1; j < numpoints; j++)
	if (printth[sorted[j - 1]] > printth[sorted[j]]) {
	  temp = sorted[j - 1];
	  sorted[j - 1] = sorted[j];
	  sorted[j] = temp;
	}

    fprintf(outfile, " There were %ld loci examined\n\n", numloci);
    fprintf(outfile, "   Theta       LnL\n");
    fprintf(outfile, "   -----       ---\n");

    for (i = 0; i < numpoints; i++) {
      fprintf(outfile, "%12.8f  %12.8f\n",printth[sorted[i]],
	      locuslike[sorted[i]]);
      }
    putc('\n', outfile);

  /* confidence interval on theta */
    theta_confidence(lastlong, maxtheta, -1L, savethetai);

  free(printth);
  free(sorted);

} /* coal_curveplot */
