/*
 *  Search a pivot row for Gaussian elimination
 */

#include "dot.h"

#include <sys/types.h>
#include <stdio.h>
#include <math.h>

#include "conf.h"
#include "dist.h"
#include "lu.h"
#include "misc.h"

/*  pivot informations  */

extern int piv[N];
extern int l_piv[NUM_PER_PROC_Y];
extern double inverse_piv[N];

/*  array for index conversion  */

extern int g_to_l[N];

/*  receive buffers  */

double piv_A[2][NUM_PER_PROC_X + 1];
double buf[2][NUM_PER_PROC_Y];

void
p_pivot_y(cid, size, pA, i, st_j, py)
    int cid, size;
    dot_mat pA;
    int i, st_j;
    double *py;
{
    int j, piv_i, tmp_j = 0;
    int local_ix;
    register double tmp_max = 0;
    double tmp_piv = 0;
    double local_max;
    int cidx, cidy;
    int num_data;
    double pivt;
    
    lin_trec(cid, cidx, cidy);
    num_data = num_in_proc_y(cidy, size);
    local_ix = tolocal_x(cidx, i, size);

    for (j = st_j; j < num_data; j++) {
	int piv_j = l_piv[j];
	double tmp_A, abs_A;

	tmp_A = pA[piv_j][local_ix];
	abs_A = fabs(tmp_A);
	if (tmp_max < abs_A) {
	    tmp_max = abs_A;
	    tmp_piv = tmp_A;
	    tmp_j = j;
	}
    }

    piv[i] = toglobal_y(cidy, l_piv[tmp_j], size);
    pivt = tmp_piv;
    local_max = tmp_max;
    max_y(cid, &piv[i], &local_max, &pivt);
    piv_i = piv[i];

    inverse_piv[i] = 1 / pivt;

    if (procid_y(piv_i, size) == cidy) {
	int tmp_i, t_j;
	double tmp_py;
	int local_i = tolocal_y(cidy, piv_i, size);

	t_j = tmp_j;
	broad_send_x(cid, BROAD_LOC_PIV1 + i * SKIP,
		     &t_j, sizeof(int));

	/* swap l_piv[] */
	tmp_i = l_piv[st_j];
	l_piv[st_j] = local_i;
	l_piv[tmp_j] = tmp_i;

	if (py != NULL) {
	    /* swap py[] */
	    tmp_py = py[0];
	    py[0] = py[tmp_j - st_j];
	    py[tmp_j - st_j] = tmp_py;
	}
    }
}

void
broad_piv_1(cid, size, size_x, pA, i, st_j, piv1)
    int cid, size, size_x;
    dot_mat pA;
    int i, st_j;
    double *piv1[];
{
    int cidx, cidy;
    int local_ix, tmp_s;
    int active_f;
    int size_y;

    lin_trec(cid, cidx, cidy);
    local_ix = tolocal_x(cidx, i, size);
    active_f = (procid_x(i, size) == cidx);
    size_y = num_in_proc_y(cidy, size);

    if (active_f) {
	tmp_s = local_ix + 1;
    }
    else {
	tmp_s = g_to_l[i];
    }

    if (procid_y(piv[i], size) == cidy) {
	int local_i = tolocal_y(cidy, piv[i], size);
	double *ap = &pA[local_i][tmp_s];

	broad_send_y(cid, BROAD_PIV_1 + i * SKIP,
		     &ap[0], (size_x - tmp_s + 1) * sizeof(double));
	piv1[0] = &ap[0];
    }
    else {
	piv1[0] = NULL;
/*	piv1[0] = get_pivot1_1_x(cid, i); */
    }
    
    if (active_f) {
	int j;
	double *bp = &buf[i & 1][0];
	for (j = st_j; j < size_y; j++) {
	    int piv_j = l_piv[j];
	    bp[j - st_j] = pA[piv_j][local_ix];
	}
	broad_send_x(cid, BROAD_PIV_X + i * SKIP,
		     bp, (size_y - st_j) * sizeof(double));
	piv1[1] = bp;
    }
    else {
	piv1[1] = NULL;
/*	piv1[1] = get_pivot1_1_y(cid, i); */
    }
}

double*
get_pivot1_x(cid, i)
int cid, i;
{
#ifdef COPY_RECV_MSG
    int which_buf = (i & 1);
    broad_recv_y(cid, BROAD_PIV_1 + i * SKIP,
		 &piv_A[which_buf][0], (NUM_PER_PROC_X + 1) * sizeof(double));
    return &piv_A[which_buf][0];
#else
    return (double*)fbroad_recv_y(cid, BROAD_PIV_1 + i * SKIP);
#endif
}

double*
get_pivot1_y(cid, i)
int cid, i;
{
#ifdef COPY_RECV_MSG
    int which_buf = (i & 1);
    broad_recv_x(cid, BROAD_PIV_X + i * SKIP,
		 &buf[which_buf][0], NUM_PER_PROC_Y * sizeof(double));
    return &buf[which_buf][0];
#else
    return (double*)fbroad_recv_x(cid, BROAD_PIV_X + i * SKIP);
#endif
}
