
#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <ctype.h>
#if __STD_C
#include <stddef.h>   /* for size_t */
#else
#include <sys/types.h>
#endif
/*extern void* malloc(unsigned int);*/

#define TRUE  1
#define FALSE 0

include(`cinclude.m4') /* This includes a file of m4 macros */

#define DESC_TYPE_CSC 10001

#define CSC_DESC_TYPE 0
#define CSC_DESC_M    1
#define CSC_DESC_N    2
#define CSC_DESC_NNZ  3
#define CSC_DESC_BASE 4
#define CSC_DESC_COLPTR 5
#define CSC_DESC_ROWIND(desc) (5+((desc)[2])+1)

typedef struct {
  integer_t  lu_size;
  scalar_t*  lu_nz;
  integer_t* lu_rowind;
  integer_t* l_colptr;
  integer_t* u_colptr;

  integer_t* row_perm;
  integer_t* col_perm;

} lu_t;

typedef void (*statistics_reporter_t)(void* context, 
                                      character_t* name, 
                                      double_t* val);

prec_prefix(suffix(gstrf_gp)) ( 
                      void       handle_in_arg(gp),
		      integer_t  scalar_in_arg(mrows),
		      integer_t  scalar_in_arg(ncols),
		      scalar_t*  a_nz,
		      integer_t* desc_a,
		      lu_t       handle_out_arg(lu),
		      integer_t  scalar_out_arg(info)
		     )
{
  integer_t* a_rowind;
  integer_t* a_colptr;

  /* work arrays */
  
  scalar_t*    rwork  = NULL;
  scalar_t*    twork  = NULL;

  integer_t*   found  = NULL;
  integer_t*   parent = NULL;
  integer_t*   child  = NULL;

  integer_t*   pattern  = NULL;

  integer_t*   cmatch   = NULL;
  integer_t*   rmatch   = NULL;

  /* copies of object parameters */

  integer_t pivot_policy; 
  double_t  pivot_threshold;
  double_t  drop_threshold;
  double_t  col_fill_ratio;
  double_t  fill_ratio;
  double_t  expand_ratio;

  /* local variables */

  integer_t nrow = scalar_in(mrows);
  integer_t ncol = scalar_in(ncols);

  integer_t a_desc_type, a_m, a_n, a_nnz, a_base;

  integer_t jcol, i; 
  integer_t lasta; 
  integer_t lastlu;
  integer_t zpivot; 

  integer_t local_pivot_policy;
  integer_t nz_count_limit;

  integer_t* user_col_perm;
  integer_t  user_col_perm_length;
  integer_t  user_col_perm_base;

  statistics_reporter_t reporter_func;
  void*                 reporter_ctxt;

  double_t   flops = 0.0;

  double_t   ujj, minujj;

  int       out_of_mem = FALSE;
  int       eline = -1;

  int pivt_row, orig_row, this_col, othr_col;

  /* constants */

  integer_t izero = 0;
  scalar_t  zero  = 0.0; /* this is not good for complex !!! replace with macro */

  /* extract data from gp object */

  if (handle_in(gp) == NULL) {
    scalar_out(info) = -1;
    goto free_and_exit;
  }
  suffix(gp_get_pivot_policy)        (handle_in(gp),&pivot_policy);
  suffix(gp_get_pivot_threshold)     (handle_in(gp),&pivot_threshold);
  suffix(gp_get_drop_threshold)      (handle_in(gp),&drop_threshold);
  suffix(gp_get_col_fill_ratio)      (handle_in(gp),&col_fill_ratio);
  suffix(gp_get_fill_ratio)          (handle_in(gp),&fill_ratio);
  suffix(gp_get_expand_ratio)        (handle_in(gp),&expand_ratio);
  suffix(gp_get_statistics_reporter) (handle_in(gp),&reporter_func,
                                                    &reporter_ctxt);
  suffix(gp_get_col_perm)            (handle_in(gp),&user_col_perm,
                                                    &user_col_perm_length,
                                                    &user_col_perm_base);

  /*
  printf("piv pol=%d piv_thr=%lf drop_thr=%lf col_fill_rt=%lf\n",
	pivot_policy,pivot_threshold,drop_threshold,col_fill_ratio);
	*/

  /*pivot_threshold = 0.001;	*/
  /*if (pivot_threshold == 0.0) pivot_policy = 0;*/ /* no pivoting */
  /*pivot_policy = 0;*/ /* no pivoting */

  /* 
     if a column permutation is specified, 
     it must be a length ncol permutation.
  */

  if (user_col_perm != NULL && user_col_perm_length != ncol) {
    scalar_out(info) = -1;
    goto free_and_exit;
  }

  /* extract data from a's array descriptor */

  a_desc_type = desc_a[CSC_DESC_TYPE];
  if (a_desc_type != DESC_TYPE_CSC) { 
    scalar_out(info) = -5;
    goto free_and_exit;
  }
  a_m      = desc_a[CSC_DESC_M];
  a_n      = desc_a[CSC_DESC_N];
  a_nnz    = desc_a[CSC_DESC_NNZ];
  a_base   = desc_a[CSC_DESC_BASE];
  a_colptr = &( desc_a[CSC_DESC_COLPTR] );
  a_rowind = &( desc_a[CSC_DESC_ROWIND(desc_a)] );

  /* convert the descriptor to 1-base if necessary */

  if (a_base == 0) {
    for (jcol=0; jcol<(a_n+1); jcol++) (a_colptr[jcol])++;
    for (jcol=0; jcol<(a_nnz); jcol++) (a_rowind[jcol])++;
    desc_a[CSC_DESC_BASE] = 1;
    a_base                = 1;
  }

  /* Allocate work arrays. */

  if ((rwork  = (scalar_t*) malloc( nrow * sizeof(scalar_t)) ) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }

  if ((twork  = (scalar_t*) malloc( nrow * sizeof(scalar_t)) ) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }

  if ((found  = (integer_t*) malloc( nrow * sizeof(integer_t) )) == NULL) 
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }
  if ((child  = (integer_t*) malloc( nrow * sizeof(integer_t) )) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }
  if ((parent = (integer_t*) malloc( nrow * sizeof(integer_t) )) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }

  if ((pattern  = (integer_t*) malloc( nrow * sizeof(integer_t) )) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }

  /* Create lu structure */

  if ((handle_out(lu)  = (lu_t*) malloc( sizeof(lu_t) )) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }

  handle_out(lu)->lu_nz     = NULL;
  handle_out(lu)->lu_rowind = NULL;
  handle_out(lu)->l_colptr  = NULL;
  handle_out(lu)->u_colptr  = NULL;
  handle_out(lu)->row_perm  = NULL;
  handle_out(lu)->col_perm  = NULL;
  handle_out(lu)->lu_size   = a_nnz * fill_ratio;

  if ((handle_out(lu)->lu_nz = 
        (scalar_t*) malloc( (handle_out(lu)->lu_size) * sizeof(scalar_t) )) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }	

  if ((handle_out(lu)->lu_rowind = 
       (integer_t*) malloc( (handle_out(lu)->lu_size) * sizeof(integer_t) )) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }	

  if ((handle_out(lu)->u_colptr = 
        (integer_t*) malloc( (ncol+1) * sizeof(integer_t) )) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }	

  if ((handle_out(lu)->l_colptr = 
        (integer_t*) malloc( (ncol) * sizeof(integer_t) )) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }	

  if ((handle_out(lu)->row_perm = 
        (integer_t*) malloc( (nrow) * sizeof(integer_t) )) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }	

  if ((handle_out(lu)->col_perm = 
        (integer_t*) malloc( (ncol) * sizeof(integer_t) )) == NULL)
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }	


  /* Compute max matching. We use elements of the lu structure */
  /* for all the temporary arrays needed.                      */

  if ((cmatch = (integer_t*) malloc( ncol * sizeof(integer_t) )) == NULL) 
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }
  if ((rmatch = (integer_t*) malloc( nrow * sizeof(integer_t) )) == NULL) 
    { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }

  for (jcol = 0; jcol < ncol; jcol++) {
    (handle_out(lu)->l_colptr)[jcol]  =
    (handle_out(lu)->u_colptr)[jcol]  =
    (handle_out(lu)->col_perm)[jcol]  =
    (handle_out(lu)->lu_rowind)[jcol] =
    rmatch[jcol]                      = 
    cmatch[jcol]                      = 0;
  }
  for ( i = 0; i < nrow; i++) {
    (handle_out(lu)->row_perm)[i] = rmatch[i] = 0;
  }
/*
  for ( i = 0; i < a_nnz; i++) {
    a_nz[i] = ((double) random()) / ((double) 0x7fffffff);
  }
  for ( i = 0; i < a_nnz; i++) {
    if (a_nz[i] == 0.0) printf("Warning: numerically zero value in A\n");
  }
*/
  fortran_suffix(maxmatch) ( 
                            &nrow ,                    /* in. */
       			    &ncol ,                    /* in. */
                            a_colptr,                  /* in. */
                            a_rowind,                  /* in. */
                            handle_out(lu)->l_colptr,  /* work. prevcl(cols) */
                            handle_out(lu)->u_colptr,  /* work. prevrw(cols) */
                            handle_out(lu)->row_perm,  /* work. marker(rows) */
                            handle_out(lu)->col_perm,  /* work. tryrow(cols) */
                            handle_out(lu)->lu_rowind, /* work. nxtchp(cols) */
                            rmatch,                    /* out.  rowset(rows) */
                            cmatch                     /* out.  colset(cols) */ 
                           );
  
  for (jcol = 0; jcol < ncol; jcol++) 
    if (cmatch[jcol] == 0) {
      printf("Warning: Perfect matching not found\n");
      break;
    }

/*
  for (jcol = 0; jcol < ncol; jcol++) 
    cmatch[jcol] = rmatch[jcol] = jcol+1;
*/
  /* Initialize useful values and zero out the dense vectors.  
     If we are threshold pivoting, get row counts. */

  lastlu = 0;

  local_pivot_policy = pivot_policy;
  scalar_out(info) = 0;
  lasta = a_colptr[ncol] - 1;
  (handle_out(lu)->u_colptr)[0] = 1;
  
  fortran_suffix(ifill) (pattern, &nrow, &izero);
  fortran_suffix(ifill) (found, &nrow, &izero);
  fortran_suffix(rfill) (rwork, &nrow, &zero);
  fortran_suffix(ifill) (handle_out(lu)->row_perm, &nrow, &izero);

  if (user_col_perm == NULL) {
    for (jcol=0; jcol<ncol; jcol++) 
      (handle_out(lu)->col_perm)[jcol] = jcol + 1;
  } else {
    printf("user_col_perm_base = %d\n",user_col_perm_base);
    for (jcol=0; jcol<ncol; jcol++) 
      (handle_out(lu)->col_perm)[jcol] = user_col_perm[jcol] + (1 - user_col_perm_base);
  }

  /* compute one column at a time */

  for ( jcol = 1; jcol <= ncol; jcol++) {

    /* mark pointer to new column, ensure it is large enough */

    if (lastlu + nrow >= handle_out(lu)->lu_size) {
      int new_size = (int) ( handle_out(lu)->lu_size * expand_ratio );

      /* fprintf(stderr,"expanding to %d nonzeros...\n",new_size);*/

      if ((handle_out(lu)->lu_nz = 
            (scalar_t*) realloc( handle_out(lu)->lu_nz,
                                 (new_size * sizeof(scalar_t)) )) == NULL)
        { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }	

      if ((handle_out(lu)->lu_rowind = 
           (integer_t*) realloc( handle_out(lu)->lu_rowind,
                                 (new_size * sizeof(integer_t)) )) == NULL)
        { out_of_mem = TRUE; eline = __LINE__; goto free_and_exit; }	

      handle_out(lu)->lu_size = new_size;
      }

    /* Set up nonzero pattern */

    {
      int jjj;

      jjj = (handle_out(lu)->col_perm)[jcol-1];
      for (i = a_colptr[jjj-1];
           i < a_colptr[jjj];
           i++) {
        pattern[ a_rowind[i-1] - 1 ] = 1;
      }

      this_col = (handle_out(lu)->col_perm)[jcol-1];
      orig_row = cmatch[ this_col - 1 ];

      pattern[ orig_row - 1 ] = 2;

      if ((handle_out(lu)->row_perm)[ orig_row - 1 ] != 0) {
	printf("ERROR: PIVOT ROW FROM MAX-MATCHING ALREADY USED.\n");
	exit(1);
      }
/*
      pattern[ this_col - 1 ] = 2;
*/
    }


    /*
      Depth-first search from each above-diagonal nonzero of column
      jcol of A, allocating storage for column jcol of U in
      topological order and also for the non-fill part of column
      jcol of L.
    */

    fortran_suffix(ludfs) (
                           &jcol, 
                           a_nz, a_rowind, a_colptr, 
                           &lastlu,
                           handle_out(lu)->lu_rowind, 
                           handle_out(lu)->l_colptr, handle_out(lu)->u_colptr, 
                           handle_out(lu)->row_perm,
                           handle_out(lu)->col_perm,
                           rwork,
                           found, parent, child,
                           info
                          );

    if (scalar_out(info) != 0) { scalar_out(info) = -100; goto free_and_exit; }

    /*
      Compute the values of column jcol of L and U in the dense
      vector, allocating storage for fill in L as necessary.
    */
    
    fortran_suffix(lucomp) (
                            &jcol, 
                            &lastlu, 
                            handle_out(lu)->lu_nz, handle_out(lu)->lu_rowind, 
                            handle_out(lu)->l_colptr, handle_out(lu)->u_colptr, 
                            handle_out(lu)->row_perm,
                            handle_out(lu)->col_perm,
                            rwork, 
                            found,
	                    pattern, 
                            &flops
                           );

/*
    if (rwork[ orig_row - 1 ] == 0.0) {
	printf("WARNING: MATCHING TO A ZERO\n");

        for (i=a_colptr[jcol-1];i<a_colptr[jcol]; i++)
          printf("(%d,%lg) ",a_rowind[i-1],a_nz[i-1]);
     	printf(". orig_row=%d\n",orig_row);

    }
*/

    /*
      Copy the dense vector into the sparse data structure, find the
      diagonal element (pivoting if specified), and divide the
      column of L by it.
    */

    nz_count_limit = (int) (col_fill_ratio * 
                            ((double)(a_colptr[this_col]-a_colptr[this_col-1] + 1)));

    fortran_suffix(lucopy) (
                            &local_pivot_policy, 
                            &pivot_threshold, 
                            &drop_threshold, 
	                    &nz_count_limit,
                            &jcol, 
                            &ncol, 
                            &lastlu, 
                            handle_out(lu)->lu_nz, handle_out(lu)->lu_rowind, 
                            handle_out(lu)->l_colptr, handle_out(lu)->u_colptr, 
                            handle_out(lu)->row_perm,
                            handle_out(lu)->col_perm,
                            rwork,
	                    pattern, 
	                    twork,
                            &flops,
                            &zpivot
                           );


    if (zpivot == -1) {
      scalar_out(info) = jcol;
      goto free_and_exit;
    }

    {
      int jjj;

      jjj = (handle_out(lu)->col_perm)[jcol-1];
      for (i = a_colptr[jjj-1];
           i < a_colptr[jjj];
           i++) {
        pattern[ a_rowind[i-1] - 1 ] = 0;
      }

      pattern[ orig_row - 1 ] = 0;

      pivt_row = zpivot;
      othr_col = rmatch[ pivt_row - 1 ];

      cmatch[ this_col - 1 ] = pivt_row;
      cmatch[ othr_col - 1 ] = orig_row;
      rmatch[ orig_row - 1 ] = othr_col;
      rmatch[ pivt_row - 1 ] = this_col;

/*
      pattern[ this_col - 1 ] = 0;
*/
    }

    /*
      If there are no diagonal elements after this column, change
      the pivot mode.
    */
    
    if (jcol == nrow) local_pivot_policy = -1;
 
  } /* end of jcol loop */

  /*
    Fill in the zero entries of the permutation vector, and renumber the
    rows so the data structure represents L and U, not PtL and PtU.
  */

  jcol = ncol + 1;
  for (i = 0; i < nrow; i++) {
    if ((handle_out(lu)->row_perm)[i] == 0) {
      (handle_out(lu)->row_perm)[i] = jcol;
      jcol = jcol + 1;
    }
  }

  for (i = 0; i < lastlu; i++)
    (handle_out(lu)->lu_rowind)[i] = (handle_out(lu)->row_perm)[(handle_out(lu)->lu_rowind)[i]-1];

  /* Return */

free_and_exit:

/*
  printf("rperm:\n[");
  for (i=0; i<ncol; i++) printf("%d ",(handle_out(lu)->row_perm)[i]);
  printf("];\n");

  printf("cperm:\n[");
  for (i=0; i<ncol; i++) printf("%d ",(handle_out(lu)->col_perm)[i]);
  printf("];\n");
*/
  if (out_of_mem) {
    fprintf(stderr,
            "Out of space in gstrf_gp. Limit of maxlu=%d exceeded at column %d line %d\n",
            handle_out(lu)->lu_size,jcol,eline);
    scalar_out(info) = -999;
  }
 


  if (rmatch) free(rmatch);
  if (cmatch) free(cmatch);

  if (pattern) free(pattern);

  if (parent) free(parent);
  if (child)  free(child);
  if (found)  free(found);
  if (twork)  free(rwork);
  if (rwork)  free(rwork);

  if (scalar_out(info) != 0) { 
    if handle_out(lu) {

      if (handle_out(lu)->row_perm)  free(handle_out(lu)->col_perm);
      if (handle_out(lu)->row_perm)  free(handle_out(lu)->row_perm);
      if (handle_out(lu)->u_colptr)  free(handle_out(lu)->u_colptr);
      if (handle_out(lu)->l_colptr)  free(handle_out(lu)->l_colptr);
      if (handle_out(lu)->lu_rowind) free(handle_out(lu)->lu_rowind);
      if (handle_out(lu)->lu_nz)     free(handle_out(lu)->lu_nz);

      free handle_out(lu);
      *lu = NULL;
    }
  } else {
    minujj = 1.0 / 0.0;

    for (jcol=1; jcol<=ncol; jcol++) {
      ujj = fabs((handle_out(lu)->lu_nz)[(handle_out(lu)->l_colptr)[jcol-1]-2]);
      if (ujj < minujj) minujj = ujj;
    }

    /*printf(">>> last = %lg, min = %lg\n",ujj,minujj);*/
  }




  if (reporter_func) {
    (*reporter_func)(reporter_ctxt,"FLOPS",&flops);
    flops = (double) lastlu;
    (*reporter_func)(reporter_ctxt,"NONZEROS",&flops);
  }

  return;
}

prec_prefix(suffix(gstrs_gp)) ( 
                      void         handle_in_arg(gp),
                      character_t* trans,
		      integer_t    scalar_in_arg(n),
		      integer_t    scalar_in_arg(nrhs),
                      lu_t         handle_in_arg(lu),
                      integer_t    scalar_in_arg(ia),
                      integer_t    scalar_in_arg(ja),
                      scalar_t*    b,
                      integer_t    scalar_in_arg(ib),
                      integer_t    scalar_in_arg(jb),
                      integer_t*   desc_b,
		      integer_t    scalar_out_arg(info)
		     )
{
  scalar_t*    rwork  = NULL;

  integer_t b_desc_type, b_ld;

  statistics_reporter_t reporter_func;
  void*                 reporter_ctxt;

  double_t   flops = 0.0;

  scalar_out(info) = 0;

  /* extract data from gp object */

  if (handle_in(gp) == NULL) {
    scalar_out(info) = -1;
    goto free_and_exit;
  }

  suffix(gp_get_statistics_reporter) (handle_in(gp),&reporter_func,
                                                    &reporter_ctxt);

  if (scalar_in(ia) != 1) { scalar_out(info) = -5; goto free_and_exit; }  
  if (scalar_in(ja) != 1) { scalar_out(info) = -6; goto free_and_exit; }  
  if (scalar_in(ib) != 1) { scalar_out(info) = -8; goto free_and_exit; }  
  if (scalar_in(jb) != 1) { scalar_out(info) = -9; goto free_and_exit; }  

  if (scalar_in(nrhs) != 1) { scalar_out(info) = -3; goto free_and_exit; }  

  /* we do not need this now since we assume a single dense rhs */

  /*
  b_desc_type = desc_b[0];
  if (b_desc_type != DESC_TYPE_DENSE) { scalar_out(info) = -10; goto free_and_exit; }
  b_ld = desc_b[1];
  */

  if ((rwork  = 
         (scalar_t*) malloc( scalar_in(n) * sizeof(scalar_t) )) == NULL)
    { scalar_out(info) = -999; goto free_and_exit; }

  if (toupper(trans[0]) == 'N') {

    fortran_suffix(lsolve) (&(scalar_in(n)), 
                            handle_in(lu)->lu_nz, handle_in(lu)->lu_rowind, 
                            handle_in(lu)->l_colptr, handle_in(lu)->u_colptr, 
                            handle_in(lu)->row_perm, 
                            handle_in(lu)->col_perm, 
                            b, 
                            rwork,
	                    info);

    fortran_suffix(usolve) (&(scalar_in(n)), 
                            handle_in(lu)->lu_nz, handle_in(lu)->lu_rowind, 
                            handle_in(lu)->l_colptr, handle_in(lu)->u_colptr, 
                            handle_in(lu)->row_perm, 
                            handle_in(lu)->col_perm, 
                            rwork, 
                            b,
                            info);

  } else if (toupper(trans[0]) == 'T') {

    fortran_suffix(utsolve) (&(scalar_in(n)), 
                            handle_in(lu)->lu_nz, handle_in(lu)->lu_rowind, 
                            handle_in(lu)->l_colptr, handle_in(lu)->u_colptr, 
                            handle_in(lu)->row_perm, 
                            handle_in(lu)->col_perm, 
                            b, 
                            rwork,
	                    info);

    fortran_suffix(ltsolve) (&(scalar_in(n)), 
                            handle_in(lu)->lu_nz, handle_in(lu)->lu_rowind, 
                            handle_in(lu)->l_colptr, handle_in(lu)->u_colptr, 
                            handle_in(lu)->row_perm, 
                            handle_in(lu)->col_perm, 
                            rwork, 
                            b,
                            info);

  } else { scalar_out(info) = -1; goto free_and_exit; }

  flops = (double) ( 2 * (handle_in(lu)->u_colptr[scalar_in(n)] - 1) );

free_and_exit:
  if (rwork)  free(rwork);

  if (reporter_func) 
    (*reporter_func)(reporter_ctxt,"FLOPS",&flops);

  return;
}

suffix(gstff_gp) (lu_t handle_in_arg(lu))
{
  if ( handle_in(lu) ) {

    if (handle_in(lu)->row_perm)  free(handle_in(lu)->col_perm);
    if (handle_in(lu)->row_perm)  free(handle_in(lu)->row_perm);
    if (handle_in(lu)->u_colptr)  free(handle_in(lu)->u_colptr);
    if (handle_in(lu)->l_colptr)  free(handle_in(lu)->l_colptr);
    if (handle_in(lu)->lu_rowind) free(handle_in(lu)->lu_rowind);
    if (handle_in(lu)->lu_nz)     free(handle_in(lu)->lu_nz);

    free( handle_in(lu) );
  }
}

