/****************************************************************************
 *
 * BlkTri.c
 *
 * Abstract:
 *   Main file to perform block Lanczos tridiagonalization.
 *
 * Student:  Guohong Liu
 *
 * ID:       0385117
 *
 ****************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <malloc.h>
#include <memory.h>
#include <math.h>
#include "blklan.h"

/****************************************************************************
 *
 * Global variables
 *
 ****************************************************************************/

extern int n;                   /* Order of given matrix A.         */
extern int bs;                  /* Block size.                      */
extern double eps;              /* Epsilon.                         */
extern double mideps;

extern DoubleComplexMat *A;     /* Matrix A. n * n                  */
extern DoubleComplexMat *Q;     /* Unitary. one block = n * bs      */  
extern DoubleComplexMat *M;     /* Diagonal blocks. bs * bs.        */
extern DoubleComplexMat *B;     /* Off-Diagonal blocks. bs * bs     */
extern DoubleComplexMat *WCur;  /* Detectors of current iter. bs*bs */
extern DoubleComplexMat *WOld;  /* Detectors of previous iter.bs*bs */
extern DoubleComplexMat *R;     /* Work matrix. n * bs              */
extern DoubleComplexMat *Work1; /* Work matrix. bs * bs             */
extern DoubleComplexMat *Work2; /* Work matrix. bs * bs             */

/* Memories for corresponding matrix    */
extern DoubleComplex *Q_m;      
extern DoubleComplex **M_m;     
extern DoubleComplex **B_m;
extern DoubleComplex **WCur_m;  
extern DoubleComplex **WOld_m;  

extern int steps;
int up;
Boolean doOrtho;
Boolean second;
int nBlk;

#ifdef MATLAB_COMP
extern DoubleComplex* RandMat;
#endif

/***************************************************************************
 *
 * Abstract
 *   Main function to perform block tridiagonalization.
 *
 * Input
 *   None
 *
 * Output
 *   None
 *
 ***************************************************************************/

void blkTri ()
{
    /* DoubleComplexMat *Q points to current block in matrix Q.
     * Q_pre and Q_next point to previous and next block respectively.
     */
    DoubleComplexMat Q_pre, Q_next;
    DoubleComplex **tmp;
    int j;
    
    /* Initialize parameters.   */
    up = 0;
    doOrtho = FALSE;
    second = FALSE;
    nBlk = 0;

    /* Decrease the index of array so that from then on the index
     * of array starts from 1, not from 0. 
     */
    --M_m;
    --B_m;
    --WCur_m;
    --WOld_m;
    
    /* Set parameters of matrices.  */
    M->m = bs;
    M->n = bs;
    B->m = bs;
    B->n = bs;
    WCur->m = bs;
    WCur->n = bs;
    WOld->m = bs;
    WOld->n = bs;

    Q_pre.m = Q->m;
    Q_pre.n = Q->n;
    Q_pre.mat = Q->mat - Q->m * bs;
    Q_next.m = Q->m;
    Q_next.n = Q->n;
    Q_next.mat = Q->mat + Q->m * bs;

    /* R = A * conjg (Q(1))                 */
    mmult ("+", A, "N", Q, "C", NULL, R);      
    /* M(1) = Q(1)' * A * conjg (Q(1))      */
    M->mat = M_m[1];
    mmult ("+", Q, "H", R, "N", NULL, M);
    /* R = R - Q(1) * M(1)                  */
    mmult ("-", Q, "N", M, "N", R, R);
    /* [Q(2), B(1)] = qr (R)                */
    B->mat = B_m[1];
    qr (R, &Q_next, B);

    /* eps*bs*0.6*B(:,:,1).*(rand(bs,bs)+IM*rand(bs,bs)))/B(:,:,1). */
#ifdef MATLAB_COMP
    memcpy (Work1->mat, RandMat, bs * bs * sizeof(DoubleComplex));
#else
    randComplexMat (Work1);
#endif
    memult (eps*bs*0.6, B, Work1, NULL, Work1);
    WCur->mat = WCur_m[1];
    mdiv (Work1, B, WCur);

    tmp = WCur_m;
    WCur_m = WOld_m;
    WOld_m = tmp;
    for (j = 2; j <= steps - 1; j++) {
        /* Specify next Q block                     */
        Q->mat += Q->m * bs;
        Q_pre.mat = Q->mat - Q->m * bs;
        Q_next.mat = Q->mat + Q->m * bs;

        /* R = A * conjg(Q(j))                      */
        mmult ("+", A, "N", Q, "C", NULL, R);
        /* M(j) = Q(j)' * A * conjg(Q(j))           */
        M->mat = M_m[j];
        mmult ("+", Q, "H", R, "N", NULL, M);
        /* R = R - Q(j) * M(j) - Q(j-1) * B(j-1).'  */
        mmult ("-", Q, "N", M, "N", R, R);
        B->mat = B_m[j-1];
        mmult ("-", &Q_pre, "N", B, "T", R, R);
        /* [Q(j+1), B(j)] = qr (R, 0)               */
        B->mat = B_m[j];
        qr (R, &Q_next, B);
        
        if (second == FALSE) {
            /* Not second orthogonalization. 
             * Compute W(k,j+1), k = 1,...j        
             */
            detectW (j);
            if ((doOrtho == TRUE) && (up < j)) {
                /* If loss of orthogonalization was found 
                 * and not the last W.
                 */
                orthInterval (j);
            }
        }
        else {
            if (up < j) {
                /* This is second orthogonalization.
                 * Compute Ws in [up, j]
                 */
                completeW (j);
            }
        }
        
        tmp = WCur_m;
        WCur_m = WOld_m;
        WOld_m = tmp;

        if ((doOrtho == TRUE) || (second == TRUE)) {
            /* Orthogonalization Q(j+1) against Q(k), 
             * k is inside the interval.
             */
            orthR (j);
        }

#ifdef INFO
        if (j == steps - 1) { 
            blkTriInfo (j);
        }
#endif
    }

    /* The last iteration               */
    Q->mat += Q->m * bs;
    mmult ("+", A, "N", Q, "C", NULL, R);
    /* M(j) = Q(j)' * A * conjg(Q(j))   */
    M->mat = M_m[steps];
    mmult ("+", Q, "H", R, "N", NULL, M);
    
    /* Q will be used in lanTri (), so set Q to 
     * represent the whole matrix.
     */
    Q->mat = Q_m;
    Q->m = n;
    Q->n = n;

    /* Restore the index of array to normal.
     * void blkTriEnd () still use W1[0] to release.
     */
    ++M_m;
    ++B_m;
    ++WCur_m;
    ++WOld_m;
}

/***************************************************************************
 *
 * Abstract
 *   Compute W, the detector of the loss of orthogonality.
 *
 * Input  
 *   iter   The iteration.
 *
 * Output    
 *   None
 *
 ***************************************************************************/

void detectW (int iter)
{
    int k, i, j, len;
    DoubleComplex *mat;
    DoubleComplexMat B2;
    DoubleComplex w;

    len = bs * bs;
    B2.m = bs;
    B2.n = bs;
    /* Matlab is the same as C for this function afte j is introduced.*/
    j = iter;
    k = 1;
    while ((k <= j) && (doOrtho != TRUE)) {
        /* Compute W (k, j+1), k = 1,...,j  */
        if (k == j) {
            /* W(:,:,k,old) = (eps*bs*0.6)*(B(:,:,1) 
             *                .*(randn(bs,bs)+IM*randn(bs,bs))); 
             */
#ifdef MATLAB_COMP
            memcpy (Work1->mat, RandMat, bs * bs * sizeof(DoubleComplex));
#else
            randComplexMat (Work1);
#endif
            B->mat = B_m[1];
            WOld->mat = WOld_m[k];
            memult (eps*bs*0.6, B, Work1, NULL, WOld);           
        }
        else {
            /* W(:,:,k,old) = M(:,:,k)*conj(W(:,:,k,cur)) 
             *                - W(:,:,k,cur)*M(:,:,j) 
             *                + (eps*0.3)*((B(:,:,1)+B(:,:,2)) 
             *                .*(randn(bs,bs)+IM*randn(bs,bs)));
             */
            M->mat = M_m[k]; 
            WCur->mat = WCur_m[k];
            WOld->mat = WOld_m[k];
            mmult ("+", M, "N", WCur, "C", NULL, WOld);
            M->mat = M_m[j];
            mmult ("-", WCur, "N", M, "N", WOld, WOld);
            B->mat = B_m[1];
            B2.mat = B_m[2];
            mplus (B, &B2, Work1);
#ifdef MATLAB_COMP
            memcpy (Work2->mat, RandMat, bs * bs * sizeof(DoubleComplex));
#else
            randComplexMat (Work2);
#endif
            WOld->mat = WOld_m[k];
            memult (eps * 0.3, Work1, Work2, WOld, WOld);
        
            if (j > 2) {
                if (k > 1) {
                    /* W(:,:,k,old) = W(:,:,k,old) 
                     *                + B(:,:,k-1)*conj(W(:,:,k-1,cur));
                     */
                    B->mat = B_m[k-1];
                    WCur->mat = WCur_m[k-1];
                    WOld->mat = WOld_m[k];
                    mmult ("+", B, "N", WCur, "C", WOld, WOld);
                }
                if (k < j - 1) {
                    /* W(:,:,k,old) = W(:,:,k,old) 
                     *                + ((B(:,:,k).')*conj(W(:,:,k+1,cur)) 
                     *                - W(:,:,k,old)*(B(:,:,j-1).'));
                     */
                    B->mat = B_m[k];
                    WCur->mat = WCur_m[k+1];
                    mmult ("+", B, "T", WCur, "C", NULL, Work1);
                    WOld->mat = WOld_m[k];
                    B->mat = B_m[j-1];
                    mmult ("-", WOld, "N", B, "T", Work1, Work1);
                    mplus (Work1, WOld, WOld);
                }
            }
        }
        /* W(:,:,k,old) = W(:,:,k,old)/B(:,:,j);        */
        WOld->mat = WOld_m[k];
        B->mat = B_m[j];
        mdiv (WOld, B, WOld);

        /* Find the first W which loses orthogonality.  */
        mat = WOld->mat;
        /* For each column of W(:,:,k,j+1)              */
        for (i = 0; i < len; i++) {
            w.r = mat[i].r;
            w.i = mat[i].i;
            if (w.r * w.r + w.i * w.i >= eps) {
                /* Find loss of orthogonalization.      */
                doOrtho = TRUE;
                up = k;
                break;
            }
        }
        k++;
    }
}

/***************************************************************************
 *
 * Abstract
 *   Determine the orthogonalization intervals.
 *
 * Input
 *   iter   The iteration
 *
 * Output
 *   None
 *
 ***************************************************************************/

void orthInterval (int iter)
{
    Boolean thresh;
    DoubleComplex *mat;
    int k, i, j, len;
    DoubleComplexMat B2;
    DoubleComplex w;

    j = iter;
    B2.m = bs;
    B2.n = bs;
    len = bs * bs;
    thresh = FALSE;
    k = j;
    while ((k >= 2) && (thresh != TRUE)) {
        if (k == j) {
            /* W(:,:,k,old) = (eps*bs*0.6)*(B(:,:,1) 
             *                .*(randn(bs,bs)+IM*randn(bs,bs)));
             */
#ifdef MATLAB_COMP
            memcpy (Work1->mat, RandMat, bs * bs * sizeof(DoubleComplex));
#else
            randComplexMat (Work1);
#endif
            B->mat = B_m[1];
            WOld->mat = WOld_m[k];
            memult (eps*bs*0.6, B, Work1, NULL, WOld);           
        }
        else {
            /* W(:,:,k,old) = M(:,:,k)*conj(W(:,:,k,cur)) 
             *                + B(:,:,k-1)*conj(W(:,:,k-1,cur)) 
             *                - W(:,:,k,cur)*M(:,:,j) 
             *                + (eps*0.3)*((B(:,:,k)+B(:,:,j)) 
             *                .*(randn(bs,bs)+IM*randn(bs,bs)));
             */
            M->mat = M_m[k];
            WCur->mat = WCur_m[k];
            WOld->mat = WOld_m[k];
            mmult ("+", M, "N", WCur, "C", NULL, WOld);

            B->mat = B_m[k-1];
            WCur->mat = WCur_m[k-1];
            mmult ("+", B, "N", WCur, "C", WOld, WOld);

            WCur->mat = WCur_m[k];
            M->mat = M_m[j];
            mmult ("-", WCur, "N", M, "N", WOld, WOld);
        
            B->mat = B_m[k];
            B2.mat = B_m[j];
            mplus (B, &B2, Work1);
#ifdef MATLAB_COMP
            memcpy (Work2->mat, RandMat, bs * bs * sizeof(DoubleComplex));
#else
            randComplexMat (Work2);
#endif
            WOld->mat = WOld_m[k];
            memult (eps * 0.3, Work1, Work2, WOld, WOld);
            
            if (k < j - 1) {
                /* W(:,:,k,old) = W(:,:,k,old) 
                 *                + (B(:,:,k).')*conj(W(:,:,k+1,cur)) 
                 *                - W(:,:,k,old)*(B(:,:,j-1).');
                 */
                B->mat = B_m[k];
                WCur->mat = WCur_m[k+1];
                mmult ("+", B, "T", WCur, "C", NULL, Work1);

                WOld->mat = WOld_m[k];
                B->mat = B_m[j-1];
                mmult ("-", WOld, "N", B, "T", Work1, Work1);
                mplus (Work1, WOld, WOld);
            }
        }
        /* W(:,:,k,old) = W(:,:,k,old)/B(:,:,j);        */
        WOld->mat = WOld_m[k];
        B->mat = B_m[j];
        mdiv (WOld, B, WOld);

        /* Check if W_old[k] exceeds mideps.            */
        mat = WOld->mat;
        for (i = 0; i < len; i++) {
            /* mideps is already eps^7/4, not eps^7/8   */
            w.r = mat[i].r;
            w.i = mat[i].i;
            if (w.r * w.r + w.i * w.i >= mideps) {
                thresh = 1;
                up = k;
                break;
            }
        }
        k--;
    }
}

/***************************************************************************
 *
 * Abstract
 *   In the second orthogonalization stage, compute W's which are not
 *   computed in last iteration due to orthogonalization.
 *
 * Input
 *   iter   The iteration
 *
 * Output
 *   None
 *
 ***************************************************************************/

void completeW (int iter)
{
    int k, j;
    DoubleComplexMat B2;

    j = iter;
    B2.m = bs;
    B2.n = bs;

    for (k = up; k <= j; k++) {
        if (k == j) {
            /* W(:,:,k,old) = (eps*bs*0.6)*(B(:,:,1) 
             *                .*(randn(bs,bs)+IM*randn(bs,bs)));
             */
#ifdef MATLAB_COMP
            memcpy (Work1->mat, RandMat, bs * bs * sizeof(DoubleComplex));
#else
            randComplexMat (Work1);
#endif
            B->mat = B_m[1];
            WOld->mat = WOld_m[k];
            memult (eps*bs*0.6, B, Work1, NULL, WOld);           
        }
        else {
            /* W(:,:,k,old) = M(:,:,k)*conj(W(:,:,k,cur)) 
             *                + B(:,:,k-1)*conj(W(:,:,k-1,cur)) 
             *                - W(:,:,k,cur)*M(:,:,j) 
             *                + (eps*0.3)*(B(:,:,k)+B(:,:,j)) 
             *                .*(randn(bs,bs)+IM*randn(bs,bs));
             */
            M->mat = M_m[k];
            WCur->mat = WCur_m[k];
            WOld->mat = WOld_m[k];
            mmult ("+", M, "N", WCur, "C", NULL, WOld);

            B->mat = B_m[k-1];
            WCur->mat = WCur_m[k-1];
            mmult ("+", B, "N", WCur, "C", WOld, WOld);

            WCur->mat = WCur_m[k];
            M->mat = M_m[j];
            mmult ("-", WCur, "N", M, "N", WOld, WOld);

            B->mat = B_m[k];
            B2.mat = B_m[j];
            mplus (B, &B2, Work1);
#ifdef MATLAB_COMP
            memcpy (Work2->mat, RandMat, bs * bs * sizeof(DoubleComplex));
#else
            randComplexMat (Work2);
#endif
            memult (eps * 0.3, Work1, Work2, NULL, Work1);

            WOld->mat = WOld_m[k];
            mplus (Work1, WOld, WOld);
            
            if (k < j - 1) {
                /* W(:,:,k,old) = W(:,:,k,old) 
                 *                +(B(:,:,k).')*conj(W(:,:,k+1,cur)) 
                 *                - W(:,:,k,old)*(B(:,:,j-1).');
                 */
                B->mat = B_m[k];
                WCur->mat = WCur_m[k+1];
                mmult ("+", B, "T", WCur, "C", NULL, Work1);

                WOld->mat = WOld_m[k];
                B->mat = B_m[j-1];
                mmult ("-", WOld, "N", B, "T", Work1, Work1);
                mplus (Work1, WOld, WOld);
            }
        }
        /* W(:,:,k,old) = W(:,:,k,old)/B(:,:,j);    */
        WOld->mat = WOld_m[k];
        B->mat = B_m[j];
        mdiv (WOld, B, WOld);
    }
}

/***************************************************************************
 *
 * Abstract
 *   Orthogonalize Q(j+1) against Q(k), where k is inside the interval.
 *   Orthogonalize $Q_{j+1}$, current Q block, against all $Q$'s blocks
 *   which are in the orthogonalization intervals.
 *
 * Input
 *   iter   The iteration
 *
 * Output
 *   None
 *
 ***************************************************************************/

void orthR (int iter)
{
    int k, j;
    DoubleComplexMat Q1;

    j = iter;
    Q1.m = Q->m;
    Q1.n = Q->n;
    Q1.mat = Q_m - Q1.m * bs;
    for (k = 1; k <= up; k++) {
        Q1.mat += Q1.m * bs;
        orthogonalize (R, &Q1);
    
        /* Reset orthogonality estimates.    */
#ifdef MATLAB_COMP
        memcpy (Work1->mat, RandMat, bs * bs * sizeof(DoubleComplex));
#else
        randComplexMat (Work1);
#endif
        WCur->mat = WCur_m[k];
        mscal (eps * 1.5, Work1, WCur);
    }

    /* [Q(j+1), B(j)] = qr (R, 0)               */
    Q1.mat = Q->mat + Q->m * bs;
    B->mat = B_m[j];
    qr (R, &Q1, B);

    /* Update number of blocks selected for orthogonalization.  */
    nBlk = nBlk + up;
    if (second == TRUE) {
        second = FALSE;
        up = 0;
    }
    else {
        second = TRUE;
        doOrtho = FALSE;
        up = min (j + 1, up + 1);
    }
}
