#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <sys/time.h>
#include "mpi.h"
#include <complex.h>
#include "blas.h"
#include "blacs.h"
#include "lapack.h"
#include "scalapack.h"

static int max( int a, int b ){
        if (a>b) return(a); else return(b);
}
static int min( int a, int b ){
        if (a<b) return(a); else return(b);
}

extern double verif_orthogonality(int m, int n, double *U, int iu, int ju, int *descU);
extern double verif_representativity(int m, int n, double *A, int ia, int ja, int *descA,
                                                     double *U, int iu, int ju, int *descU,
                                                     double *VT, int ivt, int jvt, int *descVT,
                                                     double *S);
extern int driver_pdgesvd(char jobU, char jobVT, int m, int n, double *A, int ia, int ja, int *descA,
                        double *S_NN, double *U_NN, int iu, int ju, int *descU, double *VT_NN, int ivt, int jvt, int *descVT,
                        double *MPIelapsedNN);


int main(int argc, char **argv) {
        int iam, nprocs;
        int myrank_mpi, nprocs_mpi;
        int ictxt, nprow, npcol, myrow, mycol;
        int nb, m, n;
        int mpA, nqA, mpU, nqU, mpVT, nqVT;
        int i, j, k, itemp, min_mn;
        int descA[9], descU[9], descVT[9];
        double *A=NULL;
        int info, infoVV;
        double  *U_VV=NULL;
        double  *VT_VV=NULL;
        double  *S_VV=NULL;
        double orthU_VV, residF, orthVT_VV;
        double  eps;
/**/
        int izero=0,ione=1;
/**/
        double MPIelapsedVV;
        char jobU, jobVT;
        int iseed[4], idist;
/**/
        MPI_Init( &argc, &argv);
        MPI_Comm_rank(MPI_COMM_WORLD, &myrank_mpi);
        MPI_Comm_size(MPI_COMM_WORLD, &nprocs_mpi);
/**/
        m = 100; n = 100; nprow = 2; npcol = 2; nb = 64; jobU='A'; jobVT='A';

        if (myrank_mpi==0){
                printf("\n");
                printf("--------------------------------------------------------------------------------------------------------------------\n");
                                printf("                            Testing pdgsevd -- double precision SVD ScaLAPACK routine                \n");
                printf("jobU jobVT    m     n     nb   p   q   || info   resid     orthU    orthVT   |SNN-SVV|    time(s)   cond(A) \n");
                printf("--------------------------------------------------------------------------------------------------------------------\n");
        }
/**/

        if (nb>n)
                nb = n;

        if (nprow*npcol>nprocs_mpi){
                if (myrank_mpi==0)
                        printf(" **** ERROR : we do not have enough processes available to make a p-by-q process grid ***\n");
                        printf(" **** Bye-bye                                                                         ***\n");
                MPI_Finalize(); exit(1);
        }
/**/
        Cblacs_pinfo( &iam, &nprocs ) ;
        Cblacs_get( -1, 0, &ictxt );
        Cblacs_gridinit( &ictxt, "Row", nprow, npcol );
        Cblacs_gridinfo( ictxt, &nprow, &npcol, &myrow, &mycol );
/**/
        min_mn = min(m,n);
/**/
        if ((myrow>-1)&(mycol>-1)&(myrow<nprow)&(mycol<npcol)){

/**/
                mpA    = numroc_( &m     , &nb, &myrow, &izero, &nprow );
                nqA    = numroc_( &n     , &nb, &mycol, &izero, &npcol );
                mpU    = numroc_( &m     , &nb, &myrow, &izero, &nprow );
                nqU    = numroc_( &min_mn, &nb, &mycol, &izero, &npcol );
                mpVT   = numroc_( &min_mn, &nb, &myrow, &izero, &nprow );
                nqVT   = numroc_( &n     , &nb, &mycol, &izero, &npcol );
/**/
                A = (double *)calloc(mpA*nqA,sizeof(double)) ;
                if (A==NULL){ printf("error of memory allocation A on proc %dx%d\n",myrow,mycol); exit(0); }
/**/
                idist = 2;
                iseed[0] = mpA%4096;
                iseed[1] = iam%4096;
                iseed[2] = nqA%4096;
                iseed[3] = 23;
/**/
                k = 0;
                for (i = 0; i < mpA; i++) {
                        for (j = 0; j < nqA; j++) {
                                dlarnv_( &idist, iseed, &ione, &(A[k]) );
                                k++;
                        }
                }
/*
*
*     Initialize the array descriptor for the distributed matrices A, U and VT
*
*/
                itemp = max( 1, mpA );
                descinit_( descA,  &m, &n, &nb, &nb, &izero, &izero, &ictxt, &itemp, &info );
                itemp = max( 1, mpA );
                descinit_( descU,  &m, &min_mn, &nb, &nb, &izero, &izero, &ictxt, &itemp, &info );
                itemp = max( 1, mpVT );
                descinit_( descVT, &min_mn, &n, &nb, &nb, &izero, &izero, &ictxt, &itemp, &info );
/**/
                eps = pdlamch_( &ictxt, "Epsilon" );
/**/
                U_VV = (double *)calloc(mpU*nqU,sizeof(double)) ;
                if (U_VV==NULL){ printf("error of memory allocation U_VV on proc %dx%d\n",myrow,mycol); exit(0); }
                VT_VV = (double *)calloc(mpVT*nqVT,sizeof(double)) ;
                if (VT_VV==NULL){ printf("error of memory allocation VT_VV on proc %dx%d\n",myrow,mycol); exit(0); }
                S_VV = (double *)calloc(min_mn,sizeof(double)) ;
                if (S_VV==NULL){ printf("error of memory allocation S_VV on proc %dx%d\n",myrow,mycol); exit(0); }
                infoVV = driver_pdgesvd( 'V', 'V', m, n, A, 1, 1, descA,
                        S_VV, U_VV, 1, 1, descU, VT_VV, 1, 1, descVT,
                        &MPIelapsedVV);
                orthU_VV  = verif_orthogonality(m,min_mn,U_VV , 1, 1, descU);
                orthVT_VV = verif_orthogonality(min_mn,n,VT_VV, 1, 1, descVT);
                residF =  verif_representativity( m, n,     A, 1, 1, descA,
                                                         U_VV, 1, 1, descU,
                                                        VT_VV, 1, 1, descVT,
                                                         S_VV);
                if ( iam==0 ){
                        printf(" V    V   %6d %6d  %3d  %3d %3d  ||  %3d  %7.1e   %7.1e   %7.1e              %8.2f    %7.1e\n",
                                m,n,nb,nprow,npcol,infoVV,residF,orthU_VV,orthVT_VV,MPIelapsedVV,S_VV[0]/S_VV[min_mn-1]);
                        printf("--------------------------------------------------------------------------------------------------------------------\n");
                }
/**/
                free(U_VV); free(S_VV); free(VT_VV);
                free(A);
                Cblacs_gridexit( 0 );
        }
/**/
        MPI_Finalize();
        exit(0);
}
/**/
double verif_orthogonality(int m, int n, double *U, int iu, int ju, int *descU){

        double *R=NULL;
        int nprow, npcol, myrow, mycol;
        int mpR, nqR, nb, itemp, descR[9], ictxt, info, min_mn, max_mn;
        int ctxt_ = 1, nb_ = 5;
        int izero = 0, ione = 1;
        double *wwork=NULL;
        double tmone= -1.0e+00,  tpone= +1.0e+00,  tzero= +0.0e+00;
        double orthU;

        min_mn = min(m,n);
        max_mn = max(m,n);
        ictxt = descU[ctxt_];
        nb = descU[nb_];
        Cblacs_gridinfo( ictxt, &nprow, &npcol, &myrow, &mycol );

        mpR    = numroc_( &min_mn, &nb, &myrow, &izero, &nprow );
        nqR    = numroc_( &min_mn, &nb, &mycol, &izero, &npcol );
        R = (double *)calloc(mpR*nqR,sizeof(double)) ;
        if (R==NULL){ printf("error of memory allocation R on proc %dx%d\n",myrow,mycol); exit(0); }
        itemp = max( 1, mpR );
        descinit_( descR,  &min_mn, &min_mn, &nb, &nb, &izero, &izero, &ictxt, &itemp, &info );

        pdlaset_( "F", &min_mn, &min_mn, &tzero, &tpone, R, &ione, &ione, descR );
        if (m>n)
                pdgemm_( "T", "N", &min_mn, &min_mn, &m, &tpone, U, &iu, &ju, descU, U,
                        &iu, &ju, descU, &tmone, R, &ione, &ione, descR );
        else
                pdgemm_( "N", "T", &min_mn, &min_mn, &n, &tpone, U, &iu, &ju, descU, U,
                        &iu, &ju, descU, &tmone, R, &ione, &ione, descR );
        orthU = pdlange_( "F", &min_mn, &min_mn, R, &ione, &ione, descR, wwork );
        orthU = orthU / ((double) max_mn);
        free(R);

        return orthU;

}
/**/
double verif_representativity(int m, int n, double *A, int ia, int ja, int *descA,
                                              double *U, int iu, int ju, int *descU,
                                              double *VT, int ivt, int jvt, int *descVT,
                                              double *S){

        double *Acpy=NULL, *Ucpy=NULL;
        int nprow, npcol, myrow, mycol;
        int min_mn, max_mn, mpA, pcol, localcol, i, nqA;
        int ictxt, nbA, rsrcA, csrcA, nbU, rsrcU, csrcU, mpU, nqU;
        int ctxt_ = 1, nb_ = 5, rsrc_ = 6, csrc_ = 7;
        int izero = 0, ione = 1;
        double *wwork=NULL;
        double tmone= -1.0e+00, tpone= +1.0e+00;
        double residF, AnormF;

        min_mn = min(m,n);
        max_mn = max(m,n);
        ictxt = descA[ctxt_];
        Cblacs_gridinfo( ictxt, &nprow, &npcol, &myrow, &mycol );

        nbA = descA[nb_]; rsrcA = descA[rsrc_] ; csrcA = descA[csrc_] ;
        mpA    = numroc_( &m     , &nbA, &myrow, &rsrcA, &nprow );
        nqA    = numroc_( &n     , &nbA, &mycol, &csrcA, &npcol );
        Acpy = (double *)calloc(mpA*nqA,sizeof(double)) ;
        if (Acpy==NULL){ printf("error of memory allocation Acpy on proc %dx%d\n",myrow,mycol); exit(0); }
        pdlacpy_( "All", &m, &n, A, &ia, &ja, descA, Acpy, &ia, &ja, descA );

        nbU = descU[nb_]; rsrcU = descU[rsrc_] ; csrcU = descU[csrc_] ;
        mpU    = numroc_( &m     , &nbU, &myrow, &rsrcU, &nprow );
        nqU    = numroc_( &min_mn, &nbU, &mycol, &csrcU, &npcol );
        Ucpy = (double *)calloc(mpU*nqU,sizeof(double)) ;
        if (Ucpy==NULL){ printf("error of memory allocation Ucpy on proc %dx%d\n",myrow,mycol); exit(0); }
        pdlacpy_( "All", &m, &min_mn, U, &iu, &ju, descU, Ucpy, &iu, &ju, descU );

        AnormF = pdlange_( "F", &m, &n, A, &ia, &ja, descA, wwork);

        for (i=1;i<min_mn+1;i++){
                pcol = indxg2p_( &i, &(descU[5]), &izero, &izero, &npcol );
                localcol = indxg2l_( &i, &(descU[5]), &izero, &izero, &npcol );
                if( mycol==pcol )
                        dscal_( &mpA, &(S[i-1]), &(Ucpy[ ( localcol-1 )*descU[8] ]), &ione );
        }
        pdgemm_( "N", "N", &m, &n, &min_mn, &tpone, Ucpy, &iu, &ju, descU, VT, &ivt, &jvt, descVT,
                        &tmone, Acpy, &ia, &ja, descA );
        residF = pdlange_( "F", &m, &n, Acpy, &ione, &ione, descA, wwork);
        residF = residF/AnormF/((double) max_mn);

        free(Ucpy);
        free(Acpy);

        return residF;
}
/**/
int driver_pdgesvd( char jobU, char jobVT, int m, int n, double *A, int ia, int ja, int *descA,
                double *S_NN, double *U_NN, int iu, int ju, int *descU, double *VT_NN, int ivt, int jvt, int *descVT,
                double *MPIelapsedNN){

        double *Acpy=NULL, *work=NULL;
        int lwork;
/**/
        int ione=1;
/**/
        int nprow, npcol, myrow, mycol;
        int mpA, nqA;
        int ictxt, nbA, rsrcA, csrcA;
        int ctxt_ = 1, nb_ = 5, rsrc_ = 6, csrc_ = 7;
        int infoNN;

        double MPIt1, MPIt2;
/**/
        ictxt = descA[ctxt_];
        Cblacs_gridinfo( ictxt, &nprow, &npcol, &myrow, &mycol );

        nbA = descA[nb_]; rsrcA = descA[rsrc_] ; csrcA = descA[csrc_] ;

        mpA    = numroc_( &m     , &nbA, &myrow, &rsrcA, &nprow );
        nqA    = numroc_( &n     , &nbA, &mycol, &csrcA, &npcol );

        Acpy = (double *)calloc(mpA*nqA,sizeof(double)) ;
        if (Acpy==NULL){ printf("error of memory allocation Acpy on proc %dx%d\n",myrow,mycol); exit(0); }

        pdlacpy_( "All", &m, &n, A, &ione, &ione, descA, Acpy, &ione, &ione, descA );


        work = (double *)calloc(1,sizeof(double)) ;
        if (work==NULL){ printf("error of memory allocation for work on proc %dx%d (1st time)\n",myrow,mycol); exit(0); }

        lwork=-1;

        pdgesvd_( &jobU, &jobVT, &m, &n, Acpy, &ione, &ione, descA,
                S_NN, U_NN, &ione, &ione, descU, VT_NN, &ione, &ione, descVT,
                work, &lwork, &infoNN);

        lwork = (int) (work[0]);
        free(work);

        work = (double *)calloc(lwork,sizeof(double)) ;
        if (work==NULL){ printf("error of memory allocation work on proc %dx%d\n",myrow,mycol); exit(0); }
/**/
        MPIt1 = MPI_Wtime();
/**/
        pdgesvd_( &jobU, &jobVT, &m, &n, Acpy, &ione, &ione, descA,
                S_NN, U_NN, &ione, &ione, descU, VT_NN, &ione, &ione, descVT,
                work, &lwork, &infoNN);
/**/
        MPIt2 = MPI_Wtime();
        (*MPIelapsedNN)=MPIt2-MPIt1;
/**/
        free(work);
        free(Acpy);
        return infoNN;
}