This source file includes following definitions.
- mca_coll_cuda_module_construct
- mca_coll_cuda_module_destruct
- mca_coll_cuda_init_query
- mca_coll_cuda_comm_query
- mca_coll_cuda_module_enable
1
2
3
4
5
6
7
8
9
10
11
12
13 #include "ompi_config.h"
14
15 #include <string.h>
16 #include <stdio.h>
17
18 #include "coll_cuda.h"
19
20 #include "mpi.h"
21
22 #include "orte/util/show_help.h"
23 #include "orte/util/proc_info.h"
24
25 #include "ompi/constants.h"
26 #include "ompi/communicator/communicator.h"
27 #include "ompi/mca/coll/coll.h"
28 #include "ompi/mca/coll/base/base.h"
29 #include "coll_cuda.h"
30
31
32 static void mca_coll_cuda_module_construct(mca_coll_cuda_module_t *module)
33 {
34 memset(&(module->c_coll), 0, sizeof(module->c_coll));
35 }
36
37 static void mca_coll_cuda_module_destruct(mca_coll_cuda_module_t *module)
38 {
39 OBJ_RELEASE(module->c_coll.coll_allreduce_module);
40 OBJ_RELEASE(module->c_coll.coll_reduce_module);
41 OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module);
42 OBJ_RELEASE(module->c_coll.coll_scatter_module);
43
44
45
46 if (NULL != module->c_coll.coll_exscan_module) {
47 OBJ_RELEASE(module->c_coll.coll_exscan_module);
48 OBJ_RELEASE(module->c_coll.coll_scan_module);
49 }
50 }
51
52 OBJ_CLASS_INSTANCE(mca_coll_cuda_module_t, mca_coll_base_module_t,
53 mca_coll_cuda_module_construct,
54 mca_coll_cuda_module_destruct);
55
56
57
58
59
60
61
62 int mca_coll_cuda_init_query(bool enable_progress_threads,
63 bool enable_mpi_threads)
64 {
65
66
67 return OMPI_SUCCESS;
68 }
69
70
71
72
73
74
75
76 mca_coll_base_module_t *
77 mca_coll_cuda_comm_query(struct ompi_communicator_t *comm,
78 int *priority)
79 {
80 mca_coll_cuda_module_t *cuda_module;
81
82 cuda_module = OBJ_NEW(mca_coll_cuda_module_t);
83 if (NULL == cuda_module) {
84 return NULL;
85 }
86
87 *priority = mca_coll_cuda_component.priority;
88
89
90 cuda_module->super.coll_module_enable = mca_coll_cuda_module_enable;
91 cuda_module->super.ft_event = NULL;
92
93 cuda_module->super.coll_allgather = NULL;
94 cuda_module->super.coll_allgatherv = NULL;
95 cuda_module->super.coll_allreduce = mca_coll_cuda_allreduce;
96 cuda_module->super.coll_alltoall = NULL;
97 cuda_module->super.coll_alltoallv = NULL;
98 cuda_module->super.coll_alltoallw = NULL;
99 cuda_module->super.coll_barrier = NULL;
100 cuda_module->super.coll_bcast = NULL;
101 cuda_module->super.coll_exscan = mca_coll_cuda_exscan;
102 cuda_module->super.coll_gather = NULL;
103 cuda_module->super.coll_gatherv = NULL;
104 cuda_module->super.coll_reduce = mca_coll_cuda_reduce;
105 cuda_module->super.coll_reduce_scatter = NULL;
106 cuda_module->super.coll_reduce_scatter_block = mca_coll_cuda_reduce_scatter_block;
107 cuda_module->super.coll_scan = mca_coll_cuda_scan;
108 cuda_module->super.coll_scatter = NULL;
109 cuda_module->super.coll_scatterv = NULL;
110
111 return &(cuda_module->super);
112 }
113
114
115
116
117
118 int mca_coll_cuda_module_enable(mca_coll_base_module_t *module,
119 struct ompi_communicator_t *comm)
120 {
121 bool good = true;
122 char *msg = NULL;
123 mca_coll_cuda_module_t *s = (mca_coll_cuda_module_t*) module;
124
125 #define CHECK_AND_RETAIN(src, dst, name) \
126 if (NULL == (src)->c_coll->coll_ ## name ## _module) { \
127 good = false; \
128 msg = #name; \
129 } else if (good) { \
130 (dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module; \
131 (dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name; \
132 OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module); \
133 }
134
135 CHECK_AND_RETAIN(comm, s, allreduce);
136 CHECK_AND_RETAIN(comm, s, reduce);
137 CHECK_AND_RETAIN(comm, s, reduce_scatter_block);
138 CHECK_AND_RETAIN(comm, s, scatter);
139 if (!OMPI_COMM_IS_INTER(comm)) {
140
141 CHECK_AND_RETAIN(comm, s, exscan);
142 CHECK_AND_RETAIN(comm, s, scan);
143 }
144
145
146 if (good) {
147 return OMPI_SUCCESS;
148 }
149 orte_show_help("help-mpi-coll-cuda.txt", "missing collective", true,
150 orte_process_info.nodename,
151 mca_coll_cuda_component.priority, msg);
152 return OMPI_ERR_NOT_FOUND;
153 }
154