/*****************************************************************************/
/* mpCvar - contains variabe length operations.                              */
/* Coded by Dm. Arapov for A.Lastovtsky 1995-1996                            */
/* Release 1.10.96                                                           */
/*****************************************************************************/
#include <stdlib.h>
#include <stdio.h>
#include "mpC.h"
#include "mpCtags.h"
#include <mpi.h>

/******************* MPC_Local_copy ******************************************/
Int MPC_Local_copy(
  const void* send_buffer,
  Int send_step,
  void* recv_buffer,
  Int recv_step,
  Int count,
  MPC_Datatype datatype )
  {
    char* temp_buffer;
    Int   i;
    Int elem_size;

    elem_size = MPC_Type_size(datatype);
    if (elem_size == MPC_UNDEFINED_SIZE) return MPC_ERR_UNDEFSIZE;

    if( (send_step==recv_step)&&(recv_step==1) &&
        !(((char*)recv_buffer < (char*)send_buffer)&
         ((char*)send_buffer<(char*)recv_buffer+count*elem_size) ||
        ((char*)send_buffer < (char*)recv_buffer+count*elem_size)&
         ((char*)recv_buffer+count*elem_size<(char*)send_buffer+count*elem_size))
      ){
        memcpy( (char*)recv_buffer, (char*)send_buffer, count*elem_size);      
    } else {
      if ((temp_buffer = malloc(elem_size*count)) == NULL) return MPC_ERR_NOMEM;
      for (i=0;i<count;i++)
        memcpy( temp_buffer+i*elem_size, (char*)send_buffer+i*elem_size*send_step, elem_size);
      for (i=0;i<count;i++)
        memcpy( (char*)recv_buffer+i*elem_size*recv_step, temp_buffer+i*elem_size, elem_size);
      free(temp_buffer);             
    }
    return MPC_OK;
  } /* MPC_Local_copy */   

/******************* iMPC_Var_send_receive ***********************************/
Int iMPC_Var_send_receive(
  const MPC_Net* net,
  const Int* sender,
  const void* send_buffer,
  Int   send_step,
  Int   count,
  const Int* receiver,
  void* recv_buffer,
  Int   recv_step,
  MPC_Datatype datatype )
  {
    Int RC;
    Int send_rank;
    Int recv_rank;
    MPC_Rts_datatype elem_type;
    MPI_Datatype mpi_elem_type;
    MPI_Datatype mpi_send_type;
    MPI_Datatype mpi_recv_type;
    
    MPI_Status status;
    
    send_rank = MPC_Number( net, sender );
    recv_rank = MPC_Number( net, receiver );

    if ((net->rank == send_rank) || (net->rank == recv_rank)) {
      if (send_rank == recv_rank) { 
        if ((send_step == recv_step) && (send_buffer == recv_buffer)) return MPC_OK;
        RC = MPC_Local_copy( send_buffer, send_step, recv_buffer, recv_step, count, datatype );
        if (RC != MPC_OK) return RC;
        return MPC_OK; 
      } /* if */
      elem_type = MPC_Get_datatype( datatype );
      mpi_elem_type = *(MPI_Datatype*)(elem_type);
      if (net->rank == send_rank) {
        RC = I2C(MPI_Type_vector(count,1,send_step,mpi_elem_type,&mpi_send_type));
        if (RC != MPC_OK) return RC;
        RC = I2C(MPI_Type_commit(&mpi_send_type));                 
        if (RC != MPC_OK) return RC;
        RC = I2C(MPI_Send((void*)send_buffer, 1, mpi_send_type, recv_rank, MPC_VARSEND_TAG, (*(MPI_Comm*)(net->pweb))));
        if (RC != MPC_OK) return RC;           
        RC = I2C(MPI_Type_free(&mpi_send_type));
        if (RC != MPC_OK) return RC;    
      } /* if */
      if (net->rank == recv_rank) {
        RC = I2C(MPI_Type_vector(count,1,recv_step,mpi_elem_type,&mpi_recv_type));
        if (RC != MPC_OK) return RC;
        RC = I2C(MPI_Type_commit(&mpi_recv_type));                 
        if (RC != MPC_OK) return RC;
        RC = I2C(MPI_Recv(recv_buffer,1,mpi_recv_type,send_rank,MPC_VARSEND_TAG,(*(MPI_Comm*)(net->pweb)),&status));
        if (RC != MPC_OK) return RC;
        RC = I2C(MPI_Type_free(&mpi_recv_type));
        if (RC != MPC_OK) return RC;
      } /* if */
    } /* if */          
    return MPC_OK;
  } /* iMPC_Var_send_receive */

/******************* MPC_Var_send_receive ************************************/
Int MPC_Var_send_receive(
  const MPC_Net* net,
  const Int* sender,
  const void* send_buffer,
  Int   send_step,
  Int   count,
  const Int* receiver,
  void* recv_buffer,
  Int   recv_step,
  MPC_Datatype datatype )
  {
    Int RC;
    char err_str[MAX_ERR_STR];

    if (MPC_DEBUG > 1) MPC_Debug_printf("---[%2d]-> MPC_Var_send_receive\n", MPC_Net_global.rank );
    RC = iMPC_Var_send_receive(net, sender, send_buffer, send_step, count, receiver, recv_buffer, recv_step, datatype);
    if (RC != MPC_OK) {
      RC2STR( err_str, RC );
      MPC_Debug_printf( "<==[%2d]== MPC error %s during MPC_Var_send_receive\n", MPC_Net_global.rank, err_str );
      exit(RC);
    } /* if */
    if (MPC_DEBUG > 1) MPC_Debug_printf("<--[%2d]-- MPC_Var_send_receive, RC = %d\n", MPC_Net_global.rank, RC );
    return RC;
  } /* MPC_Var_send_receive */

/******************** iMPC_Var_broadcast ***********************************/
Int iMPC_Var_broadcast(
  MPC_Net* net,
  const Int* sender,
  const void* send_buffer,
  Int   send_step,
  Int   count,
  void* recv_buffer,
  Int   recv_step,
  MPC_Datatype datatype )
  {
    Int RC;
    Int send_rank;
    Int step;
    void* buffer;
    MPC_Rts_datatype elem_type;
    MPI_Datatype mpi_elem_type;
    MPI_Datatype mpi_buffer_type;
    
    send_rank = MPC_Number(net, sender);
    RC = MPC_Change_root( net, send_rank );
    if (RC != MPC_OK) return RC;
    step = (send_rank == net->rank) ? send_step : recv_step;
    buffer = (send_rank == net->rank) ? (void*)send_buffer : recv_buffer;
    elem_type = MPC_Get_datatype(datatype);
    mpi_elem_type = *(MPI_Datatype*)(elem_type);
    RC = I2C(MPI_Type_vector(count,1,step,mpi_elem_type,&mpi_buffer_type));
    if (RC != MPC_OK) return RC;
    RC = I2C(MPI_Type_commit(&mpi_buffer_type));                 
    if (RC != MPC_OK) return RC;
    RC = I2C(MPI_Bcast(buffer,1,mpi_buffer_type,send_rank,(*(MPI_Comm*)(net->pweb))));
    if (RC != MPC_OK) return RC;
    RC = I2C(MPI_Type_free(&mpi_buffer_type));
    if (RC != MPC_OK) return RC;
    if (send_rank == net->rank) {
      if ((send_buffer == recv_buffer) && (send_step == recv_step)) return MPC_OK;
      RC = MPC_Local_copy(send_buffer,send_step,recv_buffer,recv_step,count,datatype);
      if (RC != MPC_OK) return RC;
      return MPC_OK;
    } /* if */       
    return MPC_OK;    
  } /* iMPC_Var_broadcast */

/******************* MPC_Var_broadcast ****************************************/
Int MPC_Var_broadcast(
  MPC_Net* net,
  const Int* sender,
  const void* send_buffer,
  Int   send_step,
  Int   count,
  void* recv_buffer,
  Int   recv_step,
  MPC_Datatype datatype )
  {
    Int RC;
    char err_str[MAX_ERR_STR];
 
    if (MPC_DEBUG > 1) MPC_Debug_printf("---[%2d]-> MPC_Var_broadcast, sender = %d, send_step = %d, count = %d, recv_step = %d\n", MPC_Net_global.rank,*sender,send_step, count, recv_step );
    RC = iMPC_Var_broadcast(net, sender, send_buffer, send_step, count, recv_buffer, recv_step, datatype);
    if (RC != MPC_OK) {
      RC2STR( err_str, RC );
      MPC_Debug_printf( "<==[%2d]== MPC error %s during MPC_Var_broadcast\n", MPC_Net_global.rank, err_str );
      exit(RC);
    } /* if */
    if (MPC_DEBUG > 1) MPC_Debug_printf("<--[%2d]-- MPC_Var_broadcast, RC = %d\n", MPC_Net_global.rank, RC );
    return RC; 
  } /* MPC_Var_broadcast */

/******************* iMPC_Var_scatter ******************************************/ 
Int iMPC_Var_scatter(                
  MPC_Net* net,
  const Int* sender,
  const void* send_buffer,
  const Int* displs,
  const Int* counts,
  Int count,
  void* recv_buffer,
  MPC_Datatype datatype )
  {
    Int RC;
    MPC_Rts_datatype elem_type;
    MPI_Datatype mpi_elem_type;
    Int send_rank;

    send_rank = MPC_Number(net, sender);        
    if ((send_rank == net->rank) && (counts[send_rank] != count)) return MPC_ERR_COUNT;
    RC = MPC_Change_root( net, send_rank );
    if (RC != MPC_OK) return RC;
    elem_type = MPC_Get_datatype(datatype);
    mpi_elem_type = *(MPI_Datatype*)(elem_type);
    RC = I2C(MPI_Scatterv((void*)send_buffer, (Int*)counts, (Int*)displs, mpi_elem_type, recv_buffer, count, mpi_elem_type, send_rank, (*(MPI_Comm*)(net->pweb))));
    if (RC != MPC_OK) return RC;    
    return MPC_OK;
  } /* iMPC_Var_scatter */

/******************* MPC_Var_scatter *******************************************/
Int MPC_Var_scatter(                
  MPC_Net* net,
  const Int* sender,
  const void* send_buffer,
  const Int* displs,
  const Int* counts,
  Int count,
  void* recv_buffer,
  MPC_Datatype datatype )
  {
    Int RC;
    char err_str[MAX_ERR_STR];

    if (MPC_DEBUG > 1) MPC_Debug_printf("---[%2d]-> MPC_Var_scatter\n", MPC_Net_global.rank );
    RC = iMPC_Var_scatter(net, sender, send_buffer, displs, counts, count, recv_buffer, datatype);
    if (RC != MPC_OK) {
      RC2STR( err_str, RC );
      MPC_Debug_printf( "<==[%2d]== MPC error %s during MPC_Var_scatter\n", MPC_Net_global.rank, err_str );
      exit(RC);
    } /* if */
    if (MPC_DEBUG > 1) MPC_Debug_printf("<--[%2d]-- MPC_Var_scatter, RC = %d\n", MPC_Net_global.rank, RC );
    return RC; 
  } /* MPC_Var_scatter */

/******************* iMPC_Var_gather *****************************************/
Int iMPC_Var_gather(
  const MPC_Net* net,
  const Int* receiver,
  void* recv_buffer,
  const Int* displs,
  const Int* counts,
  Int count,
  const void* send_buffer,
  MPC_Datatype datatype )
  {
    Int RC;
    MPC_Rts_datatype elem_type;
    MPI_Datatype mpi_elem_type;
    Int recv_rank;
    
    recv_rank = MPC_Number(net, receiver);
    if ((net->rank == recv_rank) && (counts[recv_rank] != count)) return MPC_ERR_COUNT; 
    elem_type = MPC_Get_datatype(datatype);
    mpi_elem_type = *(MPI_Datatype*)(elem_type);
    RC = I2C(MPI_Gatherv((void*)send_buffer,count,mpi_elem_type,recv_buffer,(Int*)counts,(Int*)displs,mpi_elem_type,recv_rank,(*(MPI_Comm*)(net->pweb))));
    if (RC != MPC_OK) return RC;
    return MPC_OK; 
  } /* iMPC_Var_gather */

/******************* MPC_Var_gather ******************************************/
Int MPC_Var_gather(
  const MPC_Net* net,
  const Int* receiver,
  void* recv_buffer,
  const Int* displs,
  const Int* counts,
  Int count,
  const void* send_buffer,
  MPC_Datatype datatype )
  {
    Int RC;
    char err_str[MAX_ERR_STR];

    if (MPC_DEBUG > 1) MPC_Debug_printf("---[%2d]-> MPC_Var_gather\n", MPC_Net_global.rank );
    RC = iMPC_Var_gather(net, receiver, recv_buffer, displs, counts, count, send_buffer, datatype);
    if (RC != MPC_OK) {
      RC2STR( err_str, RC );
      MPC_Debug_printf( "<==[%2d]== MPC error %s during MPC_Var_gather\n", MPC_Net_global.rank, err_str );
      exit(RC);
    } /* if */
    if (MPC_DEBUG > 1) MPC_Debug_printf("<--[%2d]-- MPC_Var_gather, RC = %d\n", MPC_Net_global.rank, RC );
    return RC; 
  } /* MPC_Var_gather */

