/*
 *  routine for redistribution for the AP+
 */

#include <ccell.c7.h>
#include <stdio.h>
#include <assert.h>
#include "conf.h"
#include "dist.h"


#ifndef __PUT__
#define TRANS_DAT	105
#define TRANS_VEC	106
#else /* __PUT__ */
static unsigned long put_f[2];
static unsigned long dma_f;
static unsigned long local_f[2];
#endif

#ifdef __PUT__

static double b[NUM_PER_PROC_Y];

void
trans_block_dot_mat_vec(cid, tid, size, pA, a, pB)
    int cid, tid;
    int size;
    mat pA;
    double *a;
    dot_mat pB;
{
    int i, j;
    int cidx, cidy;
    int ns = num_in_proc_block(cid, size);
    int nrx;
    int nry;
    int num_recv = 0, num_vect = 0;

    lin_trec(cid, &cidx, &cidy);
    nrx = num_in_proc_dot_x(cidx, size);
    nry = num_in_proc_dot_y(cidy, size);

    for (i = 0; i < ns; i++) {
	int g_i = toglobal_block(cid, i, size);
	int pidx = procid_dot_x(g_i, size); 
	int l_i = tolocal_dot_x(pidx, g_i, size);
	int num_y = (size > NUM_PROC_Y) ? NUM_PROC_Y : size;

	for (j = 0; j < num_y; j++) {
	    int pidy = procid_dot_y(j, size); 
	    int l_j = tolocal_dot_y(pidy, j, size);
	    int pid = rec_tlin(pidx, pidy);
	    int hcnt = num_in_proc_dot_y(pidy, size);

	    if (pid == cid) {
		int k;
		for (k = 0; k < hcnt; k++) {
		    pB[l_j + k][l_i] = pA[i][j + NUM_PROC_Y * k];
		}
		num_recv++;
	    }
	    else {
		/***** VERY IMPORTANT *****/
		/* recv_hskip depends on the size of dot_mat */
		put_stride(pid, &pA[i][j], sizeof(double),
			   &pB[l_j][l_i], &put_f[0], &dma_f, 0,
			   NUM_PROC_Y * sizeof(double),			/* hskip */
			   hcnt,					/* hcnt  */
			   (NUM_PER_PROC_X + 1) * sizeof(double),	/* recv_hskip */
			   hcnt,					/* recv_hcnt  */
			   sizeof(double));				/* recv_size  */
	    }
	}
    }

    local_f[0] += nrx - num_recv;
    amcheck(&put_f[0], local_f[0]);

    for (i = 0; i < ns; i++) {
	int g_i = toglobal_block(cid, i, size);
	int pidy = procid_dot_y(g_i, size);
	int l_i = tolocal_dot_y(pidy, g_i, size);
	int pid = rec_tlin(0, pidy);

	if (pid == cid) {
	    b[l_i] = a[i];
	    num_vect++;
	}
	else {
	    put(pid, &a[i], sizeof(double), &b[l_i],
		&put_f[1], &dma_f, 0);
	}
    }

    if (cidx == 0) {
	int k;
	local_f[1] += nry - num_vect;
	amcheck(&put_f[1], local_f[1]);
	x_brd(0, &b[0], nry * sizeof(double));
	for (k = 0; k < nry; k++)
	    pB[k][nrx] = b[k];
    }
    else {
	int k;
	x_brd(0, &b[0], nry * sizeof(double));
	for (k = 0; k < nry; k++)
	    pB[k][nrx] = b[k];
    }
}

#else /* another optimized code (this code does not use put_stride().) */

static double b[NUM_PER_PROC_Y];
static struct msg_type {
    int		size;
    double*	adr;
    double	data[NUM_PER_PROC_Y];
} msg1;

static struct dummy_type {
    int		size;
    double*	adr;
    double	data[1];
};

static struct vec_type {
    double*	adr;
    double	data;
} msg2;

void
trans_block_dot_mat_vec(cid, tid, size, pA, a, pB)
    int cid, tid;
    int size;
    mat pA;
    double *a;
    dot_mat pB;
{
    int i, j;
    int cidx, cidy;
    int ns = num_in_proc_block(cid, size);
    int nrx;
    int nry;
    int num_recv = 0, num_vect = 0;

    lin_trec(cid, &cidx, &cidy);
    nrx = num_in_proc_dot_x(cidx, size);
    nry = num_in_proc_dot_y(cidy, size);

    for (i = 0; i < ns; i++) {
	int g_i = toglobal_block(cid, i, size);
	int pidx = procid_dot_x(g_i, size); 
	int l_i = tolocal_dot_x(pidx, g_i, size);
	int num_y = (size > NUM_PROC_Y) ? NUM_PROC_Y : size;

	for (j = 0; j < num_y; j++) {
	    int pidy = procid_dot_y(j, size); 
	    int l_j = tolocal_dot_y(pidy, j, size);
	    int pid = rec_tlin(pidx, pidy);
	    int hcnt = num_in_proc_dot_y(pidy, size);

	    if (pid == cid) {
		int k;
		double *ap = &pA[i][j];
		for (k = 0; k < hcnt; k++) {
		    pB[l_j + k][l_i] = ap[NUM_PROC_Y * k];
		}
		num_recv++;
	    }
	    else {
		int k;
		double *ap = &pA[i][j];
		msg1.size = hcnt;
		msg1.adr = &pB[l_j][l_i];
		for (k = 0; k < hcnt; k++) {
		    msg1.data[k] = ap[NUM_PROC_Y * k];
		}
		l_asend(pid, tid, TRANS_DAT, &msg1,
			sizeof(struct dummy_type) + (hcnt - 1) * sizeof(double));
	    }
	}
	while (l_aqrecv(ANY_CELL, tid, TRANS_DAT) != NULL) {
	    int size = getmsize();
	    int k;
	    dot_mat *adr;
	    readmsg(&msg1, size);
	    adr = (dot_mat*) msg1.adr;
	    for (k = 0; k < msg1.size; k++)
		(*adr)[k][0] = msg1.data[k];
	    num_recv++;
	}
    }

    while(num_recv < nrx) {
	int size;
	int k;
	dot_mat *adr;
	l_arecv(ANY_CELL, tid, TRANS_DAT);
	size = getmsize();
	readmsg(&msg1, size);
	adr = (dot_mat*) msg1.adr;
	for (k = 0; k < msg1.size; k++)
	    (*adr)[k][0] = msg1.data[k];
	num_recv++;
    }

    for (i = 0; i < ns; i++) {
	int g_i = toglobal_block(cid, i, size);
	int pidy = procid_dot_y(g_i, size);
	int l_i = tolocal_dot_y(pidy, g_i, size);
	int pid = rec_tlin(0, pidy);

	if (pid == cid) {
	    b[l_i] = a[i];
	    num_vect++;
	}
	else {
	    msg2.adr = &b[l_i];
	    msg2.data = a[i];
	    l_asend(pid, tid, TRANS_VEC, &msg2, sizeof(struct vec_type));
	}
    }

    if (cidx == 0) {
	int k;
	while(num_vect < nry) {
	    l_arecv(ANY_CELL, tid, TRANS_VEC);
	    readmsg(&msg2, sizeof(struct vec_type));
	    *msg2.adr = msg2.data;
	    num_vect++;
	}
/*	broad_x(tid, BROAD_TRANS_VEC, &b[0], nry * sizeof(double)); */
	x_brd(0, &b[0], nry * sizeof(double));
	for (k = 0; k < nry; k++)
	    pB[k][nrx] = b[k];
    }
    else {
	int k;
/*
	l_arecv(ANY_CELL, tid, BROAD_TRANS_VEC);
	readmsg(&b[0], nry * sizeof(double));
*/
	x_brd(0, &b[0], nry * sizeof(double));
	for (k = 0; k < nry; k++)
	    pB[k][nrx] = b[k];
    }
}

#endif
