/*

   CONVNC.C   .MEX file 
   Implements full N-D convolution.
   Inputs must be real, numeric, and double.

   Only one syntax is supported:
     C = CONVNC(A,B)

   Steven L. Eddins, January 1996
   Copyright 1984-2000 The MathWorks, Inc. 
*/

#include "mex.h"

static char rcsid[] = "$Revision: 1.6 $";

/* Input and output arguments */
#define A (prhs[0])
#define B (prhs[1])
#define C (plhs[0])

/* Macro to increment subscripts */
/* Increment subscript vector by one element */
/* in the direction of linear indexing. */
/* The subscript vector is assumed to be zero-based. */
#define INCREMENT_SUBSCRIPTS(SUBS,SIZE,NDIMS)                           \
{ int *subs=(SUBS);                                                     \
  const int *size=(SIZE);                                               \
  int ndims=(NDIMS);                                                    \
  int p;                                                                \
                                                                        \
  subs[0] += 1;                                                         \
  for (p = 0; p < (ndims-1); p++)                                       \
    if (subs[p] > (size[p]-1)) {                                            \
      subs[p] = 0;                                                     \
      subs[p+1] += 1;                                                   \
    }                                                                   \
}

/* Macro to convert a zero-based subscript vector to a */
/* linear index. */
#define SUBSCRIPTS_TO_LINEAR(SUBS, NDIMS, PDIMS, INDEX) \
{ const int *subs = (SUBS); \
  int ndims = (NDIMS); \
  const int *pdims = (PDIMS); \
  int *index = (INDEX); \
                        \
  int i = ndims; \
  int factor = 1; \
                  \
  *index = 0; \
              \
  while (i--) { \
    *index += *subs++ * factor; \
    factor *= *pdims++; \
  } \
}


static void convolve(double *c, const double *a, const double *b,
                     const int *sizeA, const int *sizeB, const int *sizeC,
                     int ndimsA, int ndimsB, int ndimsC, 
                     double *newflops)
{
  int lengthA=0;       /* length of input A */
  int lengthB=0;       /* length of input B */

  int *subsA=NULL;        /* subscript vector for A */
  int *subsB=NULL;        /* subscript vector for B */
  int *subsC=NULL;        /* subscript vector for C */

  int p=0;       /* loop counter */
  int q=0;       /* loop counter */
  int r=0;       /* loop counter */
  int linearIndexC=0;  /* linear index into output array */

  /* Input sanity checking */
  /* Any of the following error messages indicates that */
  /* something went seriously wrong in mexFunction().   */
  /* I would use utAssert() if the API had it.  -sle    */
  if (a==NULL || b==NULL || c==NULL) {
    mexErrMsgTxt("Internal consistency error in convolve()");
  }
  if (sizeA==NULL || sizeB==NULL || sizeC==NULL) {
    mexErrMsgTxt("Internal consistency error in convolve()");
  }
  if ((ndimsA != ndimsB) || (ndimsA != ndimsC)) {
    mexErrMsgTxt("Internal consistency error in convolve()");
  } 

  /* Compute the number of elements in inputs a and b */
  lengthA = 1;
  for (p = 0; p < ndimsA; p++) {
    lengthA *= sizeA[p];
  }
  lengthB = 1;
  for (p = 0; p < ndimsB; p++) {
    lengthB *= sizeB[p];
  }

  /* Initialize subscript vectors */
  subsA = mxCalloc(ndimsA, sizeof(int));
  subsB = mxCalloc(ndimsB, sizeof(int));
  subsC = mxCalloc(ndimsC, sizeof(int));
  if (subsA==NULL || subsB==NULL || subsC==NULL) {
    if (subsA != NULL) {
      mxFree(subsA);
    }
    if (subsB != NULL) {
      mxFree(subsB);
    }
    if (subsC != NULL) {
      mxFree(subsC);
    }
    mexErrMsgTxt("Memory allocation failure");
  }

  /* Initialize subscript array for a to be a(-1,0,0,...,0) */
  /* This is ok since subscripts will be incremented once before use. */
  subsA[0] = -1;
  for (p = 1; p < ndimsA; p++) {
    subsA[p] = 0;
  }

  /* Core computation loops */
  for (p = 0; p < lengthA; p++) {
    /* Increment subscript vector for A */
    INCREMENT_SUBSCRIPTS(subsA, sizeA, ndimsA);

    /* Initialize subscript array for a to be b(-1,0,0,...,0) */
    /* This is ok since subscripts will be incremented once before use. */
    subsB[0] = -1; 
    for (r = 1; r < ndimsA; r++) {
      subsB[r] = 0;
    }
    for (q = 0; q < lengthB; q++) {
      /* Increment subscript vector for B */
      INCREMENT_SUBSCRIPTS(subsB, sizeB, ndimsB);
      
      /* Where should the next partial product go in the output array? */
      /* Answer: subsC = subsA + subsB */
      for (r = 0; r < ndimsA; r++) {
        subsC[r] = subsA[r] + subsB[r];
      }

      /* But we need the answer as a linear index rather than a */
      /* subscript array */
      SUBSCRIPTS_TO_LINEAR(subsC, ndimsC, sizeC, &linearIndexC);
  
      /* Accumulate partial product */
      c[linearIndexC] += a[p] * b[q];
    }
  }

  /* Clean up and go home */
  mxFree(subsA);
  mxFree(subsB);
  mxFree(subsC);
  
  /* lengthA * lengthB multiplies, lengthA * lengthB adds */
  *newflops = 2.0 * lengthA * lengthB;
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  const int *sizeA=NULL;    /* Size of first input */
  const int *sizeB=NULL;    /* Size of second input */
  int *sizeC=NULL;          /* Size of output */
  int *adjustedSizeA=NULL;  /* Adjusted size of first input */
  int *adjustedSizeB=NULL;  /* Adjusted size of second input */
  int ndimsA=0;             /* Dimensionality of first input */
  int ndimsB=0;             /* Dimensionality of second input */
  int adjustedNDimsA=0;     /* Adjusted dimensionality of first input */
  int adjustedNDimsB=0;     /* Adjusted dimensionality of second input */
  double newflops=0.0;      /* Used to update flops count */
  int p=0;                  /* loop counter */

  /* Sanity check input */
  if (nrhs != 2) {
    mexErrMsgTxt("Two input arguments required");
  }
  if (!mxIsNumeric(A) || !mxIsDouble(A) || mxIsComplex(A) ||
      !mxIsNumeric(B) || !mxIsDouble(B) || mxIsComplex(B)) {
    mexErrMsgTxt("Inputs must be real, numeric, and double");
  }

  /* If either A or B is empty, return an empty matrix and go home. */
  if (mxIsEmpty(A) || mxIsEmpty(B)) {
    C = mxCreateDoubleMatrix(0,0,mxREAL);
    return;
  }

  ndimsA = mxGetNumberOfDimensions(A);
  ndimsB = mxGetNumberOfDimensions(B);
  sizeA = mxGetDimensions(A);
  sizeB = mxGetDimensions(B);

  /* Make dimensionality of A and B conform */
  if (ndimsA != ndimsB) {
    if (ndimsA > ndimsB) {
      adjustedNDimsA = ndimsA;
      adjustedNDimsB = ndimsA;
      adjustedSizeB = (int *) mxCalloc(adjustedNDimsB, sizeof(int));
      adjustedSizeA = (int *) mxCalloc(adjustedNDimsA, sizeof(int));
      if (adjustedSizeB == NULL || adjustedSizeA == NULL) {
        if (adjustedSizeB != NULL) mxFree((void *) adjustedSizeB);
        if (adjustedSizeA != NULL) mxFree((void *) adjustedSizeA);
        mexErrMsgTxt("Memory allocation failure");
      }
      for (p = 0; p < ndimsB; p++) {
        adjustedSizeB[p] = sizeB[p];
        adjustedSizeA[p] = sizeA[p];
      }
      for (p = ndimsB; p < adjustedNDimsB; p++) {
        adjustedSizeB[p] = 1;
        adjustedSizeA[p] = sizeA[p];
      }
    } else {
      adjustedNDimsA = ndimsB;
      adjustedNDimsB = ndimsB;
      adjustedSizeA = (int *) mxCalloc(adjustedNDimsA, sizeof(int));
      adjustedSizeB = (int *) mxCalloc(adjustedNDimsB, sizeof(int));
      if (adjustedSizeA == NULL || adjustedSizeB == NULL) {
        if (adjustedSizeA != NULL) mxFree((void *) adjustedSizeA);
        if (adjustedSizeB != NULL) mxFree((void *) adjustedSizeB);
        mexErrMsgTxt("Memory allocation failure");
      }
      for (p = 0; p < ndimsA; p++) {
        adjustedSizeA[p] = sizeA[p];
        adjustedSizeB[p] = sizeB[p];
      }
      for (p = ndimsA; p < adjustedNDimsA; p++) {
        adjustedSizeA[p] = 1;
        adjustedSizeB[p] = sizeB[p];
      }
    }
  } else {
    adjustedNDimsA = ndimsA;
    adjustedNDimsB = ndimsB;
    adjustedSizeA = (int *) mxCalloc(adjustedNDimsA, sizeof(int));
    adjustedSizeB = (int *) mxCalloc(adjustedNDimsB, sizeof(int));
    if (adjustedSizeA == NULL || adjustedSizeB == NULL) {
      if (adjustedSizeA != NULL) mxFree((void *) adjustedSizeA);
      if (adjustedSizeB != NULL) mxFree((void *) adjustedSizeB);
      mexErrMsgTxt("Memory allocation failure");
    }
    for (p = 0; p < adjustedNDimsA; p++) {
      adjustedSizeA[p] = sizeA[p];
      adjustedSizeB[p] = sizeB[p];
    }
  }

  /* Initialize output */
  sizeC = mxCalloc(adjustedNDimsA, sizeof(int));
  if (sizeC == NULL) {
    mxFree((void *) adjustedSizeA);
    mxFree((void *) adjustedSizeB);
    mexErrMsgTxt("Memory allocation failure");
  }
  for (p = 0; p < adjustedNDimsA; p++) {
    sizeC[p] = adjustedSizeA[p] + adjustedSizeB[p] - 1;
  }
  C = mxCreateNumericArray(adjustedNDimsA, sizeC, mxDOUBLE_CLASS, mxREAL);
  if (C == NULL) {
    mxFree((void *) adjustedSizeA);
    mxFree((void *) adjustedSizeB);
    mxFree((void *) sizeC);
    mexErrMsgTxt("Memory allocation failure");
  }

  convolve((double *) mxGetPr(C), (double *) mxGetPr(A), 
           (double *) mxGetPr(B), adjustedSizeA, adjustedSizeB,
           sizeC, adjustedNDimsA, adjustedNDimsB, adjustedNDimsB, &newflops);

  /* Update flops count */
  /* Uncomment this when mxAddFlops gets implemented */
#if 0
  mxAddFlops(newflops);
#endif /* 0 */
}
