/****************************************************************************
 *
 * LanTri.c
 *
 * Abstract 
 *   Main file to perform Lanczos tridiagonalization.
 *
 * Student:  Guohong Liu
 *
 * ID:       0385117
 *
 ****************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <malloc.h>
#include <math.h>
#include "blklan.h"

/****************************************************************************
 *
 * Global variables
 *
 ****************************************************************************/

extern DoubleComplexMat *Q;     /* Unitary. one block = n * bs          */
extern DoubleComplexVec *p;     /* One column of matrix P, size n.      */
extern DoubleComplexMat *M;     /* Diagonal blocks. bs * bs.            */                                      
extern DoubleComplexMat *B;     /* Off-Diagonal blocks. bs * bs         */
extern DoubleComplexVec *a;     /* Main diagonal of tridiagonal. size n */
extern DoubleRealVec *b;        /* Subdiagonal of tridiagonal. n-1      */
extern DoubleComplexVec *wCur;  /* Dectector of current iter.           */
extern DoubleComplexVec *wOld;  /* Dectector of previous iter.          */
extern DoubleComplexVec *r;     /* Starting vector and work vector. n   */
extern DoubleComplexVec *workZ; /* Work complex vector. size = n        */
extern DoubleRealVec *workD;    /* Word real vector.                    */    

/* Memories for corresponding matrix or vector.                         */
extern DoubleComplex *Q_m;      
extern DoubleComplex *P_m;     
extern DoubleComplex **M_m;                       
extern DoubleComplex **B_m;
extern DoubleComplex *a_m;
extern DoubleReal *b_m;
extern DoubleComplex *wCur_m;
extern DoubleComplex *wOld_m;
extern DoubleComplex *r_m;
extern DoubleComplex *workZ_m;
extern DoubleReal *workD_m;

extern int n;
extern int bs;
extern int blks;
extern double eps;
extern int *up2, *low2;         /* Orthogonalization interval.          */

double mideps2;
int interNum;
Boolean doOrtho2;
Boolean second2;
int nVec;

#ifdef MATLAB_COMP
extern DoubleComplex* randVec;
#else
DoubleComplex randComplex;
#endif

/***************************************************************************
 *
 * Abstract 
 *   Main procedural in tridiagonalization stage.
 *
 * Input    
 *   None
 *
 * Output
 *   None
 *
 ***************************************************************************/

int lanTri ()
{
    DoubleComplex zval, *tmp;
    DoubleComplexMat P, final;
    int j;

    /* Initialize parameters.   */
    mideps2 = pow (sqrt (eps), 3);
    interNum = 0;
    doOrtho2 = FALSE;
    second2 = FALSE;
    nVec = 0;

    /* Decrease the index of array so that from then on the index
     * of array starts from 1, not from 0. 
     */
    M_m--;
    B_m--;
    P_m--;
    a_m--;
    b_m--;
    wOld_m--;
    wCur_m--;
    workZ_m--;
    workD_m--;
    r_m--;
    up2--;
    low2--;

    wOld_m[1].r = 1.0;
    wOld_m[1].i = 0.0;
    
    /* Set the starting vector for tridiagonalization stage.    */
    r->vec = &r_m[1];
    r->n = n;
    ones (r->vec, r->n);

    /* P(:,1) = r / norm(r) */
    p->vec = &P_m[1];       
    vscal (1/vecnorm (r), r, p);

    for (j = 1; j <= n; j++) {
        /* r = J * conjg(p(j)). Band multiplication.*/
        p->vec = P_m + (j - 1) * n + 1;      
        sbmvmul (j);
        
        /* a(j) = P(:,j)' * r                       */
        r->vec = &r_m[1];
        r->n = n;
        vmult (p, "H", r, &a_m[j]);

        if (j == 1) {
            /* r = r - a(j)*p(j)
             * Index of P_m starts from 1 not 0.
             */
            vztplus ("-", &a_m[j], p, r, r);
        }
        else {
            /* r = r - a(j)*p(j) - b(j-1)*p(j-1)    */
            vztplus ("-", &a_m[j], p, r, r);
            p->vec = p->vec - n;
            vdtplus ("-", b_m[j-1], p, r, r);
        }
    
        if (j < n) {
            b_m[j] = vecnorm (r);

            if (j > 2) {
                /* zval = b(1) * conjg(wCur(2)) 
                 *      + a(1) * conjg(wCur(1))   
                 *      - a(j) * wCur(1)
                 *      - b(j-1) * wOld(1)
                 */
                dzmult ("+", b_m[1], &wCur_m[2], "C", NULL, &zval);
                zzmult ("+", &a_m[1], &wCur_m[1], "C", &zval, &zval);
                zzmult ("-", &a_m[j], &wCur_m[1], "N", &zval, &zval);
                dzmult ("-", b_m[j - 1], &wOld_m[1], "N", &zval, &zval);
                /* wOld(1) = zval / b(j)            */
                zdiv (&zval, b_m[j], &wOld_m[1]);

                /* wOld(1) = wOld(1) 
                 *           + eps*(b(1)+b(j))*0.3*(randn + IM*randn);
                 */
#ifdef MATLAB_COMP
                dzmult ("+", eps*(b_m[1]+b_m[j])*0.3, &randVec[0], "N", 
                        &wOld_m[1], &wOld_m[1]);
#else
                randComplex = randComplexNum ();
                dzmult ("+", eps*(b_m[1]+b_m[j])*0.3, &randComplex, "N", 
                        &wOld_m[1], &wOld_m[1]);                                
#endif

                /* workZ = b(2:j-1) .* conjg(wCur(3:j))
                 *         + a(2:j-1) .* conjg(wCur(2:j-1))
                 *         - a(j) * wCur(2:j-1)
                 *         + b(1:j-2) .* conjg(wCur(1:j-2))
                 */
                workZ->vec = &workZ_m[1];
                workZ->n = j - 2;
                b->vec = &b_m[2];
                b->n = j - 2;
                wCur->vec = &wCur_m[3];
                wCur->n = j - 2;
                vdemult (1.0, b, wCur, "C", NULL, workZ);
                
                a->vec = &a_m[2];
                a->n = j - 2;
                wCur->vec = &wCur_m[2];
                vzemult (a, wCur, "C", workZ, workZ);

                vztplus ("-", &a_m[j], wCur, workZ, workZ);

                b->vec = &b_m[1];
                wCur->vec = &wCur_m[1];
                vdemult (1.0, b, wCur, "C", workZ, workZ);
                
                /* workZ = workZ - b(j-1)*wOld(2:j-1)       */
                wOld->vec = &wOld_m[2];
                wOld->n = j - 2;
                vdtplus ("-", b_m[j-1], wOld, workZ, workZ);

                /* wOld(2:j-1) = workZ / b(j)               */
                vscal (1 / b_m[j], workZ, wOld);

                /* workD = b(2:j-1)+b(j)*ones(j-2,1)        */
                workD->vec = &workD_m[1];
                workD->n = j - 2;
                setRealArray (workD->vec, workD->n, b_m[j]);
                b->vec = &b_m[2];
                b->n = j - 2;
                vplus (b, workD, workD);
                

                /* workZ = randn(j-2,1) + IM*randn(j-2,1)   */
#ifdef MATLAB_COMP
                workZ->vec = &randVec[0];
#else
                randComplexVec (workZ);
#endif                                               
                /* wOld(2:j-1) = wOld(2:j-1) + eps*0.3*workD.*workZ */
                vdemult (eps*0.3, workD, workZ, "N", wOld, wOld);

                /* Swap wOld and wCur   */
                tmp = wOld_m;
                wOld_m = wCur_m;
                wCur_m = tmp;
                wOld_m[j].r = 1.0;
                wOld_m[j].i = 0.0;
            } /* if (j > 2)             */

            /* wCur(j) = eps*n*(b(1)/b(j))*0.6*(randn + IM*randn)   */
#ifdef MATLAB_COMP
            dzmult ("+", eps*n*(b_m[1]/b_m[j])*0.6, &randVec[0], "N", 
                    NULL, &wCur_m[j]);
#else
            randComplex = randComplexNum ();
            dzmult ("+", eps*n*(b_m[1]/b_m[j])*0.6, &randComplex, "N", 
                    NULL, &wCur_m[j]);            
#endif
            wCur_m[j + 1].r = 1.0;
            wCur_m[j + 1].i = 0.0;

            if (second2 == FALSE) {
                detectw (j);
            }

            if ((doOrtho2 == TRUE) || (second2 == TRUE)) {
                orthr (j);
            }

            if (abs(b_m[j]) < eps) {
                /* b(j) = 0, quit.      
                 * a = a(1:j);
                 * b = b(1:j-1);
                 */
                if (!(a_m = (DoubleComplex *)realloc (a_m, 
                                j * sizeof (DoubleComplex)))) {
                    lanTriExcept ("lanTri", "Reallocate *a_m failed");
                }
                if (!(b_m = (DoubleReal *)realloc (b_m, 
                                (j - 1) * sizeof (DoubleComplex)))) {
                    lanTriExcept ("lanTri", "Reallocate *b_m failed");
                }
                break;
            }
            else {
                /* P(:,j+1) = r / b(j); */
                p->vec = P_m + j * n + 1;
                vscal (1 / b_m[j], r, p);
            }
        }
        
#ifdef INFO
        if (j == n - 1) { 
            lanTriInfo (j);
        }
#endif
    }

    /* Restore vector index.    */
    M_m++;
    B_m++;
    P_m++;
    a_m++;
    b_m++;
    wOld_m++;
    wCur_m++;
    workZ_m++;
    workD_m++;
    r_m++;
    up2++;
    low2++;

    return 0;
}

/***************************************************************************
 *
 * Abstract 
 *   Complex symmetric and block tridiagonal matrix-vector multiplication,
 *      r = J * conjg(P(j))
 *   where J is complex symmetric and block tridiagonal whose main 
 *   diagonal blocks are matrices M, and subdiagonal blocks are matrices 
 *   B. Matrices M and B are computed in the block tridiagonalizaition
 *   stage. The compute result is stored in global vector *r. Because
 *   M is symmetric matrix, the function uses symvmult(), not mvmult(),
 *   to deal with the multiplication of M with other vectors.
 *
 * Input    
 *   x      Vector. Size = n.
 *
 * Output     
 *   r      Vector. Size = n. 
 *
 ***************************************************************************/

void sbmvmul (int iter)
{
    int i;
    int low, up;        /* lower and upper bound    */

    /* workZ = conjg (P(:,iter))                    */    
    workZ->vec = &workZ_m[1];
    workZ->n = n;
    conjvec (p, workZ);
    
    /* The function obtains one part of vector r one time. The  
     * size of the subvector is the number of rows or columns
     * of one block.
     */
    r->n = bs;
    workZ->n = bs;

    /* r(1:bs) = M(:,:,1) * workZ(1:bs)             */
    M->mat = M_m[1];
    workZ->vec = &workZ_m[1];
    r->vec = &r_m[1];
    symvmult (M, workZ, NULL, r);

    /* r(1:bs) = r(1:bs) + B(:,:,1).' * workZ(bs+1:2*bs)        */
    B->mat = B_m[1];
    workZ->vec = &workZ_m[bs + 1];
    mvmult (B, "T", workZ, r, r);

    for (i = 1; i <= blks - 2; i++) {
        low = i * bs + 1;
        up = (i + 1) * bs;

        /* r(low:up) = B(:,:,i) * workZ((low-bs):(low-1))       */
        r->vec = &r_m[low];
        B->mat = B_m[i];
        workZ->vec = &workZ_m[low - bs];
        mvmult (B, "N", workZ, NULL, r);

        /* r(low:up) = r(low:up) + M(:,:,i+1) * workZ(low:up)   */
        M->mat = M_m[i + 1];
        workZ->vec = &workZ_m[low];
        symvmult (M, workZ, r, r);

        /* r(low:up) = r(low:up) + B(:,:,i+1).' * workZ(up+1:up+bs) */
        B->mat = B_m[i + 1];
        workZ->vec = &workZ_m[up + 1];
        mvmult (B, "T", workZ, r, r);
    }

    /* r(n-bs+1:n) = B(:,:,blks-1) * workZ((n-2*bs+1):(n-bs))   */
    r->vec = &r_m[n - bs + 1];
    B->mat = B_m[blks - 1];
    workZ->vec = &workZ_m[n - 2 * bs + 1];
    mvmult (B, "N", workZ, NULL, r);

    /* r(n-bs+1:n) = r(n-bs+1:n) + M(:,:,blks) * workZ(n-bs+1:n)*/
    M->mat = M_m[blks];
    workZ->vec = &workZ_m[n - bs + 1];
    symvmult (M, workZ, r, r);
}

/***************************************************************************
 *
 * Abstract 
 *   Compute w's, the detectors of the loss of orthogonality in
 *   tridiagonalization stage.
 *
 * Input  
 *   iter   The iteration.
 *
 * Output    
 *   None
 *
 ***************************************************************************/

void detectw (int iter)
{
    int j, k, p;
    DoubleComplex w;

    doOrtho2 = FALSE;
    interNum = 0;
    j = iter;
    k = 1;

    while (k <= j) {
        w.r = wCur_m[k].r;
        w.i = wCur_m[k].i;
        if (w.r * w.r + w.i * w.i > eps) {
            /* Lost orthogonality.      */
            doOrtho2 = TRUE;
            interNum = interNum + 1;

            /* Find the upper bound.    */
            p = k + 1;
            while (p < j + 1) {
                w.r = wCur_m[p].r;
                w.i = wCur_m[p].i;
                if (w.r * w.r + w.i * w.i >= mideps2) {
                    /* Nearly lost orthogonality.   */
                    p++;
                }
                else {
                    break;
                }
            }
            up2[interNum] = p - 1;
            
            /* Find the lower bound.                */
            p = k - 1;
            while (p > 0) {
                w.r = wCur_m[p].r;
                w.i = wCur_m[p].i;
                if (w.r * w.r + w.i * w.i >= mideps2) {
                    /* Nearly lost orthogonality.   */
                    p--;
                }
                else {
                    break;
                }
            }
            low2[interNum] = p + 1;
            
            /* Continue search. */
            k = up2[interNum] + 1;       
        }
        else {
            k = k + 1;
        }
    }
}

/***************************************************************************
 *
 * Abstract 
 *   Carry out orthogonalization in tridiagonalization stage.
 *
 * Input  
 *   iter   The iteration.
 *
 * Output    
 *   None
 *
 ***************************************************************************/

void orthr (int iter)
{
    int j, k, i;
    DoubleComplexVec p1;
    DoubleComplex zval;

    j = iter;
    p1.n = n;
    /* Carry out orthogonalization. */
    for (k = 1; k <= interNum; k++) {
        /* For each interval.       */
        for (i = low2[k]; i <= up2[k]; i++) {
            /* Do orthogonalization.
             * r = r - (P(:,i)' * r) * P(:,i).
             */
            p1.vec = P_m + (i - 1) * n + 1; 
            vmult (&p1, "H", r, &zval);
            vztplus ("-", &zval, &p1, r, r);

            /* Reset ortho estimates.   
             * wCur(i) = eps*1.5*(randn + IM*randn) 
             */
#ifdef MATLAB_COMP
            dzmult ("+", eps * 1.5, &randVec[0], "N", NULL, &wCur_m[i]);
#else
            randComplex = randComplexNum ();
            dzmult ("+", eps * 1.5, &randComplex, "N", NULL, &wCur_m[i]);
#endif
        }

        nVec = nVec + up2[k] - low2[k] + 1;
        /* Count the number of vectors selected.    */
        if (second2 == TRUE) {	        
            second2 = FALSE;		        
            low2[k] = 0; 
            up2[k] = 0; 
        }
        else {
            second2 = TRUE;		        
            doOrtho2 = FALSE;	        
            /* Adjust orthogonalization intervals for the second time */
            low2[k] = max (1, low2[k] - 1);
            up2[k] = min (j + 1, up2[k] + 1);
        }
    } /* for (k = 1; k <= interNum; k++)            */
    
    /* Recalculate b(j)     */
    b_m[j] = vecnorm (r);
}
