//
// nbody_MP_MPI.C
//
#include  <stdlib.h>
#include  <math.h>
#include  <stdiostream.h>
#include  <fstream.h>
#include  <string.h>
#include  <unistd.h>
#include "mpi.h"

#define real double
#include "vector.h"
#include "nbody_particle.h"
#include "BHtree.h"
#include "nbody.h"
MPI_Status status;

static int local_proc_id;
static int total_proc_count;

#define PRCI(x) cerr << "ID " << local_proc_id << " " << #x << " = " << x << ",  "
#define PRLI(x) cerr << "ID " << local_proc_id << " " << #x << " = " << x << endl;
#define PRC(x) cerr << #x << " = " << x << ",  "
#define PRL(x) cerr << #x << " = " << x << "\n"


// Hmm, may be better to have these local here
// and not in particle_system class??


int MP_myprocid()
{
    return local_proc_id;
}

int MP_proccount()
{
    return total_proc_count;
}


void MP_copyparams(real &dt,
		   real &dtsnapout,
		   int &outlogstep,
		   real &tend,
		   real &eps,
		   real &theta,
		   int &ncrit,
		   real &pos_scale,
		   real &vel_scale)
{
	MPI_Bcast(&dt,1,MPI_DOUBLE,0,MPI_COMM_WORLD);
	MPI_Bcast(&dtsnapout,1,MPI_DOUBLE,0,MPI_COMM_WORLD); 
	MPI_Bcast(&outlogstep,1,MPI_INT,0,MPI_COMM_WORLD); 
	MPI_Bcast(&tend,1,MPI_DOUBLE,0,MPI_COMM_WORLD); 
	MPI_Bcast(&eps,1,MPI_DOUBLE,0,MPI_COMM_WORLD); 
	MPI_Bcast(&theta,1,MPI_DOUBLE,0,MPI_COMM_WORLD); 
	MPI_Bcast(&ncrit,1,MPI_INT,0,MPI_COMM_WORLD); 
	MPI_Bcast(&pos_scale,1,MPI_DOUBLE,0,MPI_COMM_WORLD); 
	MPI_Bcast(&vel_scale,1,MPI_DOUBLE,0,MPI_COMM_WORLD);
}

void MP_initialize(int argc,char *argv[])
{
    int  namelen;
    int myid, numprocs;
    char processor_name[MPI_MAX_PROCESSOR_NAME];
    MPI_Init(&argc,&argv);
    MPI_Comm_size(MPI_COMM_WORLD,&numprocs);
    MPI_Comm_rank(MPI_COMM_WORLD,&myid);
    MPI_Get_processor_name(processor_name,&namelen);
    cerr << "Initialize:Myid = " << myid
	 <<  " Myname = " << processor_name
	 << " Nprocs = " << numprocs <<endl;
    total_proc_count = numprocs;
    local_proc_id = myid;
    MPI_Barrier(MPI_COMM_WORLD);
}

void MP_convert_snap_name(int& flag, char * name)
{
    char work[255];
    int flag_copy = flag;
    if (total_proc_count == 1) return;
    MPI_Bcast(&flag_copy,1,MPI_INT,0,MPI_COMM_WORLD);
    MPI_Barrier(MPI_COMM_WORLD);
    flag = flag_copy;
    if (flag){
	MPI_Bcast(name,254,MPI_CHAR,0,MPI_COMM_WORLD);
	sprintf(work,"%sMP%d-%d",name,total_proc_count,local_proc_id);
	strncpy(name, work, 254);
    }
}


void MP_gather_sample_coords(int & nsample,
			     vector * sample_array)
{
    if(local_proc_id != 0){
	// send samples and return
	MPI_Send( &nsample, 1, MPI_INT, 0,local_proc_id*2 , MPI_COMM_WORLD);
	MPI_Send( (real*)sample_array, nsample*3, MPI_DOUBLE, 0,local_proc_id*2+1,
		  MPI_COMM_WORLD);
    }else{
	for(int i=1;i<total_proc_count; i++){
	    int nreceive;
	    MPI_Recv( &nreceive, 1, MPI_INT, i,i*2, MPI_COMM_WORLD,&status);
	    MPI_Recv((real*)(sample_array+nsample), 3*nreceive, MPI_DOUBLE,
		     i,i*2+1, MPI_COMM_WORLD,&status);
	    nsample+=nreceive;
	}
    }
}

void MP_int_bcast(int& i)
{
    MPI_Bcast(&i,1,MPI_INT,0,MPI_COMM_WORLD);
}
void MP_double_bcast(double* data, int nwords)
{
    MPI_Bcast(data,nwords,MPI_DOUBLE,0,MPI_COMM_WORLD);
}
void MP_int_sum(int& i)
{
    int tmp;
    MPI_Reduce(&i,&tmp,1, MPI_INT, MPI_SUM,0,MPI_COMM_WORLD);
    if(local_proc_id == 0) i = tmp;
}
void MP_sum(double& r)
{
    double tmp;
    MPI_Reduce(&r,&tmp,1, MPI_DOUBLE, MPI_SUM,0,MPI_COMM_WORLD);
    if(local_proc_id == 0) r = tmp;
}

void MP_exchange_particle(int ibox,
			  nbody_particle * pb,
			  int firstloc,
			  int nparticles,
			  int isource,
			  int &iloc)
{
    MP_sync();
    cerr.flush();
    //first send&get the number of particles to send&get
    int nreceive;
    //    cerr << "Id " << local_proc_id << "enter MP_exchange\n";
    //    PRCI(ibox); PRL(isource);
    MPI_Sendrecv(&nparticles,1,MPI_INT,ibox,local_proc_id*10,
		 &nreceive,1,MPI_INT,isource,isource*10,MPI_COMM_WORLD,
		 &status);
    //    PRLI(nreceive);
    //    PRCI(firstloc); PRC(nparticles);PRL(iloc); 
    // If do not have enough buffer space, abort...
    if(iloc-nreceive < firstloc+nparticles){
	cerr << "Proc " << MP_myprocid()
	     << " ran out of buffer space for exchange_particles \n"
	     << ibox << " " << " " << isource << " " << firstloc << " "
	     << nparticles << " " << iloc << " " << nreceive << endl;
	MPI_Abort(MPI_COMM_WORLD,-1);
    }
    iloc -= nreceive;
    int ss = sizeof(nbody_particle);
    MPI_Sendrecv(pb+firstloc,ss*nparticles,MPI_BYTE,ibox,local_proc_id*10+1,
		 pb+iloc,ss*nreceive,MPI_BYTE,isource,isource*10+1,
		 MPI_COMM_WORLD,&status);
} 
    
int MP_exchange_particle_with_overflow_check(int ibox,
					     nbody_particle * pb,
					     int firstloc,
					     int nparticles,
					     int isource,
					     int &iloc,
					     int &nsend)
{
    int iret = 0;
    // exchange available buffer size...
    int mybuffsize = iloc - (firstloc+nparticles);
    int buffsize;
    MPI_Sendrecv(&mybuffsize,1,MPI_INT,isource,local_proc_id*10+9,
		 &buffsize,1,MPI_INT,ibox,ibox*10+9,MPI_COMM_WORLD,
		 &status);
    nsend = nparticles;
    if (buffsize < nparticles){
	nsend = buffsize;
	//	cerr << "Proc " << MP_myprocid()
	//   << " ran out of remote buffer space for exchange_particles "
	//   << ibox << " " << " " << buffsize << " " << nparticles << endl;
	iret = 1; 
    }
    //first send&get the number of particles to send&get
    int nreceive;
    //    cerr << "Id " << local_proc_id << "enter MP_exchange\n";
    //    PRCI(ibox); PRL(isource);
    MPI_Sendrecv(&nsend,1,MPI_INT,ibox,local_proc_id*10,
		 &nreceive,1,MPI_INT,isource,isource*10,MPI_COMM_WORLD,
		 &status);
    //    PRLI(nreceive);
    //    PRCI(firstloc); PRC(nparticles);PRL(iloc); 
    //
    iloc -= nreceive;
    int ss = sizeof(nbody_particle);
    int sendoffset = nparticles-nsend;
    MPI_Sendrecv(pb+firstloc+sendoffset,ss*nsend,MPI_BYTE,ibox,local_proc_id*10+1,
		 pb+iloc,ss*nreceive,MPI_BYTE,isource,isource*10+1,
		 MPI_COMM_WORLD,&status);
    int giret;
    MPI_Allreduce(&iret, &giret,1, MPI_INT, MPI_MAX,MPI_COMM_WORLD);
    return giret;
} 
    
	
void MP_exchange_bhlist(int ibox,
			int nlist,
			int nbmax,
			vector * plist,
			real * mlist,
			int isource,
			int & nrecvlist,
			vector * precvbuf,
			real * mrecvbuf)
{
    //first send&get the number of particles to send&get
    //    cerr << "Id " << local_proc_id << "enter MP_exchange\n";
    //    PRCI(ibox); PRL(isource);
    MPI_Sendrecv(&nlist,1,MPI_INT,ibox,local_proc_id*10,
		 &nrecvlist,1,MPI_INT,isource,isource*10,MPI_COMM_WORLD,
		 &status);
    //    PRLI(nrecvlist);
    if (nrecvlist > nbmax){
	    cerr << "Myid = " <<MP_myprocid() << "MP_exchange_bhlist: buffer too small "
		 << nbmax << " "<< nrecvlist<<endl;
	    MPI_Abort(MPI_COMM_WORLD,-1);
	}

    MPI_Sendrecv(plist,nlist*3,MPI_DOUBLE,ibox,local_proc_id*10+1,
		 precvbuf,nrecvlist*3,MPI_DOUBLE,isource,isource*10+1,
		 MPI_COMM_WORLD,&status);
    MPI_Sendrecv(mlist,nlist,MPI_DOUBLE,ibox,local_proc_id*10+2,
		 mrecvbuf,nrecvlist ,MPI_DOUBLE,isource,isource*10+2,
		 MPI_COMM_WORLD,&status);
} 
    
	


void MP_collect_cmterms(vector& pos,vector& vel,real& mass)
{
    MPI_Barrier(MPI_COMM_WORLD);
    real source[7];
    real dest[7];
    for(int i = 0;i<3;i++){
	source[i]= pos[i];
	source[i+3]= vel[i];
    }
    source[6]=mass;
    MPI_Reduce(source,dest,7, MPI_DOUBLE, MPI_SUM,0,MPI_COMM_WORLD);
    for(int i = 0;i<3;i++){
	pos[i] = dest[i];
	vel[i] = dest[i+3];
    }
    mass = dest[6];
    if (local_proc_id == 0) cerr << "Exit collect_cmterms\n";
}
void MP_collect_energies(real& e1,
			 real& e2,
			 real& e3)
{
    MPI_Barrier(MPI_COMM_WORLD);
    real source[3];
    real dest[3];
    source[0]=e1;
    source[1]=e2;
    source[2]=e3;
    MPI_Reduce(source,dest,3, MPI_DOUBLE, MPI_SUM,0,MPI_COMM_WORLD);
    e1 = dest[0];
    e2 = dest[1];
    e3 = dest[2];
}

int MP_intmax(int localval)
{
    int globalval = localval;
    //    if (local_proc_id == 0) cerr << "All int end\n";
    MPI_Allreduce(&localval, &globalval,1, MPI_INT, MPI_MAX,MPI_COMM_WORLD);
    return globalval;
}
double MP_doublemax(double localval)
{
    double globalval = localval;
    MPI_Allreduce(&localval, &globalval,1, MPI_DOUBLE, MPI_MAX,MPI_COMM_WORLD);
    return globalval;
}


void  MP_sync()
{
    //    if (local_proc_id == 0) cerr << "Enter MP_SYNC\n";
    MPI_Barrier(MPI_COMM_WORLD);
}


void MP_end()
{
        MPI_Finalize();
}

void MP_print_times(ostream &s)
{
    MP_sync();
    real cpu = cpu_time();
    real etime = wall_time();
    if (local_proc_id != 0){
	MPI_Send( &cpu, 1, MPI_DOUBLE, 0,local_proc_id*13+1 , MPI_COMM_WORLD);
	MPI_Send( &etime, 1, MPI_DOUBLE, 0,local_proc_id*13+2 , MPI_COMM_WORLD);
    }else{
	int i;
	real local_cpu;
	real local_etime;
	for(i=0;i<total_proc_count; i++){
	    if (i==0){
		local_cpu = cpu;
		local_etime = etime;
	    }else{
		MPI_Recv( &local_cpu, 1, MPI_DOUBLE, i,i*13+1 ,MPI_COMM_WORLD,&status);
		MPI_Recv( &local_etime, 1, MPI_DOUBLE, i,i*13+2 ,MPI_COMM_WORLD,&status);
	    }
	    s << " CPU time " << local_cpu 
	      << " Wallclock time " << local_etime <<endl;
	}
    }
}


void MP_print_treestats(real total_interactions,
			int tree_walks,
			int nisum,
			ostream &s)
{
    MP_sync();
    if (local_proc_id != 0){
	MPI_Send( &total_interactions, 1, MPI_DOUBLE, 0,local_proc_id*13+1 , MPI_COMM_WORLD);
	MPI_Send( &tree_walks, 1, MPI_INT, 0,local_proc_id*13+2 , MPI_COMM_WORLD);
	MPI_Send( &nisum, 1, MPI_INT, 0,local_proc_id*13+3 , MPI_COMM_WORLD);
    }else{
	int i;
	real local_total_ints;
	int local_tw;
	int local_ni;
	for(i=0;i<total_proc_count; i++){
	    if (i==0){
		local_total_ints = total_interactions;
		local_tw = tree_walks;
		local_ni = nisum;
	    }else{
		MPI_Recv( &local_total_ints, 1, MPI_DOUBLE, i,i*13+1 ,MPI_COMM_WORLD,&status);
		MPI_Recv( &local_tw, 1, MPI_INT, i,i*13+2 ,MPI_COMM_WORLD,&status);
		MPI_Recv( &local_ni, 1, MPI_INT, i,i*13+3 ,MPI_COMM_WORLD,&status);
	    }
	    s << "CPU " << i  << " tree_walks " << local_total_ints
	      << " tree_walks " << local_tw << " ni  " << local_ni << endl;
	}
    }
}


