/* $Revision: 1.5 $ */
/* Copyright 1984-2000 The MathWorks, Inc.  */

/*
 * B = BITSLICE(A, STARTBIT, ENDBIT)
 *
 * "slice" bits out the values in A, shifting them down.  STARTBIT is LSB,
 * ENDBIT is MSB.  For example, bitslice(uint8(255),5,8) is 15, and 
 * bitslice(uint8(128),7,8) is 2.
 *
 * A must be a uint8 array, and STARTBIT and ENDBIT must be double.
 *
 * B is a uint8 array that is the same size as A.
 *
 */

static char rcsid[] = "$Id: bitslice.c,v 1.5 2000/06/01 04:17:11 joeya Exp $";

#include <math.h>
#include "mex.h"

/*
 * ValidateInput --- input argument checking
 *
 * Inputs:   nlhs     --- number of left-side arguments
 *           nrhs     --- number of right-side arguments
 *           prhs     --- array of right-side arguments
 *
 * Outputs:  startBit --- least-significant bit of the slice
 *           endBit   --- most-significant bit of the slice
 *
 * Return:   none
 */
void ValidateInput(int nlhs, int nrhs, const mxArray *prhs[],
                    int *startBit, int *endBit)
{
    int i;

    if (nrhs > 3)
    {
        mexErrMsgTxt("Too many input arguments");
    }
    if (nrhs < 3)
    {
        mexErrMsgTxt("Too few input arguments");
    }
    if (nlhs > 1)
    {
        mexErrMsgTxt("Too many output arguments");
    }

    /* first arg must be uint8, others must be double */
    if (!mxIsUint8(prhs[0]))
    {
        mexErrMsgTxt("First input must be uint8");
    }
    if (!mxIsDouble(prhs[1]))
    {
        mexErrMsgTxt("Second input must be double");
    }
    if (!mxIsDouble(prhs[2]))
    {
        mexErrMsgTxt("Third input must be double");
    }

    for (i = 0; i < nrhs; i++) 
    {
        if (mxIsComplex(prhs[0]))
        {
            mexWarnMsgTxt("Ignoring imaginary part of input");
        }
    }

    if ((mxGetM(prhs[1]) * mxGetN(prhs[1])) > 1)
    {
        mexWarnMsgTxt("Second input should be a scalar");
    }

    if ((mxGetM(prhs[2]) * mxGetN(prhs[2])) > 1)
    {
        mexWarnMsgTxt("Third input should be a scalar");
    }

    *startBit = (int) floor(mxGetScalar(prhs[1]));
    *endBit = (int) floor(mxGetScalar(prhs[2]));

    if ((*startBit < 1) || (*startBit > 8) ||
        (*endBit < 1) || (*endBit > 8))
    {
        mexErrMsgTxt("STARTBIT and ENDBIT should be integers "
                     "1 and 8");
    }

    if (*endBit < *startBit)
    {
        mexErrMsgTxt("ENDBIT must be greater than STARTBIT");
    }
}



void mexFunction( int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray *prhs[] )
{
    int startBit, endBit;
    const mxArray *A;
    mxArray *B;
    uint8_T mask;
    int shift;
    uint8_T *prA;
    uint8_T *prB;
    int i;
    int length;

    ValidateInput(nlhs, nrhs, prhs, &startBit, &endBit);

    A = prhs[0];

    B = mxCreateNumericArray(mxGetNumberOfDimensions(A),
                             mxGetDimensions(A),
                             mxUINT8_CLASS,
                             mxREAL);

    mask = 0;
    for (i = startBit; i <= endBit; i++)
    {
        mask += 1 << (i-1);
    }
    shift = startBit - 1;
    
    prA = (uint8_T *) mxGetPr(A);
    prB = (uint8_T *) mxGetPr(B);
    length = mxGetM(A) * mxGetN(A);
    
    for (i = 0; i < length; i++)
    {
        *prB = (*prA & mask) >> shift;
        prB++;
        prA++;
    }

    plhs[0] = B;

}
