/*
* File: sdspqreslv2.c
*
* Abstract:
*      DSP Blockset S-function for QR Solve.
*
* Copyright 1995-2000 The MathWorks, Inc.
* $Revision: 1.15 $ $Date: 2000/06/11 23:24:15 $
*/
#define S_FUNCTION_NAME  sdspqreslv2
#define S_FUNCTION_LEVEL 2

#include "dsp_sim.h"
#include "dspqrsl_rt.h"
#include "dspqrdc_rt.h"


enum {QRAUX_IDX=0, WORK_IDX, QR_IDX, BX_IDX, MAX_NUM_DWORKS};
enum {JPVT_IDX, NUM_IWORKS};
enum {INPORT_A=0, INPORT_B, NUM_INPORTS};
enum {OUTPORT_X=0, NUM_OUTPORTS};
enum {NUM_PARAMS=0};


static void mdlInitializeSizes(SimStruct *S)
{
    ssSetNumSFcnParams(S,  NUM_PARAMS);
    if (ssGetNumSFcnParams(S) != ssGetSFcnParamsCount(S)) return;
    
    if (!ssSetNumInputPorts(S, NUM_INPORTS)) return;
    
    if (!ssSetInputPortDimensionInfo(S, INPORT_A, DYNAMIC_DIMENSION)) return;
    ssSetInputPortFrameData(         S, INPORT_A, FRAME_INHERITED);
    ssSetInputPortDirectFeedThrough( S, INPORT_A, 1);
    ssSetInputPortComplexSignal(     S, INPORT_A, COMPLEX_INHERITED);
    ssSetInputPortReusable(          S, INPORT_A, 1);
    ssSetInputPortOverWritable(      S, INPORT_A, 1);
    ssSetInputPortDataType(          S, INPORT_A, SS_DOUBLE);
    
    if (!ssSetInputPortDimensionInfo(S, INPORT_B, DYNAMIC_DIMENSION)) return;
    ssSetInputPortFrameData(         S, INPORT_B, FRAME_INHERITED);
    ssSetInputPortDirectFeedThrough( S, INPORT_B, 1);
    ssSetInputPortComplexSignal(     S, INPORT_B, COMPLEX_INHERITED);
    ssSetInputPortReusable(          S, INPORT_B, 1);
    ssSetInputPortOverWritable(      S, INPORT_B, 1);
    ssSetInputPortDataType(          S, INPORT_B, SS_DOUBLE);
    
    if (!ssSetNumOutputPorts(S, NUM_OUTPORTS)) return;
    if (!ssSetOutputPortDimensionInfo(S, OUTPORT_X, DYNAMIC_DIMENSION)) return;
    ssSetOutputPortFrameData(         S, OUTPORT_X, FRAME_NO);
    ssSetOutputPortComplexSignal(     S, OUTPORT_X, COMPLEX_INHERITED);
    ssSetOutputPortReusable(          S, OUTPORT_X, 1);
    ssSetOutputPortDataType(          S, OUTPORT_X, SS_DOUBLE);
    
    if(!ssSetNumDWork(S, DYNAMICALLY_SIZED)) return;
    ssSetNumIWork(S, DYNAMICALLY_SIZED);
    
    ssSetNumSampleTimes(S, 1);
    ssSetOptions(S, SS_OPTION_EXCEPTION_FREE_CODE |
                 SS_OPTION_USE_TLC_WITH_ACCELERATOR);
}


static void mdlInitializeSampleTimes(SimStruct *S)
{
    ssSetSampleTime(S, 0, INHERITED_SAMPLE_TIME);
    ssSetOffsetTime(S, 0, 0.0);
}


/*
* sdspqreslv accepts the qr factorization of a in qr, qraux and jpvt
* and a copy of b in bx and computes a minimum norm residual
* solution x to a*x=b in place in bx.
* If m>n, the first n entries of bx are copied to x at the end.
* Note: b, and hence x, may have more than one column.
*/
static void sdspqreslv_real(
                            int_T	 m,
                            int_T	 n,
							int_T	 p,
                            real_T	*qr,
                            real_T	*bx,
                            real_T	*qraux,
                            int_T	*jpvt,
                            real_T	*x
                            )
{
    int_T   i, j, k, minmn, maxmn, pjj;
    real_T  t, tol, *pbx, tmp, Zero = 0.0;
    
    minmn = MIN(m,n);
    maxmn = MAX(m,n);
    k = -1;
    t = fabs(*qr);
    tol = ((real_T) maxmn) * mxGetEps() * t;
    for (j=0; j<minmn; j++) {
        pjj = j * (m + 1);
        t = fabs(*(qr+pjj));
        if (t > tol) {
            k = j;
        }
    }
    
    k++;
    
#ifdef MATLAB_MEX_FILE
    if (k < minmn) {
        char msg[256];
        sprintf(msg, "Rank deficient, rank=%d, tol=%13.4e",k,tol);
        mexWarnMsgTxt(msg);
    }
#endif

	for (i=0; i<p; i++) {
	    /* Only use the first k columns of the qr factorization */
	    dspqrsl2_real(m, k, qr, qraux, bx+i*maxmn);
	}

	for(i=0; i<p; i++) {
	    /* Zero the remaining n-k entries of the solution */
	    pbx = bx + i*maxmn + k;
	    for (j=n-k; j-- > 0; ) {
	        *pbx++ = Zero;
	    }
	}

	/* swap columns according to jpvt */
    for (j=0; j<n; j++) {		
        k = jpvt[j];
        while (k != j) {
			for(i=0; i<p; i++) {
				tmp = *(bx+i*maxmn+j);
				*(bx+i*maxmn+j) = *(bx+i*maxmn+k);
				*(bx+i*maxmn+k) = tmp;
			}
			jpvt[j] = jpvt[k];
            jpvt[k] = k;
            k = jpvt[j];
        }
    }
    
    if (m>n) { /* Copy the first n entries of bx to x */
        real_T *px = x;
        
		for(j=0; j<p; j++) {
	        pbx = bx + j*maxmn;
	        for(i=n; i-- >0; ) {
	            *px++ = *pbx++;
	        }
		}
    }

}


static void sdspqreslv_cplx(
                            int_T    m,
                            int_T    n,
							int_T	 p,
                            creal_T	*qr,
                            creal_T *bx,
                            creal_T *qraux,
                            int_T   *jpvt,
                            creal_T *x
                            )
{
    int_T   i, j, k, minmn, maxmn, pjj;
    real_T  t, tol;
    creal_T *pbx, ctmp, Zero = {0.0, 0.0};
    
    
    minmn = MIN(m,n);
	maxmn = MAX(m,n);
    k = -1;
    t = CQABS(*qr);
    tol = ((real_T) maxmn) * mxGetEps() * t;
    for (j=0; j<minmn; j++) {
        pjj = j * (m + 1);
        t = CQABS(*(qr+pjj));
        if (t > tol) {
            k = j;
        }
    }
    
    k++;

#ifdef MATLAB_MEX_FILE
    if (k < minmn) {
        char msg[256];
        sprintf(msg, "Rank deficient, rank=%d, tol=%13.4e",k,tol);
        mexWarnMsgTxt(msg);
    }
#endif
    
	for (i=0; i<p; i++) {
	    /* Only use the first k columns of the qr factorization */
	    dspqrsl2_cplx(m, k, qr, qraux, bx+i*maxmn);
	}
    
	for(i=0; i<p; i++) {
	    /* Zero the remaining n-k entries of the solution */
	    pbx = bx + i*maxmn + k;
	    for (j=n-k; j-- > 0; ) {
	        *pbx++ = Zero;
	    }
	}

	/* swap columns according to jpvt */
    for (j=0; j<n; j++) {		
        k = jpvt[j];
        while (k != j) {
			for(i=0; i<p; i++) {
				ctmp = *(bx+i*maxmn+j);
				*(bx+i*maxmn+j) = *(bx+i*maxmn+k);
				*(bx+i*maxmn+k) = ctmp;
			}
			jpvt[j] = jpvt[k];
            jpvt[k] = k;
            k = jpvt[j];
        }
    }
    
    if (m>n) { /* Copy the first n entries of bx to x */
        creal_T *px = x;
        
		for(j=0; j<p; j++) {
	        pbx = bx + j*maxmn;
	        for(i=n; i-- >0; ) {
	            *px++ = *pbx++;
	        }
		}
    }
    
}


static void sdspqreslv_mixd(
                            int_T	 m,
                            int_T	 n,
							int_T	 p,
                            real_T  *qr,
                            creal_T *bx,
                            real_T	*qraux,
                            int_T	*jpvt,
                            creal_T *x
                            )
{
    int_T   i, j, k, minmn, maxmn, pjj;
    real_T  t, tol;
    creal_T *pbx, ctmp, Zero = {0.0, 0.0};
    
    
    minmn = MIN(m,n);
	maxmn = MAX(m,n);
    k = -1;
    t = fabs(*qr);
    tol = ((real_T) maxmn) * mxGetEps() * t;
    for (j=0; j<minmn; j++) {
        pjj = j * (m + 1);
        t = fabs(*(qr+pjj));
        if (t > tol) {
            k = j;
        }
    }
    
    k++;
    
#ifdef MATLAB_MEX_FILE
    if (k < minmn) {
        char msg[256];
        sprintf(msg, "Rank deficient, rank=%d, tol=%13.4e",k,tol);
        mexWarnMsgTxt(msg);
    }
#endif
    
	for (i=0; i<p; i++) {
	    /* Only use the first k columns of the qr factorization */
	    dspqrsl2_mixd(m, k, qr, qraux, bx+i*maxmn);
	}
    
	for(i=0; i<p; i++) {
	    /* Zero the remaining n-k entries of the solution */
	    pbx = bx + i*maxmn + k;
	    for (j=n-k; j-- > 0; ) {
	        *pbx++ = Zero;
	    }
	}

	/* swap rows of solution according to matrix column pivot jpvt */
    for (j=0; j<n; j++) {		
        k = jpvt[j];
        while (k != j) {
			for(i=0; i<p; i++) {
				ctmp = *(bx+i*maxmn+j);
				*(bx+i*maxmn+j) = *(bx+i*maxmn+k);
				*(bx+i*maxmn+k) = ctmp;
			}
			jpvt[j] = jpvt[k];
            jpvt[k] = k;
            k = jpvt[j];
        }
    }
    
    if (m>n) { /* Copy the first n entries of bx to x */
        creal_T *px = x;
        
		for(j=0; j<p; j++) {
	        pbx = bx + j*maxmn;
	        for(i=n; i-- >0; ) {
	            *px++ = *pbx++;
	        }
		}
    }
    
}


/*
* Compute the minimum norm residual solution X to A*X=B using
* the economy-sized qr (with col pivoting) of m-by-n input A:
* MATLAB equivalent:
* [Q,R,E] = qr(A,0);
* QTB = Q'*B;
* X(E,:) = R \ QTB(1:n);
* A is copied to DWORK QR.
* If m>n, B is copied to DWORK BX, X is computed in place
* and at the end its first n entries are copied to output X.
* If m<=n, B is copied to the first m entries of output X
* where X is computed in place.
*/
static void mdlOutputs(SimStruct *S, int_T tid)
{
    boolean_T cA = (boolean_T)(ssGetInputPortComplexSignal(S,INPORT_A) == COMPLEX_YES);
    boolean_T cB = (boolean_T)(ssGetInputPortComplexSignal(S,INPORT_B) == COMPLEX_YES);
    boolean_T cX = (cA || cB);

    const int_T numDimsA = ssGetInputPortNumDimensions(S, INPORT_A);
    const int_T numDimsB = ssGetInputPortNumDimensions(S, INPORT_B);
    const int_T *dimsA = ssGetInputPortDimensions(S, INPORT_A); 
    const int_T *dimsB = ssGetInputPortDimensions(S, INPORT_B); 
    int_T        M     = dimsA[0];
    int_T        N     = (numDimsA == 2) ? dimsA[1] : 1;
    int_T        P     = (numDimsB == 2) ? dimsB[1] : 1;
    int_T        MN    = ssGetInputPortWidth(S, INPORT_A);

    /* Copy input A to DWORK QR to be overwritten by its QR factorization. */
    {
        if (cA) {
            InputPtrsType pA = ssGetInputPortSignalPtrs(S, INPORT_A);
            creal_T *pQR     = (creal_T *)(ssGetDWork(S, QR_IDX));
            
            while(MN-- > 0) {
                *pQR++ = *((creal_T *)(*pA++));
            }
        } else {
            InputRealPtrsType pA = ssGetInputPortRealSignalPtrs(S, INPORT_A);
            real_T *pQR          = (real_T *)(ssGetDWork(S, QR_IDX));
            
            while(MN-- > 0) {
                *pQR++ = **pA++;
            }
        }
    }
			
	/* May possibly overwrite INPORT_B with OUTPORT_X. */
    {
        boolean_T need_copy = (boolean_T)(ssGetInputPortBufferDstPort(S, INPORT_B) != OUTPORT_X);
        
        if (need_copy) {
            creal_T *pBX = (creal_T *)((M>N) ? ssGetDWork(S, BX_IDX) : ssGetOutputPortSignal(S, OUTPORT_X));
            int_T i;
			int_T j = P;
			int_T NminusM = MAX((N-M),0);

			if (cX) {
                if (cB) { /* X is initialized to a complex copy of complex B */
                    InputPtrsType pB = ssGetInputPortSignalPtrs(S, INPORT_B);
                 
					while (j-- > 0) {
						i = M;
						while(i-- > 0) {
							*pBX++ = *((creal_T *)(*pB++));
						}
						pBX += NminusM;
					}
                } else { /* X is initialized to a complex copy of real B */
                    InputRealPtrsType pB = ssGetInputPortRealSignalPtrs(S, INPORT_B);

					while (j-- > 0) {
						i = M;
						while(i-- > 0) {
							pBX->re = *((real_T *)(*pB++));
							pBX->im = 0.0;
							pBX++;
						}
						pBX += NminusM;
					}
                }
            } else { /* X is initialized to a real copy of real B */
                InputRealPtrsType pB = ssGetInputPortRealSignalPtrs(S, INPORT_B);
                real_T *pBX = (real_T *)((M>N) ? ssGetDWork(S, BX_IDX) : ssGetOutputPortSignal(S, OUTPORT_X));
                real_T *p = pBX;

				while (j-- > 0) {
					i = M;
					while(i-- > 0) {
	                    *pBX++ = **pB++;
					}
					pBX += NminusM;
				}
            }
        }
    }
    
    /* Find a minimum norm residual solution to A*X=B. */
    {
        void  *pQR    = ssGetDWork(S, QR_IDX);
        void  *pBX    = (M>N) ? ssGetDWork(S, BX_IDX) : ssGetOutputPortSignal(S, OUTPORT_X);
        void  *pqraux = ssGetDWork(S, QRAUX_IDX);
        void  *pwork  = ssGetDWork(S, WORK_IDX);
        int_T *pjpvt  = ssGetIWork(S);
        void  *pX     = (M>N) ? ssGetOutputPortSignal(S, OUTPORT_X) : (void *)0;

		/* Reset the pivot indices: */
        memset(pjpvt, 0, N*sizeof(int_T));

        if (cA) { /* A is complex.  B may be real, but X and BX are always complex. */
			
			/* Overwrite QR with the complex qr factorization of complex A. */
            dspqrdc_cplx(M, N, (creal_T *)pQR, (creal_T *)pqraux, pjpvt, (creal_T *)pwork);

			/* Solve for X using qr factorization. */
			sdspqreslv_cplx(M, N, P, (creal_T *)pQR, (creal_T *)pBX,
				(creal_T *)pqraux, pjpvt, (creal_T *)pX);

        } else { /* A is real */

            if (cB) { /* A is real and B is complex => X is complex */

				/* Overwrite QR with the real qr factorization of real A. */
	            dspqrdc_real(M, N, (real_T *)pQR, (real_T *)pqraux, pjpvt, (real_T *)pwork);
			
				/* Solve for X using qr factorization. */
				sdspqreslv_mixd(M, N, P, (real_T *)pQR, (creal_T *)pBX,
					(real_T *)pqraux, pjpvt, (creal_T *)pX);

            } else { /* A and B are both real => X is real */

				/* Overwrite QR with the real qr factorization of real A. */
				dspqrdc_real(M, N, (real_T *)pQR, (real_T *)pqraux, pjpvt, (real_T *)pwork);
				
				/* Solve for X using qr factorization. */
				sdspqreslv_real(M, N, P, (real_T *)pQR, (real_T *)pBX,
					(real_T *)pqraux, pjpvt, (real_T *)pX);
            }
        }
    }
}


static void mdlTerminate(SimStruct *S)
{
}

/* Dimension width prop for QR Solver:
 * A * x = b
 *
 * A is M x N
 * b is M x P
 * x is N x P
 */
#if defined(MATLAB_MEX_FILE)
#define MDL_SET_INPUT_PORT_DIMENSION_INFO
static void mdlSetInputPortDimensionInfo(SimStruct *S, 
                                      int_T port,
                                      const DimsInfo_T *dimsInfo)
{
    if(!ssSetInputPortDimensionInfo(S, port, dimsInfo)) return;
 
    /* Port A can be any dimension. No checking required.
     * Note: Unoriented inputs are interpreted as column vectors.
     */

    /* Port B must be a column vector. 
     * Note: Scalars and unoriented vectors are interpreted as a column vector 
     */

    /* Both ports are set:
     *  Perform additional error checking across dimensions.
     *  Set output port dimensions.
     */
    if ( (!isInputDynamicallySized(S, INPORT_A)) && 
         (!isInputDynamicallySized(S, INPORT_B)) &&
         (isOutputDynamicallySized(S, OUTPORT_X)) 
       ) {

        const int_T numDimsA = ssGetInputPortNumDimensions(S, INPORT_A);
        const int_T numDimsB = ssGetInputPortNumDimensions(S, INPORT_B);
        const int_T *dimsA = ssGetInputPortDimensions(S, INPORT_A); 
        const int_T *dimsB = ssGetInputPortDimensions(S, INPORT_B); 
        int_T        M     = dimsA[0];
        int_T        N     = (numDimsA == 2) ? dimsA[1] : 1;
        int_T        P     = (numDimsB == 2) ? dimsB[1] : 1;

        if (M != dimsB[0]) {
            THROW_ERROR(S,"Number of rows of input A must match number of rows of input B.");
        }

        if (!ssSetOutputPortMatrixDimensions(S, OUTPORT_X, N, P)) return;
    }
}


# define MDL_SET_OUTPUT_PORT_DIMENSION_INFO
static void mdlSetOutputPortDimensionInfo(SimStruct        *S, 
                                          int_T            port,
                                          const DimsInfo_T *dimsInfo)
{
    if(!ssSetOutputPortDimensionInfo(S, port, dimsInfo)) return;
}


#define MDL_SET_INPUT_PORT_FRAME_DATA
static void mdlSetInputPortFrameData(SimStruct *S, 
                                      int_T port,
                                      Frame_T frameData)
{
    ssSetInputPortFrameData(S, port, frameData);
}


#define MDL_SET_WORK_WIDTHS
static void mdlSetWorkWidths(SimStruct *S)
{
    const int_T numDimsA = ssGetInputPortNumDimensions(S, INPORT_A);
    const int_T numDimsB = ssGetInputPortNumDimensions(S, INPORT_B);
    const int_T *dimsA = ssGetInputPortDimensions(S, INPORT_A); 
    const int_T *dimsB = ssGetInputPortDimensions(S, INPORT_B); 
    int_T        M     = dimsA[0];
    int_T        N     = (numDimsA == 2) ? dimsA[1] : 1;
    int_T        P     = (numDimsB == 2) ? dimsB[1] : 1;

    const int_T N_DWORKS = (M > N) ? MAX_NUM_DWORKS : MAX_NUM_DWORKS-1;

    CSignal_T cA = ssGetInputPortComplexSignal(S, INPORT_A);
    CSignal_T cX = (cA || ssGetInputPortComplexSignal(S, INPORT_B));
    
    if(!ssSetNumDWork(      S, N_DWORKS)) return;
    
    ssSetDWorkWidth(        S, QRAUX_IDX, N);
    ssSetDWorkDataType(     S, QRAUX_IDX, SS_DOUBLE);
    ssSetDWorkComplexSignal(S, QRAUX_IDX, cA);
    
    ssSetDWorkWidth(        S, WORK_IDX, N);
    ssSetDWorkDataType(     S, WORK_IDX, SS_DOUBLE);
    ssSetDWorkComplexSignal(S, WORK_IDX, cA);
    
    ssSetDWorkWidth(        S, QR_IDX, M*N);
    ssSetDWorkDataType(     S, QR_IDX, SS_DOUBLE);
    ssSetDWorkComplexSignal(S, QR_IDX, cA);
    
    if (N_DWORKS == MAX_NUM_DWORKS) {
        ssSetDWorkWidth(        S, BX_IDX, MAX(M,N)*P);
        ssSetDWorkDataType(     S, BX_IDX, SS_DOUBLE);
        ssSetDWorkComplexSignal(S, BX_IDX, cX);
    }
    
    ssSetNumIWork(S, N);
    
}
#endif

#include "dsp_cplxhs21.c"

#include "dsp_trailer.c"

/* [EOF] sdspqreslv2.c */
