root/ompi/mca/coll/basic/coll_basic_reduce_scatter.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. mca_coll_basic_reduce_scatter_intra
  2. mca_coll_basic_reduce_scatter_inter

   1 /* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil -*- */
   2 /*
   3  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
   4  *                         University Research and Technology
   5  *                         Corporation.  All rights reserved.
   6  * Copyright (c) 2004-2017 The University of Tennessee and The University
   7  *                         of Tennessee Research Foundation.  All rights
   8  *                         reserved.
   9  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
  10  *                         University of Stuttgart.  All rights reserved.
  11  * Copyright (c) 2004-2005 The Regents of the University of California.
  12  *                         All rights reserved.
  13  * Copyright (c) 2008      Sun Microsystems, Inc.  All rights reserved.
  14  * Copyright (c) 2012      Oak Ridge National Labs.  All rights reserved.
  15  * Copyright (c) 2013      Los Alamos National Security, LLC. All rights
  16  *                         reserved.
  17  * Copyright (c) 2014-2016 Research Organization for Information Science
  18  *                         and Technology (RIST). All rights reserved.
  19  * $COPYRIGHT$
  20  *
  21  * Additional copyrights may follow
  22  *
  23  * $HEADER$
  24  */
  25 
  26 #include "ompi_config.h"
  27 #include "coll_basic.h"
  28 
  29 #include <stdio.h>
  30 #include <errno.h>
  31 
  32 #include "mpi.h"
  33 #include "opal/util/bit_ops.h"
  34 #include "ompi/constants.h"
  35 #include "ompi/mca/coll/coll.h"
  36 #include "ompi/mca/coll/base/coll_tags.h"
  37 #include "ompi/mca/pml/pml.h"
  38 #include "ompi/datatype/ompi_datatype.h"
  39 #include "coll_basic.h"
  40 #include "ompi/op/op.h"
  41 
  42 #define COMMUTATIVE_LONG_MSG (8 * 1024 * 1024)
  43 
  44 /*
  45  *      reduce_scatter
  46  *
  47  *      Function:       - reduce then scatter
  48  *      Accepts:        - same as MPI_Reduce_scatter()
  49  *      Returns:        - MPI_SUCCESS or error code
  50  *
  51  * Algorithm:
  52  *   Cummutative, reasonable sized messages
  53  *     recursive halving algorithm
  54  *   Others:
  55  *     reduce and scatterv (needs to be cleaned
  56  *     up at some point)
  57  *
  58  * NOTE: that the recursive halving algorithm should be faster than
  59  * the reduce/scatter for all message sizes.  However, the memory
  60  * usage for the recusive halving is msg_size + 2 * comm_size greater
  61  * for the recursive halving, so I've limited where the recursive
  62  * halving is used to be nice to the app memory wise.  There are much
  63  * better algorithms for large messages with commutative operations,
  64  * so this should be investigated further.
  65  */
  66 int
  67 mca_coll_basic_reduce_scatter_intra(const void *sbuf, void *rbuf, const int *rcounts,
  68                                     struct ompi_datatype_t *dtype,
  69                                     struct ompi_op_t *op,
  70                                     struct ompi_communicator_t *comm,
  71                                     mca_coll_base_module_t *module)
  72 {
  73     int i, rank, size, count, err = OMPI_SUCCESS;
  74     ptrdiff_t extent, buf_size, gap;
  75     int *disps = NULL;
  76     char *recv_buf = NULL, *recv_buf_free = NULL;
  77     char *result_buf = NULL, *result_buf_free = NULL;
  78     /* Initialize */
  79     rank = ompi_comm_rank(comm);
  80     size = ompi_comm_size(comm);
  81 
  82     /* Find displacements and the like */
  83     disps = (int*) malloc(sizeof(int) * size);
  84     if (NULL == disps) return OMPI_ERR_OUT_OF_RESOURCE;
  85 
  86     disps[0] = 0;
  87     for (i = 0; i < (size - 1); ++i) {
  88         disps[i + 1] = disps[i] + rcounts[i];
  89     }
  90     count = disps[size - 1] + rcounts[size - 1];
  91 
  92     /* short cut the trivial case */
  93     if (0 == count) {
  94         free(disps);
  95         return OMPI_SUCCESS;
  96     }
  97 
  98     /* get datatype information */
  99     ompi_datatype_type_extent(dtype, &extent);
 100     buf_size = opal_datatype_span(&dtype->super, count, &gap);
 101 
 102     /* Handle MPI_IN_PLACE */
 103     if (MPI_IN_PLACE == sbuf) {
 104         sbuf = rbuf;
 105     }
 106 
 107     if ((op->o_flags & OMPI_OP_FLAGS_COMMUTE) &&
 108         (buf_size < COMMUTATIVE_LONG_MSG)) {
 109         int tmp_size, remain = 0, tmp_rank;
 110 
 111         /* temporary receive buffer.  See coll_basic_reduce.c for details on sizing */
 112         recv_buf_free = (char*) malloc(buf_size);
 113         recv_buf = recv_buf_free - gap;
 114         if (NULL == recv_buf_free) {
 115             err = OMPI_ERR_OUT_OF_RESOURCE;
 116             goto cleanup;
 117         }
 118 
 119         /* allocate temporary buffer for results */
 120         result_buf_free = (char*) malloc(buf_size);
 121         result_buf = result_buf_free - gap;
 122 
 123         /* copy local buffer into the temporary results */
 124         err = ompi_datatype_sndrcv(sbuf, count, dtype, result_buf, count, dtype);
 125         if (OMPI_SUCCESS != err) goto cleanup;
 126 
 127         /* figure out power of two mapping: grow until larger than
 128            comm size, then go back one, to get the largest power of
 129            two less than comm size */
 130         tmp_size = opal_next_poweroftwo(size);
 131         tmp_size >>= 1;
 132         remain = size - tmp_size;
 133 
 134         /* If comm size is not a power of two, have the first "remain"
 135            procs with an even rank send to rank + 1, leaving a power of
 136            two procs to do the rest of the algorithm */
 137         if (rank < 2 * remain) {
 138             if ((rank & 1) == 0) {
 139                 err = MCA_PML_CALL(send(result_buf, count, dtype, rank + 1,
 140                                         MCA_COLL_BASE_TAG_REDUCE_SCATTER,
 141                                         MCA_PML_BASE_SEND_STANDARD,
 142                                         comm));
 143                 if (OMPI_SUCCESS != err) goto cleanup;
 144 
 145                 /* we don't participate from here on out */
 146                 tmp_rank = -1;
 147             } else {
 148                 err = MCA_PML_CALL(recv(recv_buf, count, dtype, rank - 1,
 149                                         MCA_COLL_BASE_TAG_REDUCE_SCATTER,
 150                                         comm, MPI_STATUS_IGNORE));
 151                 if (OMPI_SUCCESS != err) goto cleanup;
 152 
 153                 /* integrate their results into our temp results */
 154                 ompi_op_reduce(op, recv_buf, result_buf, count, dtype);
 155 
 156                 /* adjust rank to be the bottom "remain" ranks */
 157                 tmp_rank = rank / 2;
 158             }
 159         } else {
 160             /* just need to adjust rank to show that the bottom "even
 161                remain" ranks dropped out */
 162             tmp_rank = rank - remain;
 163         }
 164 
 165         /* For ranks not kicked out by the above code, perform the
 166            recursive halving */
 167         if (tmp_rank >= 0) {
 168             int *tmp_disps = NULL, *tmp_rcounts = NULL;
 169             int mask, send_index, recv_index, last_index;
 170 
 171             /* recalculate disps and rcounts to account for the
 172                special "remainder" processes that are no longer doing
 173                anything */
 174             tmp_rcounts = (int*) malloc(tmp_size * sizeof(int));
 175             if (NULL == tmp_rcounts) {
 176                 err = OMPI_ERR_OUT_OF_RESOURCE;
 177                 goto cleanup;
 178             }
 179             tmp_disps = (int*) malloc(tmp_size * sizeof(int));
 180             if (NULL == tmp_disps) {
 181                 free(tmp_rcounts);
 182                 err = OMPI_ERR_OUT_OF_RESOURCE;
 183                 goto cleanup;
 184             }
 185 
 186             for (i = 0 ; i < tmp_size ; ++i) {
 187                 if (i < remain) {
 188                     /* need to include old neighbor as well */
 189                     tmp_rcounts[i] = rcounts[i * 2 + 1] + rcounts[i * 2];
 190                 } else {
 191                     tmp_rcounts[i] = rcounts[i + remain];
 192                 }
 193             }
 194 
 195             tmp_disps[0] = 0;
 196             for (i = 0; i < tmp_size - 1; ++i) {
 197                 tmp_disps[i + 1] = tmp_disps[i] + tmp_rcounts[i];
 198             }
 199 
 200             /* do the recursive halving communication.  Don't use the
 201                dimension information on the communicator because I
 202                think the information is invalidated by our "shrinking"
 203                of the communicator */
 204             mask = tmp_size >> 1;
 205             send_index = recv_index = 0;
 206             last_index = tmp_size;
 207             while (mask > 0) {
 208                 int tmp_peer, peer, send_count, recv_count;
 209                 struct ompi_request_t *request;
 210 
 211                 tmp_peer = tmp_rank ^ mask;
 212                 peer = (tmp_peer < remain) ? tmp_peer * 2 + 1 : tmp_peer + remain;
 213 
 214                 /* figure out if we're sending, receiving, or both */
 215                 send_count = recv_count = 0;
 216                 if (tmp_rank < tmp_peer) {
 217                     send_index = recv_index + mask;
 218                     for (i = send_index ; i < last_index ; ++i) {
 219                         send_count += tmp_rcounts[i];
 220                     }
 221                     for (i = recv_index ; i < send_index ; ++i) {
 222                         recv_count += tmp_rcounts[i];
 223                     }
 224                 } else {
 225                     recv_index = send_index + mask;
 226                     for (i = send_index ; i < recv_index ; ++i) {
 227                         send_count += tmp_rcounts[i];
 228                     }
 229                     for (i = recv_index ; i < last_index ; ++i) {
 230                         recv_count += tmp_rcounts[i];
 231                     }
 232                 }
 233 
 234                 /* actual data transfer.  Send from result_buf,
 235                    receive into recv_buf */
 236                 if (recv_count > 0) {
 237                     err = MCA_PML_CALL(irecv(recv_buf + tmp_disps[recv_index] * extent,
 238                                              recv_count, dtype, peer,
 239                                              MCA_COLL_BASE_TAG_REDUCE_SCATTER,
 240                                              comm, &request));
 241                     if (OMPI_SUCCESS != err) {
 242                         free(tmp_rcounts);
 243                         free(tmp_disps);
 244                         goto cleanup;
 245                     }
 246                 }
 247                 if (send_count > 0) {
 248                     err = MCA_PML_CALL(send(result_buf + tmp_disps[send_index] * extent,
 249                                             send_count, dtype, peer,
 250                                             MCA_COLL_BASE_TAG_REDUCE_SCATTER,
 251                                             MCA_PML_BASE_SEND_STANDARD,
 252                                             comm));
 253                     if (OMPI_SUCCESS != err) {
 254                         free(tmp_rcounts);
 255                         free(tmp_disps);
 256                         goto cleanup;
 257                     }
 258                 }
 259 
 260                 /* if we received something on this step, push it into
 261                    the results buffer */
 262                 if (recv_count > 0) {
 263                     err = ompi_request_wait(&request, MPI_STATUS_IGNORE);
 264                     if (OMPI_SUCCESS != err) {
 265                         free(tmp_rcounts);
 266                         free(tmp_disps);
 267                         goto cleanup;
 268                     }
 269 
 270                     ompi_op_reduce(op,
 271                                    recv_buf + tmp_disps[recv_index] * extent,
 272                                    result_buf + tmp_disps[recv_index] * extent,
 273                                    recv_count, dtype);
 274                 }
 275 
 276                 /* update for next iteration */
 277                 send_index = recv_index;
 278                 last_index = recv_index + mask;
 279                 mask >>= 1;
 280             }
 281 
 282             /* copy local results from results buffer into real receive buffer */
 283             if (0 != rcounts[rank]) {
 284                 err = ompi_datatype_sndrcv(result_buf + disps[rank] * extent,
 285                                       rcounts[rank], dtype,
 286                                       rbuf, rcounts[rank], dtype);
 287                 if (OMPI_SUCCESS != err) {
 288                     free(tmp_rcounts);
 289                     free(tmp_disps);
 290                     goto cleanup;
 291                 }
 292             }
 293 
 294             free(tmp_rcounts);
 295             free(tmp_disps);
 296         }
 297 
 298         /* Now fix up the non-power of two case, by having the odd
 299            procs send the even procs the proper results */
 300         if (rank < 2 * remain) {
 301             if ((rank & 1) == 0) {
 302                 if (rcounts[rank]) {
 303                     err = MCA_PML_CALL(recv(rbuf, rcounts[rank], dtype, rank + 1,
 304                                             MCA_COLL_BASE_TAG_REDUCE_SCATTER,
 305                                             comm, MPI_STATUS_IGNORE));
 306                     if (OMPI_SUCCESS != err) goto cleanup;
 307                 }
 308             } else {
 309                 if (rcounts[rank - 1]) {
 310                     err = MCA_PML_CALL(send(result_buf + disps[rank - 1] * extent,
 311                                             rcounts[rank - 1], dtype, rank - 1,
 312                                             MCA_COLL_BASE_TAG_REDUCE_SCATTER,
 313                                             MCA_PML_BASE_SEND_STANDARD,
 314                                             comm));
 315                     if (OMPI_SUCCESS != err) goto cleanup;
 316                 }
 317             }
 318         }
 319 
 320     } else {
 321         if (0 == rank) {
 322             /* temporary receive buffer.  See coll_basic_reduce.c for
 323                details on sizing */
 324             recv_buf_free = (char*) malloc(buf_size);
 325             recv_buf = recv_buf_free - gap;
 326             if (NULL == recv_buf_free) {
 327                 err = OMPI_ERR_OUT_OF_RESOURCE;
 328                 goto cleanup;
 329             }
 330         }
 331 
 332         /* reduction */
 333         err =
 334             comm->c_coll->coll_reduce(sbuf, recv_buf, count, dtype, op, 0,
 335                                      comm, comm->c_coll->coll_reduce_module);
 336 
 337         /* scatter */
 338         if (MPI_SUCCESS == err) {
 339             err = comm->c_coll->coll_scatterv(recv_buf, rcounts, disps, dtype,
 340                                              rbuf, rcounts[rank], dtype, 0,
 341                                              comm, comm->c_coll->coll_scatterv_module);
 342         }
 343     }
 344 
 345  cleanup:
 346     if (NULL != disps) free(disps);
 347     if (NULL != recv_buf_free) free(recv_buf_free);
 348     if (NULL != result_buf_free) free(result_buf_free);
 349 
 350     return err;
 351 }
 352 
 353 
 354 /*
 355  *      reduce_scatter_inter
 356  *
 357  *      Function:       - reduce/scatter operation
 358  *      Accepts:        - same arguments as MPI_Reduce_scatter()
 359  *      Returns:        - MPI_SUCCESS or error code
 360  */
 361 int
 362 mca_coll_basic_reduce_scatter_inter(const void *sbuf, void *rbuf, const int *rcounts,
 363                                     struct ompi_datatype_t *dtype,
 364                                     struct ompi_op_t *op,
 365                                     struct ompi_communicator_t *comm,
 366                                     mca_coll_base_module_t *module)
 367 {
 368     int err, i, rank, root = 0, rsize, lsize, totalcounts;
 369     char *tmpbuf = NULL, *tmpbuf2 = NULL, *lbuf = NULL, *buf;
 370     ptrdiff_t gap, span;
 371     ompi_request_t *req;
 372     int *disps = NULL;
 373 
 374     rank = ompi_comm_rank(comm);
 375     rsize = ompi_comm_remote_size(comm);
 376     lsize = ompi_comm_size(comm);
 377 
 378     /* Figure out the total amount of data for the reduction. */
 379     for (totalcounts = 0, i = 0; i < lsize; i++) {
 380         totalcounts += rcounts[i];
 381     }
 382 
 383     /*
 384      * The following code basically does an interreduce followed by a
 385      * intrascatterv.  This is implemented by having the roots of each
 386      * group exchange their sbuf.  Then, the roots receive the data
 387      * from each of the remote ranks and execute the reduce.  When
 388      * this is complete, they have the reduced data available to them
 389      * for doing the scatterv.  They do this on the local communicator
 390      * associated with the intercommunicator.
 391      *
 392      * Note: There are other ways to implement MPI_Reduce_scatter on
 393      * intercommunicators.  For example, one could do a MPI_Reduce locally,
 394      * then send the results to the other root which could scatter it.
 395      *
 396      * Note: It is also worth pointing out that the rcounts argument
 397      * represents how the data is going to be scatter locally.  Therefore,
 398      * its size is the same as the local communicator size.
 399      */
 400     if (rank == root) {
 401         span = opal_datatype_span(&dtype->super, totalcounts, &gap);
 402 
 403         /* Generate displacements for the scatterv part */
 404         disps = (int*) malloc(sizeof(int) * lsize);
 405         if (NULL == disps) {
 406             return OMPI_ERR_OUT_OF_RESOURCE;
 407         }
 408         disps[0] = 0;
 409         for (i = 0; i < (lsize - 1); ++i) {
 410             disps[i + 1] = disps[i] + rcounts[i];
 411         }
 412 
 413         tmpbuf = (char *) malloc(span);
 414         tmpbuf2 = (char *) malloc(span);
 415         if (NULL == tmpbuf || NULL == tmpbuf2) {
 416             err = OMPI_ERR_OUT_OF_RESOURCE;
 417             goto exit;
 418         }
 419         lbuf = tmpbuf - gap;
 420         buf = tmpbuf2 - gap;
 421 
 422         /* Do a send-recv between the two root procs. to avoid deadlock */
 423         err = MCA_PML_CALL(isend(sbuf, totalcounts, dtype, 0,
 424                                  MCA_COLL_BASE_TAG_REDUCE_SCATTER,
 425                                  MCA_PML_BASE_SEND_STANDARD, comm, &req));
 426         if (OMPI_SUCCESS != err) {
 427             goto exit;
 428         }
 429 
 430         err = MCA_PML_CALL(recv(lbuf, totalcounts, dtype, 0,
 431                                 MCA_COLL_BASE_TAG_REDUCE_SCATTER, comm,
 432                                 MPI_STATUS_IGNORE));
 433         if (OMPI_SUCCESS != err) {
 434             goto exit;
 435         }
 436 
 437         err = ompi_request_wait( &req, MPI_STATUS_IGNORE);
 438         if (OMPI_SUCCESS != err) {
 439             goto exit;
 440         }
 441 
 442 
 443         /* Loop receiving and calling reduction function (C or Fortran)
 444          * The result of this reduction operations is then in
 445          * lbuf.
 446          */
 447         for (i = 1; i < rsize; i++) {
 448             char *tbuf;
 449             err = MCA_PML_CALL(recv(buf, totalcounts, dtype, i,
 450                                     MCA_COLL_BASE_TAG_REDUCE_SCATTER, comm,
 451                                     MPI_STATUS_IGNORE));
 452             if (MPI_SUCCESS != err) {
 453                 goto exit;
 454             }
 455 
 456             /* Perform the reduction */
 457             ompi_op_reduce(op, lbuf, buf, totalcounts, dtype);
 458             /* swap the buffers */
 459             tbuf = lbuf; lbuf = buf; buf = tbuf;
 460         }
 461     } else {
 462         /* If not root, send data to the root. */
 463         err = MCA_PML_CALL(send(sbuf, totalcounts, dtype, root,
 464                                 MCA_COLL_BASE_TAG_REDUCE_SCATTER,
 465                                 MCA_PML_BASE_SEND_STANDARD, comm));
 466         if (OMPI_SUCCESS != err) {
 467             goto exit;
 468         }
 469     }
 470 
 471     /* Now do a scatterv on the local communicator */
 472     err = comm->c_local_comm->c_coll->coll_scatterv(lbuf, rcounts, disps, dtype,
 473                                                    rbuf, rcounts[rank], dtype, 0,
 474                                                    comm->c_local_comm,
 475                                                    comm->c_local_comm->c_coll->coll_scatterv_module);
 476 
 477   exit:
 478     if (NULL != tmpbuf) {
 479         free(tmpbuf);
 480     }
 481 
 482     if (NULL != tmpbuf2) {
 483         free(tmpbuf2);
 484     }
 485 
 486     if (NULL != disps) {
 487         free(disps);
 488     }
 489 
 490     return err;
 491 }

/* [<][>][^][v][top][bottom][index][help] */