
#include <stdio.h>
#include <math.h>
#include <mpi.h>

#include "counter.h"

int Factor( int divisor, int* quotient, int **factors, int *numf, int *max_n_of_factors) {
      int i;

      while (((*quotient)%divisor) == 0)
      {
         (*factors)[(*numf)++] = divisor;

         /*
          * Copy and enlarge the array
          */
         if ((*numf) >= (*max_n_of_factors))
         {
            int tempmax;
            int *tempf = (int*)malloc(
                               sizeof(int)
                               *
                               (*max_n_of_factors)
            );

            if (tempf == NULL)
            {
               return -1;
            }

            for (i = 0; i < (*max_n_of_factors); i++)
            {
               tempf[i] = (*factors)[i];
            }

            free(factors[0]);

            tempmax = (*max_n_of_factors);
            (*max_n_of_factors) *= 2;

            factors[0] = (int*)malloc(
                               sizeof(int)
                               *
                               (*max_n_of_factors)
            );

            if (factors[0] == NULL)
            {
               return -1;
            }

            for (i = 0; i < tempmax; i++)
            {
               (*factors)[i] = tempf[i];
            }

            free(tempf);
         }

         (*quotient) /= divisor;
      }

      return 0;
}

int Get_factors( int n, int *numf, int **result) {
      int rc;
      int quotient = n;
      int divisor, maxDivisor;
      int max_n_of_factors = 2;

      *numf = 0;

      result[0] = (int*)malloc(
                        sizeof(int)
                        *
                        max_n_of_factors
      );

      if (result[0] == NULL)
      {
         return -1;
      }

      //
      // Try special cases of 2 and 3
      rc = Factor(
           2,
           &quotient,
           result,
           numf,
           &max_n_of_factors
      );

      if (rc != 0)
      {
         return rc;
      }

      rc = Factor(
           3,
           &quotient,
           result,
           numf,
           &max_n_of_factors
      );

      if (rc != 0)
      {
         return rc;
      }

      //
      // Try pairs of the form 6m-1 and 6m+1
      // (i.e. 5, 7, 11, 13, 17, 19, . .)
      maxDivisor = sqrt(quotient);
      for (divisor = 5; divisor <= maxDivisor; divisor+=6)
      {
         rc = Factor(
              divisor,
              &quotient,
              result,
              numf,
              &max_n_of_factors
         );

         if (rc != 0) 
         {
            return rc;
         }

         rc = Factor(
              divisor+2,
              &quotient,
              result,
              numf,
              &max_n_of_factors
         );

         if (rc != 0)
         {
            return rc;
         }
      }

      // store final factor
      if (quotient > 1)
      {
         (*result)[(*numf)++] = quotient;
      }

      return 0;
}

int main(int argc, char *argv[])
{
  int i, j, k, me;
  double tstart, tend, *mysquare, *new;
  double *ml, *mr, *mb, *mt;
  double *templ, *tempr, *tempb, *tempt;
  int p;
  int p1 = 1;
  int p2 = 1;
  int m1, m2;
  int tag = 0xff;
  MPI_Status status;
  int left, right, top, bottom;
  int icoord, jcoord;
  unsigned char debug = 0;
  unsigned char output = 1;

  MPI_Init(&argc, &argv);
  MPI_Comm_size(MPI_COMM_WORLD, &p);
  MPI_Comm_rank(MPI_COMM_WORLD, &me);

  MPI_Barrier(MPI_COMM_WORLD);

  if (me == 0)
  {
     tstart = MPI_Wtime();
  }

  p1 = sqrt(p);
  p2 = sqrt(p);

  if ((p1*p2) != p)
  {
     int rc, numf;
     int **result = (int**)malloc(
                           sizeof(int*)
     );

     if (result == NULL)
     {
        return -1;
     }

     rc = Get_factors(
             p,
             &numf,
             result
     );

     if (rc != 0)
     {
        return rc;
     }

     for (p1 = 1, j = 0; j < numf; j+=2)
     {
         p1 *= (*result)[j];
     }

     for (p2 = 1, j = 1; j < numf; j+=2)
     {
        p2 *= (*result)[j];
     }

     free(result[0]);
     free(result);
  }

  if ((debug) && (me == 0))
  {
     printf("Processor grid is (%d,%d)\n", p1, p2);
  }

  icoord = me/p2;
  jcoord = me%p2;

  /* 
   * Compute size of local block 
   */
  m1 = N/p1;

  for (i = 0; i < (N%p1); i++)
  {
     if ((jcoord == 0)
         && (icoord == i)
     )
     {
        m1++;
     }
  }

  m2 = N/p2;

  for (i = 0; i < (N%p2); i++)
  {
     if ((icoord == 0)
         && (jcoord == i)
     )
     {
        m2++;
     }
  }

  if (debug)
  {
     printf("me=%d, allocation=(%d,%d)\n", me, m1, m2);
  }

  /* 
   * set up top, bottom, left and right neighbors
   */
  top = me - p2;
  bottom = me + p2;
  left = me - 1;
  right = me + 1;

  if (icoord == 0)
  {
     top = MPI_PROC_NULL;
  }

  if (jcoord == 0)
  {
     left = MPI_PROC_NULL;
  }

  if (icoord == (p1-1))
  {
     bottom = MPI_PROC_NULL;
  }

  if (jcoord == (p2-1))
  {
     right = MPI_PROC_NULL;
  }

  if (debug)
  {
     printf("me=%d, top=%d, left=%d, bottom=%d, right=%d\n", me, top, left, bottom, right);
  }

  /* 
   * All outer boundary elements = 10.0
   * Rest of the elements = 0
   */
  mysquare = (double*)malloc(sizeof(double)*(m1)*(m2));
  if (mysquare == NULL)
  {
     printf("Process(%d): Cannot allocate mysquare\n", me);
     return -1;
  }

  new = (double*)malloc(sizeof(double)*(m1)*(m2));
  if (new == NULL)
  {
     printf("Process(%d): Cannot allocate new\n", me);
     return -1;
  }

  for (i = 0; i < (m1*m2); i++)
  {
     mysquare[i] = 0.0;
     new[i] = 0.0;
  }

  for (i = 0; i < m1; i++)
  {
     mysquare[i*m2] = 10.0;
     new[i*m2] = 10.0;
     mysquare[i*m2 + m2 - 1] = 10.0;
     new[i*m2 + m2 - 1] = 10.0;
  }

  for (i = 0; i < m2; i++)
  {
     mysquare[i] = 10.0;
     new[i] = 10.0;
     mysquare[(m1-1)*m2 + i] = 10.0;
     new[(m1-1)*m2 + i] = 10.0;
  }

  if (icoord == 0)
  {
     for (i = 0; i < m2; i++)
     {
        mysquare[i] = 10.0;
        new[i] = 10.0;
     }
  }

  if (jcoord == 0)
  {
     for (i = 0; i < m1; i++)
     {
        mysquare[i*m2] = 10.0;
        new[i*m2] = 10.0;
     }
  }

  if (icoord == (p1-1))
  {
     for (i = 0; i < m2; i++)
     {
        mysquare[(m1-1)*m2 + i] = 10.0;
        new[(m1-1)*m2 + i] = 10.0;
     }
  }

  if (jcoord == (p2-1))
  {
     for (i = 0; i < m1; i++)
     {
        mysquare[i*m2 + m2 - 1] = 10.0;
        new[i*m2 + m2 - 1] = 10.0;
     }
  }

  /*
   * Allocate temporaries
   */
  tempr = (double*)malloc(sizeof(double)*m1);
  if (tempr == NULL)
  {
     printf("Process(%d): Cannot allocate tempr\n", me);
     return -1;
  }

  templ = (double*)malloc(sizeof(double)*m1);
  if (templ == NULL)
  {
     printf("Process(%d): Cannot allocate templ\n", me);
     return -1;
  }

  tempt = (double*)malloc(sizeof(double)*m2);
  if (tempt == NULL)
  {
     printf("Process(%d): Cannot allocate tempt\n", me);
     return -1;
  }

  tempb = (double*)malloc(sizeof(double)*m2);
  if (tempb == NULL)
  {
     printf("Process(%d): Cannot allocate tempb\n", me);
     return -1;
  }

  for (i = 0; i < m1; i++)
  {
     templ[i] = 0.0;
     tempr[i] = 0.0;
  }

  for (i = 0; i < m2; i++)
  {
     tempt[i] = 0.0;
     tempb[i] = 0.0;
  }

  mr = (double*)malloc(sizeof(double)*m1);
  if (mr == NULL)
  {
     printf("Process(%d): Cannot allocate mr\n", me);
     return -1;
  }

  ml = (double*)malloc(sizeof(double)*m1);
  if (ml == NULL)
  {
     printf("Process(%d): Cannot allocate ml\n", me);
     return -1;
  }

  mt = (double*)malloc(sizeof(double)*m2);
  if (mt == NULL)
  {
     printf("Process(%d): Cannot allocate mt\n", me);
     return -1;
  }

  mb = (double*)malloc(sizeof(double)*m2);
  if (mb == NULL)
  {
     printf("Process(%d): Cannot allocate mb\n", me);
     return -1;
  }

  for (k = 0; k < NUM_OF_ITERATIONS; k++)
  {
      if (left != MPI_PROC_NULL)
      {
         if (debug)
         {
            printf("me=%d receiving from left=%d\n", me, left);
         }

         MPI_Recv(tempr, m1, MPI_DOUBLE, left, tag, MPI_COMM_WORLD, &status);

         for (i = 0; i < m1; i++)
         {
            ml[i] = new[i*m2];
         }
         
         if (debug)
         {
            printf("me=%d sending to left=%d\n", me, left);
         }

         MPI_Send(ml, m1, MPI_DOUBLE, left, tag, MPI_COMM_WORLD);
      }

      if (right != MPI_PROC_NULL)
      {
         for (i = 0; i < m1; i++)
         {
            mr[i] = new[i*m2 + m2 - 1];
         }
         
         if (debug)
         {
            printf("me=%d sending to right=%d\n", me, right);
         }

         MPI_Send(mr, m1, MPI_DOUBLE, right, tag, MPI_COMM_WORLD);

         if (debug)
         {
            printf("me=%d receiving from right=%d\n", me, right);
         }

         MPI_Recv(templ, m1, MPI_DOUBLE, right, tag, MPI_COMM_WORLD, &status);
      }

      if (top != MPI_PROC_NULL)
      {
         if (debug)
         {
            printf("me=%d receiving from top=%d\n", me, top);
         }

         MPI_Recv(tempb, m2, MPI_DOUBLE, top, tag, MPI_COMM_WORLD, &status);

         for (i = 0; i < m2; i++)
         {
            mt[i] = new[i];
         }
         
         if (debug)
         {
            printf("me=%d sending to top=%d\n", me, top);
         }

         MPI_Send(mt, m2, MPI_DOUBLE, top, tag, MPI_COMM_WORLD);
      }

      if (bottom != MPI_PROC_NULL)
      {
         for (i = 0; i < m2; i++)
         {
            mb[i] = new[(m1-1)*m2 + i];
         }
         
         if (debug)
         {
            printf("me=%d sending to bottom=%d\n", me, bottom);
         }

         MPI_Send(mb, m2, MPI_DOUBLE, bottom, tag, MPI_COMM_WORLD);

         if (debug)
         {
            printf("me=%d receiving from bottom=%d\n", me, bottom);
         }

         MPI_Recv(tempt, m2, MPI_DOUBLE, bottom, tag, MPI_COMM_WORLD, &status);
      }

      /*
       * Update the elements
       */
      {
         for (i = 0; i < m1; i++)
         {
	    for (j = 0; j < m2; j++)
            {
               if (i == 0)
               {
                  if (j == 0)
                  {
	             new[i*m2+j] = 0.25*(mysquare[(i+1)*m2+j]
                                      +tempt[j]
                                      +templ[i]
                                      +mysquare[i*m2+(j+1)]);
                  }
                  else if (j < (m2-1))
                  {
	             new[i*m2+j] = 0.25*(mysquare[(i+1)*m2+j]
                                      +tempt[j]
                                      +mysquare[i*m2+(j-1)]
                                      +mysquare[i*m2+(j+1)]);
                  }
                  else 
                  {
	             new[i*m2+j] = 0.25*(mysquare[(i+1)*m2+j]
                                      +tempt[j]
                                      +tempr[i]
                                      +mysquare[i*m2+(j-1)]);
                  }
               }
               else if (j == 0)
               {
                  if (i < (m1-1))
                  {
	             new[i*m2+j] = 0.25*(mysquare[(i+1)*m2+j]
                                      +templ[i]
                                      +mysquare[(i-1)*m2+j]
                                      +mysquare[i*m2+(j+1)]);
                  }   
                  else
                  {
	             new[i*m2+j] = 0.25*(mysquare[(i-1)*m2+j]
                                      +templ[i]
                                      +tempb[j]
                                      +mysquare[i*m2+(j+1)]);
                  }
               }
               else if (i == (m1-1))
               {
                  if (j < (m2-1))
                  {
	             new[i*m2+j] = 0.25*(mysquare[(i-1)*m2+j]
                                      +mysquare[i*m2+(j-1)]
                                      +tempb[j]
                                      +mysquare[i*m2+(j+1)]);
                  }
                  else
                  {
	             new[i*m2+j] = 0.25*(mysquare[(i-1)*m2+j]
                                      +tempr[i]
                                      +tempb[j]
                                      +mysquare[i*m2+(j-1)]);
                  }
               }
               else if (j == (m2-1))
               {
                  if (i < (m1-1))
                  {
	             new[i*m2+j] = 0.25*(mysquare[(i-1)*m2+j]
                                      +tempr[i]
                                      +mysquare[i*m2+(j-1)]
                                      +mysquare[(i+1)*m2+j]);
                  }
               }
               else
               {
	          new[i*m2+j] = 0.25*(mysquare[(i+1)*m2+j]
                                    +mysquare[(i-1)*m2+j]
                                    +mysquare[i*m2+(j-1)]
                                    +mysquare[i*m2+(j+1)]);
               }
            }
         }
      }
  }

  if (output)
  {
    FILE *fp;
    char fname[20];

    sprintf(fname, "jacobi_output%d.dat", me);
    fp = fopen(fname, "w");
    for (i = 0; i<m1; i++)
    {
       for (j = 0; j<m2; j++)
       {
	  fprintf(fp, "%f ", new[i*m2+j]);
       }
       fprintf(fp, "\n");
    }

    fclose(fp);
  }

  free(mysquare);
  free(new);
  free(ml);
  free(mr);
  free(mt);
  free(mb);
  free(templ);
  free(tempr);
  free(tempt);
  free(tempb);
      
  if (me == 0)
  {
     tend = MPI_Wtime();
     printf("N=%d, Grid=(%d,%d), Time(seconds)=%f\n", N, p1, p2, (tend - tstart));
  }   

  MPI_Finalize();
}
