root/ompi/mca/coll/cuda/coll_cuda_module.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. mca_coll_cuda_module_construct
  2. mca_coll_cuda_module_destruct
  3. mca_coll_cuda_init_query
  4. mca_coll_cuda_comm_query
  5. mca_coll_cuda_module_enable

   1 /*
   2  * Copyright (c) 2014-2017 The University of Tennessee and The University
   3  *                         of Tennessee Research Foundation.  All rights
   4  *                         reserved.
   5  * Copyright (c) 2014      NVIDIA Corporation.  All rights reserved.
   6  * $COPYRIGHT$
   7  *
   8  * Additional copyrights may follow
   9  *
  10  * $HEADER$
  11  */
  12 
  13 #include "ompi_config.h"
  14 
  15 #include <string.h>
  16 #include <stdio.h>
  17 
  18 #include "coll_cuda.h"
  19 
  20 #include "mpi.h"
  21 
  22 #include "orte/util/show_help.h"
  23 #include "orte/util/proc_info.h"
  24 
  25 #include "ompi/constants.h"
  26 #include "ompi/communicator/communicator.h"
  27 #include "ompi/mca/coll/coll.h"
  28 #include "ompi/mca/coll/base/base.h"
  29 #include "coll_cuda.h"
  30 
  31 
  32 static void mca_coll_cuda_module_construct(mca_coll_cuda_module_t *module)
  33 {
  34     memset(&(module->c_coll), 0, sizeof(module->c_coll));
  35 }
  36 
  37 static void mca_coll_cuda_module_destruct(mca_coll_cuda_module_t *module)
  38 {
  39     OBJ_RELEASE(module->c_coll.coll_allreduce_module);
  40     OBJ_RELEASE(module->c_coll.coll_reduce_module);
  41     OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module);
  42     OBJ_RELEASE(module->c_coll.coll_scatter_module);
  43     /* If the exscan module is not NULL, then this was an
  44        intracommunicator, and therefore scan will have a module as
  45        well. */
  46     if (NULL != module->c_coll.coll_exscan_module) {
  47         OBJ_RELEASE(module->c_coll.coll_exscan_module);
  48         OBJ_RELEASE(module->c_coll.coll_scan_module);
  49     }
  50 }
  51 
  52 OBJ_CLASS_INSTANCE(mca_coll_cuda_module_t, mca_coll_base_module_t,
  53                    mca_coll_cuda_module_construct,
  54                    mca_coll_cuda_module_destruct);
  55 
  56 
  57 /*
  58  * Initial query function that is invoked during MPI_INIT, allowing
  59  * this component to disqualify itself if it doesn't support the
  60  * required level of thread support.
  61  */
  62 int mca_coll_cuda_init_query(bool enable_progress_threads,
  63                              bool enable_mpi_threads)
  64 {
  65     /* Nothing to do */
  66 
  67     return OMPI_SUCCESS;
  68 }
  69 
  70 
  71 /*
  72  * Invoked when there's a new communicator that has been created.
  73  * Look at the communicator and decide which set of functions and
  74  * priority we want to return.
  75  */
  76 mca_coll_base_module_t *
  77 mca_coll_cuda_comm_query(struct ompi_communicator_t *comm,
  78                          int *priority)
  79 {
  80     mca_coll_cuda_module_t *cuda_module;
  81 
  82     cuda_module = OBJ_NEW(mca_coll_cuda_module_t);
  83     if (NULL == cuda_module) {
  84         return NULL;
  85     }
  86 
  87     *priority = mca_coll_cuda_component.priority;
  88 
  89     /* Choose whether to use [intra|inter] */
  90     cuda_module->super.coll_module_enable = mca_coll_cuda_module_enable;
  91     cuda_module->super.ft_event = NULL;
  92 
  93     cuda_module->super.coll_allgather  = NULL;
  94     cuda_module->super.coll_allgatherv = NULL;
  95     cuda_module->super.coll_allreduce  = mca_coll_cuda_allreduce;
  96     cuda_module->super.coll_alltoall   = NULL;
  97     cuda_module->super.coll_alltoallv  = NULL;
  98     cuda_module->super.coll_alltoallw  = NULL;
  99     cuda_module->super.coll_barrier    = NULL;
 100     cuda_module->super.coll_bcast      = NULL;
 101     cuda_module->super.coll_exscan     = mca_coll_cuda_exscan;
 102     cuda_module->super.coll_gather     = NULL;
 103     cuda_module->super.coll_gatherv    = NULL;
 104     cuda_module->super.coll_reduce     = mca_coll_cuda_reduce;
 105     cuda_module->super.coll_reduce_scatter = NULL;
 106     cuda_module->super.coll_reduce_scatter_block = mca_coll_cuda_reduce_scatter_block;
 107     cuda_module->super.coll_scan       = mca_coll_cuda_scan;
 108     cuda_module->super.coll_scatter    = NULL;
 109     cuda_module->super.coll_scatterv   = NULL;
 110 
 111     return &(cuda_module->super);
 112 }
 113 
 114 
 115 /*
 116  * Init module on the communicator
 117  */
 118 int mca_coll_cuda_module_enable(mca_coll_base_module_t *module,
 119                                 struct ompi_communicator_t *comm)
 120 {
 121     bool good = true;
 122     char *msg = NULL;
 123     mca_coll_cuda_module_t *s = (mca_coll_cuda_module_t*) module;
 124 
 125 #define CHECK_AND_RETAIN(src, dst, name)                                                   \
 126     if (NULL == (src)->c_coll->coll_ ## name ## _module) {                                 \
 127         good = false;                                                                      \
 128         msg = #name;                                                                       \
 129     } else if (good) {                                                                     \
 130         (dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module;  \
 131         (dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name;                        \
 132         OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module);                               \
 133     }
 134 
 135     CHECK_AND_RETAIN(comm, s, allreduce);
 136     CHECK_AND_RETAIN(comm, s, reduce);
 137     CHECK_AND_RETAIN(comm, s, reduce_scatter_block);
 138     CHECK_AND_RETAIN(comm, s, scatter);
 139     if (!OMPI_COMM_IS_INTER(comm)) {
 140         /* MPI does not define scan/exscan on intercommunicators */
 141         CHECK_AND_RETAIN(comm, s, exscan);
 142         CHECK_AND_RETAIN(comm, s, scan);
 143     }
 144 
 145     /* All done */
 146     if (good) {
 147         return OMPI_SUCCESS;
 148     }
 149     orte_show_help("help-mpi-coll-cuda.txt", "missing collective", true,
 150                    orte_process_info.nodename,
 151                    mca_coll_cuda_component.priority, msg);
 152     return OMPI_ERR_NOT_FOUND;
 153 }
 154 

/* [<][>][^][v][top][bottom][index][help] */