/*****************************************************************************/
/* mpCctrl - contains barrier and cttrl functions                            */
/* Coded by Dm. Arapov for A.Lastovtsky 1995-1996                            */
/* Release 1.10.96                                                           */
/*****************************************************************************/
#include <stdio.h>
#include <stdlib.h>
#include "mpC.h"
#include "mpCtags.h"
#include <mpi.h>

/******************* MPC_Break_send ******************************************/ 
Int MPC_Break_send( const MPC_Net* net )
  {
    return MPC_OK;
  }

/******************* MPC_Break_test ******************************************/  
Int MPC_Break_test( const MPC_Net* net )
  {
    return MPC_OK;
  }

/******************** iMPC_Ctrl_known ****************************************/
Int iMPC_Ctrl_known(
  const MPC_Net* net,   /* net, where control is sended */
  Int* ctrl_value,      /* control value before call is valid on ctrl_index node. after it it is valid on parent node */ 
  Int  ctrl_index )     /* rank of the sender node */
  {
    Int RC;
    Int parent;
    MPI_Status status;

    parent = MPC_Parent(net);
    if ((net->rank != ctrl_index) && (parent != net->rank)) return MPC_OK;
    if ((net->rank == ctrl_index) && (parent == net->rank)) return MPC_OK;
    if (net->rank == ctrl_index) {
      RC = I2C(MPI_Send((void*)ctrl_value,1,MPI_INT,parent,MPC_CTRL_TAG,(*(MPI_Comm*)(net->pweb))));
      if (RC != MPC_OK) return RC;
    }
    if (net->rank == parent) {
      RC = I2C(MPI_Recv((void*)ctrl_value,1,MPI_INT,ctrl_index,MPC_CTRL_TAG,(*(MPI_Comm*)(net->pweb)),&status));
      if (RC != MPC_OK) return RC;
    }
    return MPC_OK;      
  }  /* iMPC_Ctrl_known */

/******************** MPC_Ctrl_known *****************************************/
Int MPC_Ctrl_known(
  const MPC_Net* net,   /* net, where control is sended */
  Int* ctrl_value,      /* control value before call is valid on ctrl_index node. after it it is valid on parent node */ 
  Int  ctrl_index )     /* rank of the sender node */
  {
    Int RC;
    char err_str[MAX_ERR_STR];

    if (MPC_DEBUG)
      {
        MPC_Debug_printf( "---[%2d]-> MPC_Ctrl_known, rank = %d, ctrl_index = %d, ctrl_value = %d,\n", MPC_Net_global.rank, net->rank, ctrl_index, *ctrl_value );
      } /* if */
    RC = iMPC_Ctrl_known(net,ctrl_value,ctrl_index);
    if (RC != MPC_OK)
      {        
        RC2STR( err_str, RC );
        MPC_Debug_printf( "<==[%2d]== MPC error %s during MPC_Ctrl_known\n", MPC_Net_global.rank, err_str );
        exit(RC);
      } /* if */
    if (MPC_DEBUG)
      {
        MPC_Debug_printf( "<--[%2d]-- MPC_Ctrl_known, rank = %d, am_parent = %d, ctrl_value = %d\n", MPC_Net_global.rank, net->rank, MPC_Is_parent(net),*ctrl_value );
      } /* if */
    return RC; 
  }  /* MPC_Ctrl_known */

/******************** iMPC_Ctrl_any ******************************************/
Int iMPC_Ctrl_any(
  MPC_Net* net,   /* net, where control is sended */
  Int* ctrl_value,      /* control value before call is valid on node with am_sender != 0,
                           after the call is done it is valid on parent node */
  Int am_sender )       /* everywhere == 0, except the sender node */
  {
    Int RC;
    Int parent;
    MPI_Status status;
    Int done = 0;

    parent = MPC_Parent(net);
    if ((!am_sender) && (net->rank != parent)) done = 1;    
    if (am_sender && (net->rank == parent)) done = 1;
    if (!done)
      {
        if (am_sender) {
          RC = I2C(MPI_Send((void*)ctrl_value,1,MPI_INT,parent,MPC_CTRL_TAG,(*(MPI_Comm*)(net->pweb))));
          if (RC != MPC_OK) return RC;
        } /* if */
        if (net->rank == parent) {
          RC = I2C(MPI_Recv((void*)ctrl_value,1,MPI_INT,MPI_ANY_SOURCE,MPC_CTRL_TAG,(*(MPI_Comm*)(net->pweb)),&status));
          if (RC != MPC_OK) return RC;
        } /* if */
      }
    RC = I2C(MPI_Barrier((*(MPI_Comm*)(net->pweb))));
    if (RC != MPC_OK) return RC;
    net->oldroot = MPC_NULL_ROOT;
    return MPC_OK;    
  } /* iMPC_Ctrl_any */ 

/******************** MPC_Ctrl_any ********************************************/
Int MPC_Ctrl_any(
  MPC_Net* net,   /* net, where control is sended */
  Int* ctrl_value,      /* control value before call is valid on node with am_sender != 0,
                           after the call is done it is valid on parent node */
  Int am_sender )       /* everywhere == 0, except the sender node */
  {
    Int RC;
    char err_str[MAX_ERR_STR];

    if (MPC_DEBUG)
      {
        MPC_Debug_printf( "---[%2d]-> MPC_Ctrl_any, am_sender = %d, ctrl_value = %d\n", MPC_Net_global.rank,am_sender, *ctrl_value );
      } /* if */
    RC = iMPC_Ctrl_any(net,ctrl_value,am_sender);
    if (RC != MPC_OK)
      {        
        RC2STR( err_str, RC );
        MPC_Debug_printf( "<==[%2d]== MPC error %s during MPC_Ctrl_any\n", MPC_Net_global.rank, err_str );
        exit(RC);
      } /* if */
    if (MPC_DEBUG)
      {
        MPC_Debug_printf( "<--[%2d]-- MPC_Ctrl_any, am_parent = %d, ctrl_value = %d\n", MPC_Net_global.rank, MPC_Is_parent(net),*ctrl_value );
      } /* if */      
    return RC; 
  } /* MPC_Ctrl_any */ 

/******************** iMPC_Set_web_root ***************************************/
Int iMPC_Set_web_root(
  MPC_Net* net,
  Int newroot )
  {
    if (newroot >= net->power) return MPC_ERR_INTERNAL;
    ((MPI_Comm*)(net->pweb))[newroot+1] = *((MPI_Comm*)(net->pweb));
    return MPC_OK;   
  } /* iMPC_Set_web_root */

/******************** MPC_Set_web_root ****************************************/
Int MPC_Set_web_root(
  MPC_Net* net,
  Int newroot )
  {
    Int RC;

    if (MPC_DEBUG>1) MPC_Debug_printf("---[%2d]--> MPC_Set_web_root, root = %d\n", MPC_Net_global.rank, newroot);
    RC = iMPC_Set_web_root(net, newroot);
    if (MPC_DEBUG>1) MPC_Debug_printf("<--[%2d]--- MPC_Set_web_root, RC = %d\n", MPC_Net_global.rank, RC);
    return RC;
  } /* MPC_Set_web_root */

/******************** iMPC_Change_root_unconditional **************************/
Int iMPC_Change_root_unconditional( 
  MPC_Net* net,
  Int newroot )
  {
    Int RC;

    if (MPC_CHANGE_ROOT == MPC_BY_BARRIER) RC = MPC_Local_barrier(net);
    if (MPC_CHANGE_ROOT == MPC_BY_GATHER) {
      Int* buffer;
      Int  dummy;      

      if ((buffer = malloc(sizeof(int)*(net->power))) == NULL) return MPC_ERR_NOMEM;
      RC = I2C(MPI_Gather(&dummy,1,MPI_INT,buffer,1,MPI_INT,newroot,(*(MPI_Comm*)(net->pweb))));       
      if (RC != MPC_OK) return RC;
      free(buffer);
    } /* if */
    if (MPC_CHANGE_ROOT == MPC_BY_REDUCE) {
      Int result;
      Int dummy;
      RC = I2C(MPI_Reduce(&dummy,&result,1,MPI_INT,MPI_LOR,newroot,*(MPI_Comm*)(net->pweb)));
      if (RC != MPC_OK) return RC;
    } /* if */
    return MPC_OK;
  } /* iMPC_Change_root_unconditional */

/******************** MPC_Change_root_unconditional ***************************/
Int MPC_Change_root_unconditional( 
  MPC_Net* net,
  Int newroot )
  {
    Int RC;
    char* methods[3] = {"barrier", "gather", "reduce"};

    if (MPC_DEBUG>1) MPC_Debug_printf("---[%2d]--> MPC_Change_root_unconditional(%s), oldroot = %d, newroot = %d\n", MPC_Net_global.rank, methods[MPC_CHANGE_ROOT],net->oldroot, newroot );
    RC = iMPC_Change_root_unconditional( net, newroot );
    if (MPC_DEBUG>1) MPC_Debug_printf("<--[%2d]--- MPC_Change_root_unconditional(%s), RC = %d, newroot = %d\n", MPC_Net_global.rank, methods[MPC_CHANGE_ROOT], RC, newroot );
    return RC;
  } /* MPC_Change_root_unconditional */

/******************** iMPC_Change_root_internal *******************************/
Int iMPC_Change_root_internal( 
  MPC_Net* net,
  Int newroot,
  Int make_barrier ) {
    Int RC;

    switch(net->oldroot) {
      case MPC_NULL_ROOT : 
        {
          net->oldroot = newroot;          
          break;
        }     
      case MPC_UNDEFINED_ROOT : 
        {
          if (make_barrier) {
            RC = MPC_Change_root_unconditional( net, newroot );
            if (RC != MPC_OK) return RC;
          } /* if */
          net->oldroot = newroot;
          break;
        }
      case MPC_MULTI_ROOT : 
        {
          if (((MPI_Comm*)(net->pweb))[newroot+1] == MPI_COMM_NULL) {
            RC = I2C(MPI_Comm_dup(*(MPI_Comm*)(net->pweb),(MPI_Comm*)(net->pweb)+newroot+1));
            if (RC != MPC_OK) return RC;
          } /* if */
          *(MPI_Comm*)(net->pweb) = ((MPI_Comm*)(net->pweb))[newroot+1];
          break;
        }
      default :
        {
          if (newroot != (net->oldroot)) {
            if (make_barrier) {
              RC = MPC_Change_root_unconditional( net, newroot );
              if (RC != MPC_OK) return RC;
            }
            net->oldroot = newroot;
          } /* if */
          break;
        }
    }
    return MPC_OK;  
  } /* iMPC_Change_root_internal */

/******************** MPC_Change_root_internal ********************************/
Int MPC_Change_root_internal( 
  MPC_Net* net,
  Int newroot,
  Int make_barrier ) {
    Int RC;

    if (MPC_DEBUG > 1) MPC_Debug_printf("---[%2d]-> MPC_Change_root_internal, oldroot = %d, newroot = %d, make_barrier = %d\n", MPC_Net_global.rank, net->oldroot, newroot, make_barrier);
    RC = iMPC_Change_root_internal( net, newroot, make_barrier );
    if (MPC_DEBUG > 1) MPC_Debug_printf("<--[%2d]-- MPC_Change_root_internal, root = %d, RC = %d\n", MPC_Net_global.rank, net->oldroot, RC);
    return RC;    
  } /* MPC_Change_root_internal */

/******************** iMPC_Change_root_done ***********************************/
Int iMPC_Change_root_done(
  MPC_Net* net,
  Int newroot )
  {
    Int RC;

    RC = MPC_Change_root_internal( net, newroot, 0 );
    return RC;
  } /* iMPC_Change_root_done */

/******************** MPC_Change_root_done ************************************/

Int MPC_Change_root_done(
  MPC_Net* net,
  Int newroot )
  {
    Int RC;

    if (MPC_DEBUG > 1) MPC_Debug_printf("---[%2d]-> MPC_Change_root_done, oldroot = %d, newroot = %d\n", MPC_Net_global.rank, net->oldroot, newroot);
    RC = iMPC_Change_root_done( net, newroot );
    if (MPC_DEBUG > 1) MPC_Debug_printf("<--[%2d]-- MPC_Change_root_done, root = %d, RC = %d\n", MPC_Net_global.rank, net->oldroot, RC);
    return RC;
  } /* MPC_Change_root_done */

/******************** iMPC_Change_root ****************************************/
Int iMPC_Change_root( 
  MPC_Net* net,
  Int newroot )
  {
    Int RC;

    RC = MPC_Change_root_internal(net,newroot,1);
    return RC;    
  } /* iMPC_Change_root */

/******************** MPC_Change_root *****************************************/
Int MPC_Change_root( 
  MPC_Net* net,
  Int newroot )
  {
    Int RC;

    if (MPC_DEBUG > 1) MPC_Debug_printf("---[%2d]-> MPC_Change_root, oldroot = %d, newroot = %d\n", MPC_Net_global.rank, net->oldroot, newroot);
    RC = iMPC_Change_root( net, newroot );
    if (MPC_DEBUG > 1) MPC_Debug_printf("<--[%2d]-- MPC_Change_root, root = %d, RC = %d\n", MPC_Net_global.rank, net->oldroot, RC);
    return RC;
  } /* MPC_Change_root */

/******************** iMPC_Ctrl_prop ******************************************/       
Int iMPC_Ctrl_prop(
  MPC_Net* net,         /* net, where control is sended */
  Int* ctrl_value )     /* integer before call valid only on the parent node. 
                           After call is propagated to all members of the net */
  {
    Int RC;
    Int parent;

    parent = MPC_Parent(net);
    RC = MPC_Change_root(net, parent);
    if (RC != MPC_OK) return RC;
    RC = I2C(MPI_Bcast((void*)ctrl_value,1,MPI_INT,parent,(*(MPI_Comm*)(net->pweb))));
    if (RC != MPC_OK) return RC;
    return MPC_OK;
  } /* iMPC_Ctrl_prop */

/******************** MPC_Ctrl_prop ******************************************/
Int MPC_Ctrl_prop(
  MPC_Net* net,
  Int* ctrl_value)
  {
    Int RC;
    char err_str[MAX_ERR_STR];

    if (MPC_DEBUG)
      {
        MPC_Debug_printf( "---[%2d]-> MPC_Ctrl_prop, ctrl_value = %d\n", MPC_Net_global.rank, *ctrl_value );
      } /* if */
    RC = iMPC_Ctrl_prop(net,ctrl_value);
    if (RC != MPC_OK)
      {        
        RC2STR( err_str, RC );
        MPC_Debug_printf( "<==[%2d]== MPC error %s during MPC_Ctrl_prop\n", MPC_Net_global.rank, err_str );
        exit(RC);
      } /* if */
    if (MPC_DEBUG)
      {
        MPC_Debug_printf( "<--[%2d]-- MPC_Ctrl_prop, ctrl_value = %d\n", MPC_Net_global.rank, *ctrl_value );
      } /* if */         
    return RC;
  } /* MPC_Ctrl_prop */

/******************** MPC_Propagate_barrier **********************************/
Int MPC_Propagate_barrier( MPC_Net* net )
  {
    return MPC_OK;
  } /* MPC_Propagate_barrier */

/******************** iMPC_Barrier *******************************************/
Int iMPC_Barrier( MPC_Net* net )
  {
    Int RC;
    
    if ((RC = MPC_Test_net(net)) != MPC_OK) return RC;
    RC = I2C(MPI_Barrier((*(MPI_Comm*)(net->pweb))));
    if (RC != MPC_OK) return RC;
    switch (net->oldroot) {
      case MPC_NULL_ROOT : break;
      case MPC_UNDEFINED_ROOT : 
        {
          net->oldroot = MPC_NULL_ROOT; 
          RC = MPC_Propagate_barrier(net);
          if (RC != MPC_OK) return RC;
          break;
        }
      case MPC_MULTI_ROOT : break;
      default:
        {
          net->oldroot = MPC_NULL_ROOT; 
          RC = MPC_Propagate_barrier(net);
          if (RC != MPC_OK) return RC;
          break;
        }       
    } /* switch */
    return MPC_OK;
  } /* iMPC_Barrier */

/******************* MPC_Local_barrier ***************************************/
Int MPC_Local_barrier( MPC_Net* net )
  {
    Int RC;   
    char err_str[MAX_ERR_STR];

    if (MPC_DEBUG)
      {
         MPC_Debug_printf("---[%2d]-> MPC_Local_barrier, rank = %d, power = %d\n", MPC_Net_global.rank, net->rank, net->power);
      }
    RC = iMPC_Barrier(net);
    if (RC != MPC_OK)
      {        
        RC2STR( err_str, RC );
        MPC_Debug_printf( "<==[%2d]== MPC error %s during MPC_Barrier\n", MPC_Net_global.rank, err_str );
        exit(RC);
      } /* if */
    if (MPC_DEBUG)
      {
        MPC_Debug_printf( "<--[%2d]-- MPC_Local_barrier, rank = %d\n", MPC_Net_global.rank, net->rank );
      } /* if */        
    return MPC_OK;
  } /* MPC_Local_barrier */

/******************* MPC_Global_barrier **************************************/
Int MPC_Global_barrier( void )
  {
    Int RC;   
    char err_str[MAX_ERR_STR];

    if (MPC_DEBUG)
      {
         MPC_Debug_printf("---[%2d]-> MPC_Global_barrier\n", MPC_Net_global.rank);
      }
    RC = iMPC_Barrier(&MPC_Net_global);
    if (RC != MPC_OK)
      {        
        RC2STR( err_str, RC );
        MPC_Debug_printf( "<==[%2d]== MPC error %s during MPC_Global_barrier\n", MPC_Net_global.rank, err_str );
        exit(RC);
      } /* if */
    if (MPC_DEBUG)
      {
        MPC_Debug_printf( "<--[%2d]-- MPC_Global_barrier\n", MPC_Net_global.rank );
      } /* if */           
    return MPC_OK;
  } /* MPC_Global_barrier */
