This source file includes following definitions.
- mca_coll_cuda_module_construct
- mca_coll_cuda_module_destruct
- mca_coll_cuda_init_query
- mca_coll_cuda_comm_query
- mca_coll_cuda_module_enable
   1 
   2 
   3 
   4 
   5 
   6 
   7 
   8 
   9 
  10 
  11 
  12 
  13 #include "ompi_config.h"
  14 
  15 #include <string.h>
  16 #include <stdio.h>
  17 
  18 #include "coll_cuda.h"
  19 
  20 #include "mpi.h"
  21 
  22 #include "orte/util/show_help.h"
  23 #include "orte/util/proc_info.h"
  24 
  25 #include "ompi/constants.h"
  26 #include "ompi/communicator/communicator.h"
  27 #include "ompi/mca/coll/coll.h"
  28 #include "ompi/mca/coll/base/base.h"
  29 #include "coll_cuda.h"
  30 
  31 
  32 static void mca_coll_cuda_module_construct(mca_coll_cuda_module_t *module)
  33 {
  34     memset(&(module->c_coll), 0, sizeof(module->c_coll));
  35 }
  36 
  37 static void mca_coll_cuda_module_destruct(mca_coll_cuda_module_t *module)
  38 {
  39     OBJ_RELEASE(module->c_coll.coll_allreduce_module);
  40     OBJ_RELEASE(module->c_coll.coll_reduce_module);
  41     OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module);
  42     OBJ_RELEASE(module->c_coll.coll_scatter_module);
  43     
  44 
  45 
  46     if (NULL != module->c_coll.coll_exscan_module) {
  47         OBJ_RELEASE(module->c_coll.coll_exscan_module);
  48         OBJ_RELEASE(module->c_coll.coll_scan_module);
  49     }
  50 }
  51 
  52 OBJ_CLASS_INSTANCE(mca_coll_cuda_module_t, mca_coll_base_module_t,
  53                    mca_coll_cuda_module_construct,
  54                    mca_coll_cuda_module_destruct);
  55 
  56 
  57 
  58 
  59 
  60 
  61 
  62 int mca_coll_cuda_init_query(bool enable_progress_threads,
  63                              bool enable_mpi_threads)
  64 {
  65     
  66 
  67     return OMPI_SUCCESS;
  68 }
  69 
  70 
  71 
  72 
  73 
  74 
  75 
  76 mca_coll_base_module_t *
  77 mca_coll_cuda_comm_query(struct ompi_communicator_t *comm,
  78                          int *priority)
  79 {
  80     mca_coll_cuda_module_t *cuda_module;
  81 
  82     cuda_module = OBJ_NEW(mca_coll_cuda_module_t);
  83     if (NULL == cuda_module) {
  84         return NULL;
  85     }
  86 
  87     *priority = mca_coll_cuda_component.priority;
  88 
  89     
  90     cuda_module->super.coll_module_enable = mca_coll_cuda_module_enable;
  91     cuda_module->super.ft_event = NULL;
  92 
  93     cuda_module->super.coll_allgather  = NULL;
  94     cuda_module->super.coll_allgatherv = NULL;
  95     cuda_module->super.coll_allreduce  = mca_coll_cuda_allreduce;
  96     cuda_module->super.coll_alltoall   = NULL;
  97     cuda_module->super.coll_alltoallv  = NULL;
  98     cuda_module->super.coll_alltoallw  = NULL;
  99     cuda_module->super.coll_barrier    = NULL;
 100     cuda_module->super.coll_bcast      = NULL;
 101     cuda_module->super.coll_exscan     = mca_coll_cuda_exscan;
 102     cuda_module->super.coll_gather     = NULL;
 103     cuda_module->super.coll_gatherv    = NULL;
 104     cuda_module->super.coll_reduce     = mca_coll_cuda_reduce;
 105     cuda_module->super.coll_reduce_scatter = NULL;
 106     cuda_module->super.coll_reduce_scatter_block = mca_coll_cuda_reduce_scatter_block;
 107     cuda_module->super.coll_scan       = mca_coll_cuda_scan;
 108     cuda_module->super.coll_scatter    = NULL;
 109     cuda_module->super.coll_scatterv   = NULL;
 110 
 111     return &(cuda_module->super);
 112 }
 113 
 114 
 115 
 116 
 117 
 118 int mca_coll_cuda_module_enable(mca_coll_base_module_t *module,
 119                                 struct ompi_communicator_t *comm)
 120 {
 121     bool good = true;
 122     char *msg = NULL;
 123     mca_coll_cuda_module_t *s = (mca_coll_cuda_module_t*) module;
 124 
 125 #define CHECK_AND_RETAIN(src, dst, name)                                                   \
 126     if (NULL == (src)->c_coll->coll_ ## name ## _module) {                                 \
 127         good = false;                                                                      \
 128         msg = #name;                                                                       \
 129     } else if (good) {                                                                     \
 130         (dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module;  \
 131         (dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name;                        \
 132         OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module);                               \
 133     }
 134 
 135     CHECK_AND_RETAIN(comm, s, allreduce);
 136     CHECK_AND_RETAIN(comm, s, reduce);
 137     CHECK_AND_RETAIN(comm, s, reduce_scatter_block);
 138     CHECK_AND_RETAIN(comm, s, scatter);
 139     if (!OMPI_COMM_IS_INTER(comm)) {
 140         
 141         CHECK_AND_RETAIN(comm, s, exscan);
 142         CHECK_AND_RETAIN(comm, s, scan);
 143     }
 144 
 145     
 146     if (good) {
 147         return OMPI_SUCCESS;
 148     }
 149     orte_show_help("help-mpi-coll-cuda.txt", "missing collective", true,
 150                    orte_process_info.nodename,
 151                    mca_coll_cuda_component.priority, msg);
 152     return OMPI_ERR_NOT_FOUND;
 153 }
 154