/*
* File: sdspbsub2.c
*
* Abstract:
*      S-function for solving Ux=b by backward substitution.
*      U is an upper (or unit upper) triangular full matrix.
*      The entries in the lower triangle are ignored.
*
* Copyright 1995-2000 The MathWorks, Inc.
* $Revision: 1.10 $ $Date: 2000/06/11 23:24:12 $
*/
#define S_FUNCTION_NAME  sdspbsub2
#define S_FUNCTION_LEVEL 2

#include "dsp_sim.h"

enum {INPORT_U=0, INPORT_B, NUM_INPORTS};
enum {OUTPORT_X=0, NUM_OUTPORTS};
enum {UNIT_ARGC=0, NUM_PARAMS};

#define UNIT_ARG (ssGetSFcnParam(S,UNIT_ARGC))

#ifdef MATLAB_MEX_FILE
#define MDL_CHECK_PARAMETERS
static void mdlCheckParameters(SimStruct *S) {
    
    if (!IS_FLINT_IN_RANGE(UNIT_ARG,0,1)) {
        THROW_ERROR(S, "Unit upper override option must be 0 or 1.");
    }
}
#endif


static void mdlInitializeSizes(SimStruct *S)
{
    ssSetNumSFcnParams(S, NUM_PARAMS);
    
#if defined(MATLAB_MEX_FILE)
    if (ssGetNumSFcnParams(S) != ssGetSFcnParamsCount(S)) return;
    mdlCheckParameters(S);
    if (ssGetErrorStatus(S) != NULL) return;
#endif
    
    /* Define ports: */
    if (!ssSetNumInputPorts(S, NUM_INPORTS)) return;
    if (!ssSetInputPortDimensionInfo(S, INPORT_U, DYNAMIC_DIMENSION)) return;
    ssSetInputPortFrameData(         S, INPORT_U, FRAME_INHERITED);
    ssSetInputPortDirectFeedThrough( S, INPORT_U, 1);
    ssSetInputPortComplexSignal(     S, INPORT_U, COMPLEX_INHERITED);
    ssSetInputPortReusable(          S, INPORT_U, 1);
    
    if (!ssSetInputPortDimensionInfo(S, INPORT_B, DYNAMIC_DIMENSION)) return;
    ssSetInputPortFrameData(         S, INPORT_B, FRAME_INHERITED);
    ssSetInputPortDirectFeedThrough( S, INPORT_B, 1);
    ssSetInputPortComplexSignal(     S, INPORT_B, COMPLEX_INHERITED);
    ssSetInputPortReusable(          S, INPORT_B, 1);
    ssSetInputPortOverWritable(      S, INPORT_B, 1);  /* Can overwrite OUTPORT_X */
    
    if (!ssSetNumOutputPorts(S, NUM_OUTPORTS)) return;
    if (!ssSetOutputPortDimensionInfo(S, OUTPORT_X, DYNAMIC_DIMENSION)) return;
    ssSetOutputPortFrameData(         S, OUTPORT_X, FRAME_NO);
    ssSetOutputPortComplexSignal(     S, OUTPORT_X, COMPLEX_INHERITED);
    
    ssSetNumSampleTimes(S, 1);
    ssSetOptions(S, SS_OPTION_EXCEPTION_FREE_CODE |
                 SS_OPTION_USE_TLC_WITH_ACCELERATOR);
}


static void mdlInitializeSampleTimes(SimStruct *S)
{
    ssSetSampleTime(S, 0, INHERITED_SAMPLE_TIME);
    ssSetOffsetTime(S, 0, 0.0);
}


#define MDL_START
static void mdlStart(SimStruct *S)
{
#ifdef MATLAB_MEX_FILE

    /* Additional error checking: 
     */
    const int_T      *dims_U     = ssGetInputPortDimensions(S,INPORT_U);
    const int_T      nrows_U     = dims_U[0];

    const int_T      numdims_B   = ssGetInputPortNumDimensions(S,INPORT_B);
    const int_T      *dims_B     = ssGetInputPortDimensions(S,INPORT_B);
    const int_T      ncols_B     = (numdims_B == 2) ? dims_B[1] : 1;
    const int_T      nrows_B     = dims_B[0];

    const int_T      numdims_X   = ssGetOutputPortNumDimensions(S,OUTPORT_X);
    const int_T      *dims_X     = ssGetOutputPortDimensions(S,OUTPORT_X);
    const int_T      ncols_X     = (numdims_X == 2) ? dims_X[1] : 1;
    const int_T      nrows_X     = dims_X[0];

    ErrorIfInputIsNotSquareMatrix(S, INPORT_U);

    /* All rows must be the same */
    if ((nrows_U != nrows_B) || (nrows_U != nrows_X)) { 
        THROW_ERROR(S, "Number of rows must be the same.");            
    }

    if (ncols_B != ncols_X) {
        THROW_ERROR(S, "Number of columns in B must be equal to columns in X.");
    }
#endif
}


static void mdlOutputs(SimStruct *S, int_T tid)
{
   /* Solve UX = B
    *   Inputs: U, B
    *   Output: X
    */
    const boolean_T cB         = (boolean_T)(ssGetInputPortComplexSignal(S,INPORT_B) == COMPLEX_YES);
    const boolean_T cU         = (boolean_T)(ssGetInputPortComplexSignal(S,INPORT_U) == COMPLEX_YES);

    const int_T      *dims_U   = ssGetInputPortDimensions(S,INPORT_U);
    const int_T      N         = dims_U[0];
    const int_T      N2        = ssGetInputPortWidth(S,INPORT_U);

    const int_T      numdims_B = ssGetInputPortNumDimensions(S,INPORT_B);
    const int_T      *dims_B   = ssGetInputPortDimensions(S,INPORT_B);
    const int_T      P         = (numdims_B == 2) ? dims_B[1] : 1; /* Number of right hand sides in B */
    const int_T      NP        = ssGetInputPortWidth(S,INPORT_B);

    const boolean_T unit_upper = (boolean_T)(mxGetPr(UNIT_ARG)[0] == 1.0);
    
    if (!cB && !cU) {
        /* Real inputs: */
        
        InputRealPtrsType  pU = ssGetInputPortRealSignalPtrs(S, INPORT_U)  + N2-1;
        InputRealPtrsType  pb = ssGetInputPortRealSignalPtrs(S, INPORT_B)  + NP-1;
        real_T            *x  = ssGetOutputPortRealSignal(   S, OUTPORT_X);
        int_T              i,k;
        
        for(k=P; k>0; k--) {
            InputRealPtrsType     pUcol = pU;
            for(i=0; i<N; i++) {
                real_T           *xj    = x + k*N-1;
                real_T            s     = 0.0;
                InputRealPtrsType pUrow = pUcol--;  /* access current row of U */
                
                {
                    int_T j = i;
                    while(j-- > 0) {
                        s += **pUrow * *xj--;
                        pUrow -= N;
                    }
                }
                
                if (unit_upper) {
                    *xj-- = **pb-- - s;
                } else {
                    *xj-- = (**pb-- - s) / **pUrow;
                }
            }
        }
        
    } else if (cB && !cU) {
        /* B is complex, U is real */
        
        InputRealPtrsType  pU = ssGetInputPortRealSignalPtrs(    S, INPORT_U)  + N2-1;
        InputPtrsType      pb = ssGetInputPortSignalPtrs(        S, INPORT_B)  + NP-1;
        creal_T           *x  = (creal_T *)ssGetOutputPortSignal(S, OUTPORT_X);
        int_T              i,k;
        
        for(k=P; k>0; k--) {
            InputRealPtrsType     pUcol = pU;
            for(i=0; i<N; i++) {
                creal_T          *xj    = x + k*N-1;
                creal_T           s     = {0.0, 0.0};
                InputRealPtrsType pUrow = pUcol--;  /* access current row of U */
                
                {
                    int_T j = i;
                    while(j-- > 0) {
                        s.re += **pUrow * xj->re;
                        s.im += **pUrow * (xj--)->im;
                        pUrow -= N;
                    }
                }
                
                if (unit_upper) {
                    const creal_T cb = *((creal_T *)(*pb--));
                    xj->re     = cb.re - s.re;
                    (xj--)->im = cb.im - s.im;
                } else {
                    const creal_T cb = *((creal_T *)(*pb--));
                    xj->re     = (cb.re - s.re) / **pUrow;
                    (xj--)->im = (cb.im - s.im) / **pUrow;
                }
            }
        }
        
    } else if (!cB && cU) {
        /* B is real, U is complex */
        
        InputPtrsType     pU = ssGetInputPortSignalPtrs(        S,INPORT_U)  + N2-1;
        InputRealPtrsType pb = ssGetInputPortRealSignalPtrs(    S,INPORT_B) + NP-1;
        creal_T          *x  = (creal_T *)ssGetOutputPortSignal(S,OUTPORT_X);
        int_T             i, k;
        
        for(k=P; k>0; k--) {
            InputPtrsType      pUcol = pU;
            for(i=0; i<N; i++) {
                creal_T       *xj    = x + k*N-1;
                creal_T        s     = {0.0, 0.0};
                InputPtrsType  pUrow = pUcol--;
                
                {
                    int_T j = i;
                    while(j-- > 0) {
                        /* Compute: s += U * xj, in complex */
                        const creal_T cU = *((creal_T *)(*pUrow));
                        pUrow -= N;
                        
                        s.re += CMULT_RE(cU, *xj);
                        s.im += CMULT_IM(cU, *xj);
                        xj--;
                    }
                }
                
                if (unit_upper) {
                    const real_T cb = **pb--;
                    xj->re     = cb - s.re;
                    (xj--)->im = cb - s.im;
                    
                } else {
                    /* Complex divide: *xj = cdiff / *cU */
                    const real_T  cb = **pb--;
                    const creal_T cU = *((creal_T *)(*pUrow));
                    creal_T       cdiff;
                    cdiff.re = cb - s.re;
                    cdiff.im = -s.im;
                    
                    CDIV(cdiff, cU, *xj);
                    xj--;
                }
            }
        }
        
    } else {
        /* Complex inputs: */
        
        InputPtrsType  pU = ssGetInputPortSignalPtrs(        S,INPORT_U)  + N2-1;
        InputPtrsType  pb = ssGetInputPortSignalPtrs(        S,INPORT_B)  + NP-1;
        creal_T       *x  = (creal_T *)ssGetOutputPortSignal(S,OUTPORT_X);
        int_T          i, k;
        
        for(k=P; k>0; k--) {
            InputPtrsType      pUcol = pU;
            for(i=0; i<N; i++) {
                creal_T       *xj    = x + k*N-1;
                creal_T        s     = {0.0, 0.0};
                InputPtrsType  pUrow = pUcol--;
                
                {
                    int_T j = i;
                    while(j-- > 0) {
                        /* Compute: s += L * xj, in complex */
                        const creal_T cU = *((creal_T *)(*pUrow));
                        pUrow -= N;
                        
                        s.re += CMULT_RE(cU, *xj);
                        s.im += CMULT_IM(cU, *xj);
                        xj--;
                    }
                }
                
                if (unit_upper) {
                    const creal_T cb = *((creal_T *)(*pb--));
                    xj->re     = cb.re - s.re;
                    (xj--)->im = cb.im - s.im;
                    
                } else {
                    /* Complex divide: *xj = cdiff / *cL */
                    const creal_T cb = *((creal_T *)(*pb--));
                    const creal_T cU = *((creal_T *)(*pUrow));
                    creal_T       cdiff;
                    cdiff.re = cb.re - s.re;
                    cdiff.im = cb.im - s.im;
                    
                    CDIV(cdiff, cU, *xj);
                    xj--;
                }
            }
        }
    }
}


static void mdlTerminate(SimStruct *S)
{
}



/* Dimension width prop for Backward Substitution:
 * U * X = B
 *
 * U is N x N
 * B is N x P
 * X is N x P
 */
#if defined(MATLAB_MEX_FILE)
#define MDL_SET_INPUT_PORT_DIMENSION_INFO
static void mdlSetInputPortDimensionInfo(SimStruct *S, 
                                      int_T port,
                                      const DimsInfo_T *dimsInfo)
{
    if ((port != INPORT_U) && (port != INPORT_B)) {
        THROW_ERROR(S,"Invalid call to input port dimension info propagation.");
    }

    if(!ssSetInputPortDimensionInfo(S, port, dimsInfo)) return;
 
    if (port == INPORT_U) {
        ErrorIfInputIsNotSquareMatrix(S, port);

        /* Check the number of rows are equal for the inputs */
        if (ssGetInputPortWidth(S,INPORT_B) != DYNAMICALLY_SIZED) {
            const int_T      numdims_B   = ssGetInputPortNumDimensions(S,INPORT_B);
            const int_T      *dims_B     = ssGetInputPortDimensions(S,INPORT_B);
            const int_T      ncols_B     = (numdims_B == 2) ? dims_B[1] : 1;
            const int_T      nrows_B     = dims_B[0];

            if (dimsInfo->dims[0] != nrows_B) { 
                THROW_ERROR(S, "Number of rows must be the same.");            
            }
        }

        if (ssGetOutputPortWidth(S,OUTPORT_X) != DYNAMICALLY_SIZED) {
            const int_T      numdims_X   = ssGetOutputPortNumDimensions(S,OUTPORT_X);
            const int_T      *dims_X     = ssGetOutputPortDimensions(S,OUTPORT_X);
            const int_T      ncols_X     = (numdims_X == 2) ? dims_X[1] : 1;
            const int_T      nrows_X     = dims_X[0];

            if (dimsInfo->dims[0] != nrows_X) { 
                THROW_ERROR(S, "Number of rows must be the same.");            
            }
        }
    
    } else {
        /* INPORT_B */

        if (ssGetInputPortWidth(S,INPORT_U) != DYNAMICALLY_SIZED) {
            const int_T      numdims_U   = ssGetInputPortNumDimensions(S,INPORT_U);
            const int_T      *dims_U     = ssGetInputPortDimensions(S,INPORT_U);
            const int_T      ncols_U     = (numdims_U == 2) ? dims_U[1] : 1;
            const int_T      nrows_U     = dims_U[0];

            if (dimsInfo->dims[0] != nrows_U) { 
                THROW_ERROR(S, "Number of rows must be the same.");            
            }

        }

        if (ssGetOutputPortWidth(S,OUTPORT_X) == DYNAMICALLY_SIZED) {
            const int_T N = dimsInfo->dims[0];
            const int_T P = (dimsInfo->numDims == 2) ? dimsInfo->dims[1] : 1;

            if (!ssSetOutputPortMatrixDimensions(S, OUTPORT_X, N, P)) return;

        } else {
            const int_T      numdims_X   = ssGetOutputPortNumDimensions(S,OUTPORT_X);
            const int_T      *dims_X     = ssGetOutputPortDimensions(S,OUTPORT_X);
            const int_T      ncols_X     = (numdims_X == 2) ? dims_X[1] : 1;
            const int_T      nrows_X     = dims_X[0];

            if (dimsInfo->dims[0] != nrows_X) { 
                THROW_ERROR(S, "Number of rows must be the same.");            
            }
        }
    }
}


# define MDL_SET_OUTPUT_PORT_DIMENSION_INFO
static void mdlSetOutputPortDimensionInfo(SimStruct        *S, 
                                          int_T            port,
                                          const DimsInfo_T *dimsInfo)
{
    if (port != OUTPORT_X) {
        THROW_ERROR(S,"Invalid call to output port dimension info propagation.");
    }

    if(!ssSetOutputPortDimensionInfo(S, port, dimsInfo)) return;

    ErrorIfOutputIsUnoriented(S, port);

    if (ssGetInputPortWidth(S,INPORT_B) == DYNAMICALLY_SIZED) {
        if (!ssSetInputPortDimensionInfo(S, INPORT_B, dimsInfo)) return;
    
    } else {
        const int_T      numdims_B   = ssGetInputPortNumDimensions(S,INPORT_B);
        const int_T      *dims_B     = ssGetInputPortDimensions(S,INPORT_B);
        const int_T      ncols_B     = (numdims_B == 2) ? dims_B[1] : 1;
        const int_T      nrows_B     = dims_B[0];

        if (dimsInfo->dims[0] != nrows_B) { 
            THROW_ERROR(S, "Number of rows must be the same.");            
        }
    }

    if (ssGetInputPortWidth(S,INPORT_U) != DYNAMICALLY_SIZED) {
        const int_T      numdims_U   = ssGetInputPortNumDimensions(S,INPORT_U);
        const int_T      *dims_U     = ssGetInputPortDimensions(S,INPORT_U);
        const int_T      ncols_U     = (numdims_U == 2) ? dims_U[1] : 1;
        const int_T      nrows_U     = dims_U[0];

        if (dimsInfo->dims[0] != nrows_U) { 
            THROW_ERROR(S, "Number of rows must be the same.");            
        }
    }
}

#define MDL_SET_INPUT_PORT_FRAME_DATA
static void mdlSetInputPortFrameData(SimStruct *S, 
                                      int_T port,
                                      Frame_T frameData)
{
    ssSetInputPortFrameData(S, port, frameData);
}
#endif

#include "dsp_cplxhs21.c"

#include "dsp_trailer.c"

/* [EOF] sdspbsub2.c */
