/****************************************************************************
 *
 * BlkLanAux.c
 *
 * Abstract 
 *   Define auxiliar functions used in block Lanczos algorithm.
 *
 * Student:  Guohong Liu
 *
 * ID:       0385117
 *
 ****************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <malloc.h>
#include <string.h>
#include <time.h>
#include "blklan.h"

/***************************************************************************
 *
 * Abstract 
 *   Check errors in factorization and orthogonality of matrices 
 *   computed by block Lanczos tridiagonalization.
 *
 * Input    
 *   n      Order of matrix A.
 *   A_m    Pointer to the memory of matrix A.
 *   a_p    Pointer to the memory of vector a.
 *   b_p    Pointer to the memory of vector b.
 *   Q_p    Pointer to the memory of matrix Q.
 *   P_p    Pointer to the memory of matrix P.
 *
 * Output
 *   None
 *
 * Return
 *   0      Successful exit
 *   1      Memory allocation exception occurs
 *
 ***************************************************************************/

int triErrChk (int n, DoubleComplex *A_m, DoubleComplex *a_m, 
               DoubleReal *b_m, DoubleComplex *Q_m, DoubleComplex *P_m) 
{
    double err;
    DoubleComplexMat A, Q, P, R1, R2, Tmp;
    DoubleComplexVec a;
    DoubleRealVec b;
    DoubleComplex *mat;
    int i, j;

    /* Check the error in orthogonality.    
     * First create the final unitary R1 = Q * P 
     */
    Q.m = n;
    Q.n = n;
    Q.mat = Q_m;
    P.m = n;
    P.n = n;
    P.mat = P_m;
    R1.m = n;
    R1.n = n;
    if (!(R1.mat = (DoubleComplex *)
                    malloc (n * n * sizeof (DoubleComplex)))) {
        printf ("triErrChk: Malloc for R1.mat failed.\n");
        return 1;
    }
    mmult ("+", &Q, "N", &P, "N", NULL, &R1);
    orthErr (&R1, &err);
    printf ("\nError in orthogonality: %1.3e\n", err);

    /* Check error in factorization.    */
    Tmp.m = n;
    Tmp.n = n;
    if (!(Tmp.mat = (DoubleComplex *)
                    malloc (n * n * sizeof (DoubleComplex)))) {
        printf ("brdErrChk: Malloc for Tmp.mat failed.\n");
        return 1;
    }
    R2.m = n;
    R2.n = n;
    if (!(R2.mat = (DoubleComplex *)
                    malloc (n * n * sizeof (DoubleComplex)))) {
        printf ("triErrChk: Malloc for R2.mat failed.\n");
        return 1;
    }
    A.m = n;
    A.n = n;
    A.mat = A_m;
    /* R2 = R1' * A * conj(R1) where R1 = Q * P
     * Q and A are n-by-n matrics, so call bmmult () not mmult ()
     */
    bmmult ("+", &R1, "H", &A, "N", NULL, &Tmp);
    bmmult ("+", &Tmp, "N", &R1, "C", NULL, &R2);
    /* R2 = R2 - T where matrix T's diagonal is vector a and 
     * off-diagonals are vector b. 
     */
    a.n = n;
    a.vec = a_m;
    b.n = n - 1;
    b.vec = b_m;
    mat = R2.mat;
    for (j = 0; j < n; j++) {
        for (i = 0; i < n; i++) {
            if (i == j) {
                /* diagonal elements.       */
                mat[j * n + i].r -=  a.vec[i].r;
                mat[j * n + i].i -=  a.vec[i].i;
            }
            else if (i == j - 1) {
                /* superdiagonal elements.  */
                mat[j * n + i].r -= b.vec[i];
            }
            else if (i == j + 1) {
                /* subdiagonal elements.    */
                mat[j * n + i].r -= b.vec[j];
            }
        }
    }

    err = matnorm (&R2)/ (double)(n * n);
    printf ("Error in factorization: %1.3e\n", err);

    free (R1.mat);
    free (R2.mat);
    free (Tmp.mat);
    return 0;
}

/***************************************************************************
 *
 * Abstract 
 *   norm (A'*A - eye(n), 'fro')
 *
 * Input    
 *   A  matrix. n * n
 *
 * Output     
 *   err    The norm.
 *
 * Return
 *   0      Successful exit
 *   1      Memory allocation exception occurs
 *
 ***************************************************************************/

int orthErr (DoubleComplexMat *A, double *err)
{
    DoubleComplexMat R;
    int j, m, n;
    DoubleComplex *mat;
    double error;

    m = A->m;
    n = A->n;
    if (!(R.mat = (DoubleComplex *)
                    malloc (m * n * sizeof (DoubleComplex)))) {
        printf ("orthErr: Malloc for R.mat failed.\n");
        return 1;
    }
    R.m = m;
    R.n = n;

    /* R = A' * A   */
    mmult ("+", A, "H", A, "N", NULL, &R);
        
    /* R = R - I    */
    mat = R.mat;
    for (j = 0; j < m; j++) {
        (mat + j * m + j)->r = (mat + j * m + j)->r - 1.0;
    }

    error = matnorm (&R);
    *err = (double)error / (double)(m * m);
    
    free (R.mat);
    return 0;
}

/***************************************************************************
 *
 * Abstract 
 *   Set every elements of one array to zero.
 *
 * Input    
 *   array      Complex array.
 *   n          The size of the array.
 *
 * Output     
 *   array      Complex array whose elements are all zeros.
 *
 ***************************************************************************/

void zeros (DoubleComplex *array, int n)
{
    int i;

    for (i = 0; i < n; i++) {
        array[i].r = 0;
        array[i].i = 0;
    }
}

/***************************************************************************
 *
 * Abstract 
 *   Set every elements of one array to one.
 *
 * Input    
 *   array      Complex array.
 *   n          The size of the array.
 *
 * Output     
 *   array      Complex array whose elements are all ones.
 *
 ***************************************************************************/

void ones (DoubleComplex *array, int n)
{
    int i;

    for (i = 0; i < n; i++) {
        array[i].r = 1;
        array[i].i = 0;
    }
}

/***************************************************************************
 *
 * Abstract 
 *   Set every element of one array to a specific value.
 *
 * Input    
 *   array      Real array.
 *   n          The size of array.
 *   val        Real number
 *
 * Output     
 *   array      Array after set.
 *
 ***************************************************************************/

void setRealArray (DoubleReal *array, int n, DoubleReal val)
{
    int i;

    for (i = 0; i < n; i++) {
        array[i] = val;
    }
}

#define DISP_COL 2          /* Maximum columns of displayable complex.    */
#define DISP_ROW 8          /* Maximum rows of displayable numbers.       */
char complexFormat[] = "%13.10f %+13.10fi";
char realFormat[] = "%13.10f";

/***************************************************************************
 *
 * Abstract 
 *   Print one complex matrix. Only the first and last column are 
 *   printed. If the matrix is large, only the first and the last
 *   DISP_ROW/2 rows are printed.
 *   
 * Input    
 *   A      Complex matrix to be printed
 *   tag    String to identify the matrix.
 *
 * Output     
 *   None
 *
 ***************************************************************************/

void showComplexMat (DoubleComplexMat *A, char *tag)
{
    int m = A->m;
    int n = A->n;
    DoubleComplex *mat = A->mat;
    int i, j, col, row, idx;

    if (tag != NULL) {
        printf ("%s\n", tag);
    }
    
    row = min (m, DISP_ROW / 2);
    col = min (n, DISP_COL - 1);
    for (i = 0; i < row; i++) {
        for (j = 0; j < col; j++) {
            idx = i + j * m;
            printf (complexFormat, mat[idx].r, mat[idx].i);
            if ((j != col - 1) || ((j == col - 1) && (n <= DISP_COL))) {
                printf ("  ");
            }
        }

        if (n > DISP_COL) {
            printf (" ... ");
        }

        /* Last column. */
        idx = i + (n - 1) * m;
        printf (complexFormat, mat[idx].r, mat[idx].i);
        printf ("\n");
    }

    if (m > DISP_ROW) {
        printf (" ...... ......\n");
        printf (" ...... ......\n");
    }
    else {
        return;
    }

    row = min (DISP_ROW / 2, m - row);
    for (i = m - row; i < m; i++) {
        for (j = 0; j < col; j++) {
            idx = i + j * m;
            printf (complexFormat, mat[idx].r, mat[idx].i);
            if ((j != col - 1) || ((j == col - 1) && (n <= DISP_COL))) {
                printf ("  ");
            }
        }
        if (n > DISP_COL) {
            printf (" ... ");
        } 

        idx = i + (n - 1) * m;
        printf (complexFormat, mat[idx].r, mat[idx].i);
        printf ("\n");
    }
}

/***************************************************************************
 *
 * Abstract 
 *   Print one complex vector. If the vector is large, only the first 
 *   and the last DISP_ROW / 2 * DISP_COL elements of the vector 
 *   are printed.
 *   
 * Input    
 *   a      Complex vector to be printed
 *   tag    String to identify the vector.
 *
 * Output     
 *   None
 *
 ***************************************************************************/

void showComplexVec (DoubleComplexVec *a, char *tag)
{
    int n = a->n;
    DoubleComplex *vec = a->vec;
    int i, j, total, idx;

    if (tag != NULL) {
        printf ("%s\n", tag);
    }
    
    total = min (n, DISP_ROW / 2 * DISP_COL);
    for (i = 0; i < total; i++) {
        printf (complexFormat, vec[i].r, vec[i].i);
        if ((i + 1) % DISP_COL == 0) {
            printf ("\n");
        }
        else {
            printf ("  ");
        }
    }

    if (n > total) {
        printf (" ...... ......\n");
        printf (" ...... ......\n");
    }
    else {
        printf ("\n");
        return;
    }

    total = min (DISP_ROW / 2 * DISP_COL, n - total);
    j = 1;
    for (i = n - total; i < n; i++) {
        printf (complexFormat, vec[i].r, vec[i].i);
        if ((j % DISP_COL == 0) && (i != n - 1)) {
            printf ("\n");
            j = 1;
        }
        else {
            printf ("  ");
            j++;
        }
    }
    printf ("\n");
}

/***************************************************************************
 *
 * Abstract 
 *   Print one real vector. If the vector is large, only the first 
 *   and the last DISP_ROW / 2 * DISP_COL elements of the vector 
 *   are printed.
 *   
 * Input    
 *   a      Real vector to be printed
 *   tag    String to identify the vector.
 *
 * Output     
 *   None
 *
 ***************************************************************************/

void showRealVec (DoubleRealVec *a, char *tag)
{
    int n = a->n;
    DoubleReal *vec = a->vec;
    int i, j, total, idx;

    if (tag != NULL) {
        printf ("%s\n", tag);
    }
    
    total = min (n, DISP_ROW * DISP_COL);
    for (i = 0; i < total; i++) {
        printf (realFormat, vec[i]);
        if ((i + 1) % (DISP_COL * 2) == 0) {
            printf ("\n");
        }
        else {
            printf ("  ");
        }
    }

    if (n > total) {
        printf (" ...... ......\n");
        printf (" ...... ......\n");
    }
    else {
        printf ("\n");
        return;
    }

    total = min (DISP_ROW * DISP_COL, n - total);
    j = 1;
    for (i = n - total; i < n; i++) {
        printf (realFormat, vec[i]);
        if ((j % (DISP_COL * 2) == 0) && (i != n - 1)) {
            printf ("\n");
            j = 1;
        }
        else {
            printf ("  ");
            j++;
        }
    }
    printf ("\n");
}