/*****************************************************************************/
/* mpCred - contains all one need for MPC_Reduce                             */
/* Coded by Dm. Arapov for A.Lastovtsky 1995-1996                            */
/* Release 1.10.96                                                           */
/*****************************************************************************/
#include <stdlib.h>
#include <stdio.h>
#include "mpC.h"
#include <mpi.h>

/******************* iMPC_Op_create_commutative ******************************/   
Int iMPC_Op_create_commutative(MPC_User_function* function, MPC_Op* op)  
  {
    Int RC;
    MPI_Op* pop;

    if ((pop = malloc(sizeof(MPI_Op))) == NULL) return MPC_ERR_NOMEM;
    RC = I2C(MPI_Op_create((MPI_User_function*)function, 1, pop));
    if (RC != MPC_OK) return RC;
    *op = pop;
    return MPC_OK;
  } /* iMPC_Op_create_commutative */

/******************* MPC_Op_create_commutative *******************************/   
Int MPC_Op_create_commutative(MPC_User_function* function, MPC_Op* op)  
  {
    Int RC;

    if (MPC_DEBUG) MPC_Debug_printf( "---[%2d]-> MPC_Op_create_commutative\n", MPC_Net_global.rank );
    RC = iMPC_Op_create_commutative(function, op); 
    if (MPC_DEBUG) MPC_Debug_printf( "<--[%2d]-- MPC_Op_create_commutative, RC = %d\n", MPC_Net_global.rank, RC );
    return RC;                          
  } /* MPC_Op_create_commutative */  

/******************* iMPC_Op_create_uncommutative ****************************/  
Int iMPC_Op_create_uncommutative(MPC_User_function* function, MPC_Op* op)
  {                               
    Int RC;
    MPI_Op* pop;

    if ((pop = malloc(sizeof(MPI_Op))) == NULL) return MPC_ERR_NOMEM;    
    RC = I2C(MPI_Op_create((MPI_User_function*)function, 0, pop));
    if (RC != MPC_OK) return RC;
    *op = pop;
    return MPC_OK;
  } /* iMPC_Op_create_uncommutative */ 

/******************* MPC_Op_create_uncommutative *****************************/  
Int MPC_Op_create_uncommutative(MPC_User_function* function, MPC_Op* op)
  {        
    Int RC;

    if (MPC_DEBUG) MPC_Debug_printf( "---[%2d]-> MPC_Op_create_uncommutative\n", MPC_Net_global.rank );
    RC = iMPC_Op_create_uncommutative(function, op); 
    if (MPC_DEBUG) MPC_Debug_printf( "<--[%2d]-- MPC_Op_create_uncommutative, RC = %d\n", MPC_Net_global.rank, RC );
    return RC;                          
  } /* MPC_Op_create_uncommutative */ 

/******************* iMPC_Op_free ********************************************/ 
Int iMPC_Op_free(MPC_Op* op)
  {
    Int RC;
    MPI_Op* pop;

    pop = (MPI_Op*)(*op);
    RC = I2C(MPI_Op_free(pop));
    if (RC != MPC_OK) return RC;
    free(*op);
    *op = NULL;
    return MPC_OK;
  } /* iMPC_Op_free */

/******************* MPC_Op_free ********************************************/
Int MPC_Op_free(MPC_Op* op)
  {
    Int RC;

    if (MPC_DEBUG) MPC_Debug_printf( "---[%2d]-> MPC_Op_free\n", MPC_Net_global.rank );
    RC = iMPC_Op_free(op); 
    if (MPC_DEBUG) MPC_Debug_printf( "<--[%2d]-- MPC_Op_free, RC = %d\n", MPC_Net_global.rank, RC );
    return RC;
  } /* MPC_Op_free */

/******************* MPC_Pack_coeff *****************************************/
Int MPC_Pack_coeff( MPC_Datatype datatype, MPC_Datatype elemtype )
  {
    Int step;
    MPC_Datatype temptype;

    step = 1;
    temptype = datatype;
    while (temptype != elemtype) {
       step *= MPC_Get_step(temptype);
       temptype = MPC_Get_element(temptype);        
    } /* while */    
    return (step);
  } /* MPC_Pack_coeff */

/******************* MPC_Get_total_count *************************************/
Int MPC_Get_total_count(MPC_Datatype datatype, MPC_Datatype elemtype )
  {
    MPC_Datatype temptype;
    Int total_count = 1;
    temptype = datatype;
    while (temptype != elemtype) {
      total_count *= MPC_Get_count(temptype);
      temptype = MPC_Get_element(temptype);        
    } /* while */
    return (total_count);    
  } /* MPC_Get_total_count */

/******************* forward declaration *************************************/
Int MPC_Pack_array( 
  Int packing, 
  void* buf, 
  void* packbuf, 
  MPC_Datatype datatype, 
  MPC_Datatype elemtype, 
  Int* offset );

/******************* iMPC_Pack_array *****************************************/
Int iMPC_Pack_array( 
  Int packing, 
  void* buf, 
  void* packbuf, 
  MPC_Datatype datatype, 
  MPC_Datatype elemtype, 
  Int* offset )
  {
    Int RC;
    Int step;
    Int count;
    Int i;
    Int elemsize;
    Int subsize;
    Int suboffset;
    MPC_Datatype subtype;

    elemsize = MPC_Type_size(elemtype);
    if (elemtype == datatype) 
      {        
        if (packing) memcpy(packbuf,buf,elemsize); else memcpy(buf,packbuf,elemsize);
        /* ??? MPC_Elem_copy( packing?packbuf:buf,packing?buf:packbuf,elemtype); */
        (*offset)++;
        return MPC_OK;
      }        
    step = MPC_Get_step( datatype );
    count = MPC_Get_count( datatype );
    subtype = MPC_Get_element( datatype );
    subsize = MPC_Type_size( subtype );
    for (i=0;i<count;i++)
      {         
        suboffset = *offset; 
        MPC_Debug_printf( "==> %d*%d*%d suboffset = %d, buf = %e, packbuf = %e\n",i,step,subsize,suboffset, *(double*)((char*)buf+(i*step*subsize)),*(double*)((char*)packbuf+(elemsize*suboffset)));        
        RC = MPC_Pack_array(packing, (char*)buf+(i*step*subsize),(char*)packbuf+(elemsize*suboffset), subtype, elemtype, offset );         
        MPC_Debug_printf( "<== suboffset = %d, buf = %e, packbuf = %e\n",suboffset, *(double*)((char*)buf+(i*step*subsize)),*(double*)((char*)packbuf+(elemsize*suboffset)));        
        if (RC != MPC_OK) return RC;
      } /* for */
    return MPC_OK;
  } /* iMPC_Pack_array */

/****************** MPC_Pack_array *******************************************/
Int MPC_Pack_array(
  Int packing, 
  void* buf, 
  void* packbuf, 
  MPC_Datatype datatype, 
  MPC_Datatype elemtype, 
  Int* offset )
  {
    Int RC;
    if (MPC_DEBUG > 1) MPC_Debug_printf( "---[%2d]-> MPC_Pack_array\n", MPC_Net_global.rank );
    RC = iMPC_Pack_array(packing,buf,packbuf,datatype,elemtype,offset);
    if (MPC_DEBUG > 1) MPC_Debug_printf( "<--[%2d]-- MPC_Pack_array, RC = %d\n", MPC_Net_global.rank, RC ); 
    return RC;
  } /* MPC_Pack_array */

#ifdef OLD
/******************* iMPC_Reduce *********************************************/   
Int iMPC_Reduce( 
  MPC_Net* net,     
  void* sendbuf,    
  void* recvbuf,           
  MPC_Datatype datatype,
  MPC_Datatype elemtype,  
  MPC_Op op )             
  {
    Int RC;       
    void* packedsendbuf;
    void* packedrecvbuf;
    Int count;
    Int offset;
    MPC_Rts_datatype type;
    MPI_Datatype mpi_type;
    Int size;

    MPI_Op* pop;
    
    if (op == NULL) return MPC_ERR_INTERNAL;
    pop = (MPI_Op*)op;
    if ((RC = MPC_Test_net(net)) != MPC_OK) return RC;
    type = MPC_Get_datatype( elemtype );
    mpi_type = *(MPI_Datatype*)type;
    count = MPC_Get_total_count(datatype,elemtype);
    size = MPC_Type_size(elemtype);
    if ((packedrecvbuf = malloc(size*count)) == NULL) return MPC_ERR_NOMEM;
    if (MPC_Pack_coeff(datatype,elemtype) == 1) packedsendbuf = sendbuf;
	else {
      offset = 0;                      
      if ((packedsendbuf = malloc(size*count)) == NULL) return MPC_ERR_NOMEM;           
      if ((RC = MPC_Pack_array( 1,sendbuf, packedsendbuf, datatype, elemtype, &offset )) != MPC_OK) return RC;
    } /* if */
    RC = I2C( MPI_Allreduce(packedsendbuf,packedrecvbuf,count,mpi_type,*pop,(*(MPI_Comm*)(net->pweb)) ));
    if (RC != MPC_OK) return RC;
    if (net->oldroot != MPC_MULTI_ROOT) net->oldroot = MPC_UNDEFINED_ROOT;
    if (MPC_Pack_coeff(datatype,elemtype) == 1) MPC_Elem_copy(recvbuf,packedrecvbuf,datatype );
	else {
      offset = 0;
      if((RC = MPC_Pack_array( 0,recvbuf, packedrecvbuf, datatype, elemtype, &offset )) != MPC_OK) return RC;
      free( packedsendbuf );
    } /* if */
    free(packedrecvbuf);
    return MPC_OK;
  }  /* iMPC_Reduce */

#endif
/******************* iMPC_Reduce *********************************************/
Int iMPC_Reduce(
  MPC_Net* net,     
  void* sendbuf,    
  void* recvbuf,           
  MPC_Datatype datatype,
  MPC_Datatype elemtype,  
  MPC_Op op )             
  {
    Int RC;       
    void* packedsendbuf;
    void* packedrecvbuf;
    Int count;
    Int offset;
    MPC_Rts_datatype type;
    MPI_Datatype mpi_type;
    Int size;
    Int parent;

    MPI_Op* pop;
    
    if (op == NULL) return MPC_ERR_INTERNAL;
    pop = (MPI_Op*)op;
    if ((RC = MPC_Test_net(net)) != MPC_OK) return RC;
    type = MPC_Get_datatype( elemtype );
    mpi_type = *(MPI_Datatype*)type;
    count = MPC_Get_total_count(datatype,elemtype);
    size = MPC_Type_size(elemtype);
    if ((packedrecvbuf = malloc(size*count)) == NULL) return MPC_ERR_NOMEM;
    if (MPC_Pack_coeff(datatype,elemtype) == 1) packedsendbuf = sendbuf;
	else {
      offset = 0;                      
      if ((packedsendbuf = malloc(size*count)) == NULL) return MPC_ERR_NOMEM;           
      if ((RC = MPC_Pack_array( 1,sendbuf, packedsendbuf, datatype, elemtype, &offset )) != MPC_OK) return RC;
    } /* if */

    parent = MPC_Parent(net);    
    RC = I2C( MPI_Reduce(packedsendbuf,packedrecvbuf,count,mpi_type,*pop,parent,(*(MPI_Comm*)(net->pweb)) ));
    if (RC != MPC_OK) return RC;
    RC = MPC_Change_root_done(net, parent);
    if (RC != MPC_OK) return RC;
    RC = I2C(MPI_Bcast(packedrecvbuf,count,mpi_type,parent, (*(MPI_Comm*)(net->pweb))));
    if (RC != MPC_OK) return RC;

    if (MPC_Pack_coeff(datatype,elemtype) == 1) MPC_Elem_copy(recvbuf,packedrecvbuf,datatype );
	else {
      offset = 0;
      if((RC = MPC_Pack_array( 0,recvbuf, packedrecvbuf, datatype, elemtype, &offset )) != MPC_OK) return RC;
      free( packedsendbuf );
    } /* if */
    free(packedrecvbuf);
    return MPC_OK;
}  /* iMPC_Reduce */

/******************* MPC_Reduce **********************************************/
Int MPC_Reduce( 
  MPC_Net* net,     
  void* sendbuf,    
  void* recvbuf,           
  MPC_Datatype datatype,
  MPC_Datatype elemtype,  
  MPC_Op op ) 
  {
    Int RC;
    char err_str[MAX_ERR_STR];

    if (MPC_DEBUG)
      {
        MPC_Debug_printf( "---[%2d]-> MPC_Reduce, rank = %d\n", MPC_Net_global.rank, net->rank );
      } /* if */
    RC = iMPC_Reduce(net,sendbuf,recvbuf,datatype,elemtype,op);
    if (RC != MPC_OK)
      {        
        RC2STR( err_str, RC );
        MPC_Debug_printf( "<==[%2d]== MPC error %s during MPC_Reduce\n", MPC_Net_global.rank, err_str );
        exit(RC);
      } /* if */
    if (MPC_DEBUG)
      {
        MPC_Debug_printf( "<--[%2d]-- MPC_Reduce\n", MPC_Net_global.rank );
      } /* if */
    return RC; 
  } /* MPC_Reduce */ 
