/*
 *	Ohio Trollius
 *	Copyright 1996 The Ohio State University
 *	RBD
 *
 *	$Id: alltoall.c,v 6.1 96/11/23 22:50:46 nevin Rel $
 *
 *	Function:	- send/recv data from each node to all nodes
 *	Accepts:	- send buffer
 *			- send count
 *			- send datatype
 *			- recv buffer
 *			- recv count
 *			- recv datatype
 *			- communicator
 *	Returns:	- MPI_SUCCESS or an MPI error code
 */

#include <stdlib.h>

#include <blktype.h>
#include <mpi.h>
#include <mpisys.h>
#include <terror.h>

/*
 * local functions
 */
static int		c2c_alltoall();
static int		lamd_alltoall();


int
MPI_Alltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm)

void			*sbuf;
int			scount;
MPI_Datatype		sdtype;
void			*rbuf;
int			rcount;
MPI_Datatype		rdtype;
MPI_Comm		comm;

{
	lam_initerr();
	lam_setfunc(BLKMPIALLTOALL);
/*
 * Check for invalid arguments.
 */
	if ((comm == MPI_COMM_NULL) || LAM_IS_INTER(comm)) {
		return(lam_errfunc(comm, BLKMPIALLTOALL,
			lam_mkerr(MPI_ERR_COMM, 0)));
	}

	if ((sdtype == MPI_DATATYPE_NULL) || (rdtype == MPI_DATATYPE_NULL)) {
		return(lam_errfunc(comm, BLKMPIALLTOALL,
			lam_mkerr(MPI_ERR_TYPE, 0)));
	}

	if ((scount < 0) || (rcount < 0)) {
		return(lam_errfunc(comm, BLKMPIALLTOALL,
			lam_mkerr(MPI_ERR_COUNT, 0)));
	}

	LAM_TRACE(lam_tr_cffstart(BLKMPIALLTOALL));

	return(RPI_SPLIT(lamd_alltoall, c2c_alltoall,
		(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm)));
}

/*
 *	c2c_alltoall
 *
 *	Function:	- MPI_Alltoall for the C2C RPI
 *	Accepts:	- same as MPI_Alltoall
 *	Returns:	- MPI_SUCCESS or an MPI error code
 */
static int
c2c_alltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm)

void			*sbuf;
int			scount;
MPI_Datatype		sdtype;
void			*rbuf;
int			rcount;
MPI_Datatype		rdtype;
MPI_Comm		comm;

{
	int		i;			/* favourite index */
	int		rank;			/* my rank */
	int		size;			/* group size */
	int		nreqs;			/* # requests */
	int		err;			/* error code */
	char		*psnd;			/* ptr send buffer */
	char		*prcv;			/* ptr recv buffer */
	MPI_Aint	sndinc;			/* send increment */
	MPI_Aint	rcvinc;			/* rcv increment */
	MPI_Request	*req;			/* request array */
	MPI_Request	*preq;			/* ptr request */
	MPI_Status	*stat;			/* status array */
/*
 * Initialize.
 */
	MPI_Comm_size(comm, &size);
	MPI_Comm_rank(comm, &rank);

	MPI_Type_extent(sdtype, &sndinc);
	MPI_Type_extent(rdtype, &rcvinc);

	sndinc *= scount;
	rcvinc *= rcount;
/*
 * Allocate arrays of requests and status structures.
 */
	nreqs = 2 * (size - 1);

	if (nreqs > 0) {
		req = (MPI_Request *) malloc((unsigned)
						nreqs * sizeof(MPI_Request));
		stat = (MPI_Status *) malloc((unsigned)
						nreqs * sizeof(MPI_Status));

		if ((req == 0) || (stat == 0)) {
			if (req) free((char *) req);
			if (stat) free((char *) stat);
			return(lam_errfunc(comm, BLKMPIALLTOALL,
					lam_mkerr(MPI_ERR_OTHER, errno)));
		}
	}
	else {
		req = 0;
		stat = 0;
	}
/*
 * Switch to collective communicator.
 */
	lam_mkcoll(comm);
/*
 * simple optimization
 */
	psnd = ((char *) sbuf) + (rank * sndinc);
	prcv = ((char *) rbuf) + (rank * rcvinc);

	err = lam_dtsndrcv(psnd, scount, sdtype,
			prcv, rcount, rdtype, BLKMPIALLTOALL, comm);

	if (err != MPI_SUCCESS) {
		if (req) free((char *) req);
		if (stat) free((char *) stat);
		lam_mkpt(comm);
		return(lam_errfunc(comm, BLKMPIALLTOALL, err));
	}
/*
 * If only one process, generate run time trace and we're done.
 */
	if (size == 1) {
		lam_mkpt(comm);
		LAM_TRACE(lam_tr_cffend(BLKMPIALLTOALL, -1, comm,
				sdtype, scount));
		
		lam_resetfunc(BLKMPIALLTOALL);
		return(MPI_SUCCESS);
	}
/*
 * Initiate all send/recv to/from others.
 */
	psnd = (char *) sbuf;
	prcv = (char *) rbuf;
	preq = req;

	for (i = 0; i < size; ++i, prcv += rcvinc) {

		if (i == rank) continue;

		err = MPI_Recv_init(prcv, rcount, rdtype, i,
					BLKMPIALLTOALL, comm, preq++);

		if (err != MPI_SUCCESS) {
			free((char *) req);
			free((char *) stat);
			lam_mkpt(comm);
			return(lam_errfunc(comm, BLKMPIALLTOALL, err));
		}
	}

	for (i = 0; i < size; ++i, psnd += sndinc) {

		if (i == rank) continue;

		err = MPI_Send_init(psnd, scount, sdtype, i,
					BLKMPIALLTOALL, comm, preq++);

		if (err != MPI_SUCCESS) {
			free((char *) req);
			free((char *) stat);
			lam_mkpt(comm);
			return(lam_errfunc(comm, BLKMPIALLTOALL, err));
		}
	}
/*
 * Start all the requests.
 */
	err = MPI_Startall(nreqs, req);

	if (err != MPI_SUCCESS) {
		free((char *) req);
		free((char *) stat);
		lam_mkpt(comm);
		return(lam_errfunc(comm, BLKMPIALLTOALL, err));
	}
/*
 * Wait for them all.
 */
	err = MPI_Waitall(nreqs, req, stat);

	lam_mkpt(comm);
	free((char *) stat);

	if (err != MPI_SUCCESS) {
		free((char *) req);
		return(lam_errfunc(comm, BLKMPIALLTOALL, err));
	}
/*
 * Free the requests.
 */
	for (i = 0, preq = req; i < nreqs; ++i, ++preq) {

		err = MPI_Request_free(preq);

		if (err != MPI_SUCCESS) {
			free((char *) req);
			return(lam_errfunc(comm, BLKMPIALLTOALL, err));
		}
	}

	free((char *) req);

	LAM_TRACE(lam_tr_cffend(BLKMPIALLTOALL, -1, comm, sdtype, scount));

	lam_resetfunc(BLKMPIALLTOALL);
	return(MPI_SUCCESS);
}


/*
 *	lamd_alltoall
 *
 *	Function:	- MPI_Alltoall for the LAMD RPI
 *	Accepts:	- same as MPI_Alltoall
 *	Returns:	- MPI_SUCCESS or an MPI error code
 */
static int
lamd_alltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm)

void			*sbuf;
int			scount;
MPI_Datatype		sdtype;
void			*rbuf;
int			rcount;
MPI_Datatype		rdtype;
MPI_Comm		comm;

{
    int			i;			/* favourite index */
    int			rank;			/* my rank */
    int			size;			/* group size */
    int			err;			/* error code */
    char		*psnd;			/* ptr send buffer */
    char		*prcv;			/* ptr recv buffer */
    MPI_Aint		sndinc;			/* send increment */
    MPI_Aint		rcvinc;			/* rcv increment */
    MPI_Status		stat;			/* status */
/*
 * Initialize.
 */
    MPI_Comm_size(comm, &size);
    MPI_Comm_rank(comm, &rank);

    MPI_Type_extent(sdtype, &sndinc);
    MPI_Type_extent(rdtype, &rcvinc);

    sndinc *= scount;
    rcvinc *= rcount;
/*
 * Switch to collective communicator.
 */
    lam_mkcoll(comm);
/*
 * simple optimization
 */
    psnd = ((char *) sbuf) + (rank * sndinc);
    prcv = ((char *) rbuf) + (rank * rcvinc);

    err = lam_dtsndrcv(psnd, scount, sdtype,	prcv, rcount, rdtype,
			BLKMPIALLTOALL, comm);

    if (err != MPI_SUCCESS) {
	lam_mkpt(comm);
	return(lam_errfunc(comm, BLKMPIALLTOALL, err));
    }
/*
 * Do sendrecv's with others if any.
 */
    psnd = (char *) sbuf;
    prcv = (char *) rbuf;

    for (i = 0; i < size; ++i, prcv += rcvinc, psnd += sndinc) {

	if (i == rank) continue;
	
	err = MPI_Sendrecv(psnd, scount, sdtype, i, BLKMPIALLTOALL,
		prcv, rcount, rdtype, i, BLKMPIALLTOALL, comm, &stat);

	if (err != MPI_SUCCESS) {
	    lam_mkpt(comm);
	    return(lam_errfunc(comm, BLKMPIALLTOALL, err));
	}
    }

    lam_mkpt(comm);
    LAM_TRACE(lam_tr_cffend(BLKMPIALLTOALL, -1, comm, sdtype, scount));

    lam_resetfunc(BLKMPIALLTOALL);
    return(MPI_SUCCESS);
}
