#include "petscmat.h"
#include "parpre_vec.h"
#include "src/vec/impls/dvecimpl.h"
#include "dotproducts_impl.h"
#include "mpi.h"

/****************************************************************
 * Combined Dot Products routines
 * 
 * The user needs to create (and later destroy) a structure `DotProducts'
 * with DotProductsCreate and DotProductsDestroy.
 * Combined dot products are then performed as follows:
 * - a sequence of DotProductsSet calls performs the local parts of
 * the dot products, and stores them; there is no global communication.
 * - the first DotProductsGet call causes the global communication to be
 * performed; every subsequent DotProductsGet call is a read from
 * the results array.
 ****************************************************************/

#define WRITE_STATE 0
#define READ_STATE 1

#undef __FUNC__
#define __FUNC__ "DotProductsCreate"
/*
  DotProductsCreate
  Create a dot products object.
*/
int DotProductsCreate(MPI_Comm comm,int size,DotProducts *dp)
{
  DotProducts newdp; int ierr;
  PetscFunctionBegin;

  newdp = PetscNew(struct _p_DotProducts); CHKPTRQ(newdp);

  /* set up internal matrices for full dot product */
  ierr = MatCreateSeqDense
    (MPI_COMM_SELF,size,size,PETSC_NULL,&newdp->mat); CHKERRQ(ierr);
  ierr = MatAssemblyBegin(newdp->mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(newdp->mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatDuplicate
    (newdp->mat,MAT_DO_NOT_COPY_VALUES,&(newdp->tmpmat)); CHKERRQ(ierr);
  ierr = MatGetArray(newdp->mat,&newdp->mat_array); CHKERRQ(ierr);
  ierr = MatRestoreArray(newdp->mat,&newdp->mat_array); CHKERRQ(ierr);
  ierr = MatGetArray(newdp->tmpmat,&newdp->tmpmat_array); CHKERRQ(ierr);
  ierr = MatRestoreArray(newdp->tmpmat,&newdp->tmpmat_array); CHKERRQ(ierr);

  /* set up internal vectors */
  ierr = VecCreateSeq(MPI_COMM_SELF,size,&newdp->vec); CHKERRQ(ierr);
  ierr = VecAssemblyBegin(newdp->vec); CHKERRQ(ierr);
  ierr = VecAssemblyEnd(newdp->vec); CHKERRQ(ierr);
  ierr = VecDuplicate(newdp->vec,&newdp->tmpvec); CHKERRQ(ierr);
  ierr = VecGetArray(newdp->vec,&newdp->vec_array); CHKERRQ(ierr);
  ierr = VecRestoreArray(newdp->vec,&newdp->vec_array); CHKERRQ(ierr);

  newdp->dimension = size;
  newdp->full = newdp->complex = 0;
  newdp->size = newdp->write = newdp->high_read = 0;
  newdp->lvalues = newdp->gvalues = 0;
  newdp->state = WRITE_STATE;
  newdp->comm = comm;
  *dp = newdp;
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DotProductsDestroy"
/*
  DotProductsDestroy
  Destroy a dot products object.
*/
int DotProductsDestroy(DotProducts dp)
{
  int ierr;
  PetscFunctionBegin;

  if (dp->lvalues) PetscFree(dp->lvalues);
  if (dp->gvalues) PetscFree(dp->gvalues);
  ierr = MatDestroy(dp->mat); CHKERRQ(ierr);
  ierr = MatDestroy(dp->tmpmat); CHKERRQ(ierr);
  ierr = VecDestroy(dp->vec); CHKERRQ(ierr);
  ierr = VecDestroy(dp->tmpvec); CHKERRQ(ierr);

  PetscFree(dp);
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DotProductsSetFull"
int DotProductsSetFull(DotProducts dp)
{
  PetscFunctionBegin;
  dp->full = 1;
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DotProductsSetDiagonal"
int DotProductsSetDiagonal(DotProducts dp)
{
  PetscFunctionBegin;
  dp->full = 0;
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DotProductsAssemblyBegin"
int DotProductsAssemblyBegin(DotProducts dp)
{
  int ierr;
  PetscFunctionBegin;
  if (dp->full) {
    ierr = MatGetDiagonal(dp->mat,dp->vec); CHKERRQ(ierr);
  } else {
    Scalar *a; int i;
    ierr = MatZeroEntries(dp->mat); CHKERRQ(ierr);

    ierr = VecGetArray(dp->vec,&a); CHKERRQ(ierr);
    for (i=0; i<dp->dimension; i++) {
      ierr = MatSetValues(dp->mat,1,&i,1,&i,a+i,INSERT_VALUES); CHKERRQ(ierr);
    }
    ierr = MatAssemblyBegin(dp->mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
    ierr = VecRestoreArray(dp->vec,&a); CHKERRQ(ierr);

    /* what on earth?!
    ierr = MatGetArray(dp->mat,&a); CHKERRQ(ierr);
    for (i=0; i<dp->dimension; i++) {
      int d=i+i*dp->dimension;
      ierr = VecSetValues(dp->vec,1,&i,a+d,INSERT_VALUES); CHKERRQ(ierr);
    }
    ierr = VecAssemblyBegin(dp->vec); CHKERRQ(ierr);
    ierr = VecAssemblyEnd(dp->vec); CHKERRQ(ierr);
    ierr = MatRestoreArray(dp->mat,&a); CHKERRQ(ierr);
    */
  }
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DotProductsAssemblyEnd"
int DotProductsAssemblyEnd(DotProducts dp)
{
  int ierr;
  PetscFunctionBegin;
  if (dp->full) {
  } else {
    ierr = MatAssemblyEnd(dp->mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  }
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DotProductsGetMat"
int DotProductsGetMat(DotProducts dp,Mat *mat)
{
  PetscFunctionBegin;
  if (dp->gotmat) SETERRQ(1,1,"Cannot get mat a second time");
  *mat = dp->mat; dp->gotmat = (int) mat;
  PetscFunctionReturn(0);
}
#undef __FUNC__
#define __FUNC__ "DotProductsRestoreMat"
int DotProductsRestoreMat(DotProducts dp,Mat *mat)
{
  PetscFunctionBegin;
  if (!dp->gotmat) SETERRQ(1,1,"You did not get the mat to begin with");
  if ((int)mat!=dp->gotmat) SETERRQ(1,1,"Restoring from wrong matrix");
  dp->gotmat = 0;
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DotProductsGetDiagonal"
int DotProductsGetDiagonal(DotProducts dp,Vec *vec)
{
  PetscFunctionBegin;
  if (dp->gotvec) SETERRQ(1,1,"Cannot get vec a second time");
  *vec = dp->vec; dp->gotvec = (int) vec;
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DotProductsRestoreDiagonal"
int DotProductsRestoreDiagonal(DotProducts dp,Vec *vec)
{
  PetscFunctionBegin;
  if (!dp->gotvec) SETERRQ(1,1,"You did not get the vec to begin with");
  if (dp->gotvec!=(int)vec) SETERRQ(1,1,"Restoring from wrong vec");
  dp->gotvec = 0;
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DotProductsSet"
/*
  DotProductsSet
  Submit one dot product (two vectors) for execution; the result can
  later be retrieved with DotProductsGet.
*/
int DotProductsSet(DotProducts dp,Vec x,Vec y)
{
  Scalar res;
  int the_loc,ierr;

  PetscFunctionBegin;
  /* compute local result */
  ierr = VecDot_Seq(x,y,&res); CHKERRQ(ierr);
  /* decide where we are going to write it */
#if PETSC_DECIDE >= 0
#error "PETSC_DECIDE conflict: has to be negative"
#endif
  the_loc = dp->write++;
  /* make sure that there is a place to write it */
  if (the_loc>=dp->size) {
    if (dp->size==0) {
      /* initial allocation */
      dp->lvalues = (Scalar *) PetscMalloc(100*sizeof(Scalar));
      CHKPTRQ(dp->lvalues);
      dp->gvalues = (Scalar *) PetscMalloc(100*sizeof(Scalar));
      CHKPTRQ(dp->gvalues);
      dp->size = 100;
    } else {
      /* double existing storage */
      Scalar *tmp; int new_size;
      new_size = 2*dp->size;

      tmp = (Scalar *) PetscMalloc(new_size*sizeof(Scalar)); CHKPTRQ(tmp);
      if (dp->write)
	PetscMemcpy(tmp,dp->lvalues,dp->write*sizeof(Scalar));
      PetscFree(dp->lvalues); dp->lvalues = tmp;

      tmp = (Scalar *) PetscMalloc(new_size*sizeof(Scalar)); CHKPTRQ(tmp);
      if (dp->high_read)
	PetscMemcpy(tmp,dp->gvalues,dp->high_read*sizeof(Scalar));
      PetscFree(dp->gvalues); dp->gvalues = tmp;
      dp->size = new_size;
    }
  }
  dp->lvalues[the_loc] = res;
  dp->state = WRITE_STATE;
  PetscFunctionReturn(0);
}
int DotProductsSetT(DotProducts dp,Vec x,Vec y)
{
  Scalar res;
  int the_loc,ierr;

  PetscFunctionBegin;
  /* compute local result */
  ierr = VecTDot_Seq(x,y,&res); CHKERRQ(ierr);
  /* decide where we are going to write it */
#if PETSC_DECIDE >= 0
#error "PETSC_DECIDE conflict: has to be negative"
#endif
  the_loc = dp->write++;
  /* make sure that there is a place to write it */
  if (the_loc>=dp->size) {
    if (dp->size==0) {
      /* initial allocation */
      dp->lvalues = (Scalar *) PetscMalloc(100*sizeof(Scalar));
      CHKPTRQ(dp->lvalues);
      dp->gvalues = (Scalar *) PetscMalloc(100*sizeof(Scalar));
      CHKPTRQ(dp->gvalues);
      dp->size = 100;
    } else {
      /* double existing storage */
      Scalar *tmp; int new_size;
      new_size = 2*dp->size;

      tmp = (Scalar *) PetscMalloc(new_size*sizeof(Scalar)); CHKPTRQ(tmp);
      if (dp->write)
	PetscMemcpy(tmp,dp->lvalues,dp->write*sizeof(Scalar));
      PetscFree(dp->lvalues); dp->lvalues = tmp;

      tmp = (Scalar *) PetscMalloc(new_size*sizeof(Scalar)); CHKPTRQ(tmp);
      if (dp->high_read)
	PetscMemcpy(tmp,dp->gvalues,dp->high_read*sizeof(Scalar));
      PetscFree(dp->gvalues); dp->gvalues = tmp;
      dp->size = new_size;
    }
  }
  dp->lvalues[the_loc] = res;
  dp->state = WRITE_STATE;
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DotProductsGet"
/*
  DotProductsGet
  Retrieve a dot product value.
*/
int DotProductsGet(DotProducts dp,Scalar *r)
{
  PetscFunctionBegin;
  if (dp->state==WRITE_STATE) {
    /* we have been writing local results so far, time for communication */
#if defined(USE_PETSC_COMPLEX)
    MPI_Allreduce(dp->lvalues,dp->gvalues,
		  2*dp->write,MPI_DOUBLE,MPI_SUM,dp->comm);
    PLogFlops(2*(dp->write-1));
#else
    MPI_Allreduce((void*)dp->lvalues,(void*)dp->gvalues,
		  dp->write,MPI_DOUBLE,MPI_SUM,dp->comm);
    PLogFlops(dp->write-1);
#endif
    dp->state = READ_STATE;
    dp->read = 0; dp->high_read = dp->write-1; dp->write = 0;
  }
  if (dp->read>dp->high_read)
    SETERRQ(1,0,"Too many results requested");
  *r = dp->gvalues[dp->read++];
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DotProductsClear"
/*
  DotProductsClear
*/
int DotProductsClear(DotProducts dp)
{
  PetscFunctionBegin;
  dp->write = dp->high_read = dp->read = 0;
  PetscMemzero(dp->lvalues,dp->size*sizeof(Scalar));
  PetscMemzero(dp->gvalues,dp->size*sizeof(Scalar));
  PetscFunctionReturn(0);
}
