This source file includes following definitions.
- mca_scoll_mpi_init_query
- mca_scoll_mpi_module_clear
- mca_scoll_mpi_module_construct
- mca_scoll_mpi_module_destruct
- mca_scoll_mpi_save_coll_handlers
- mca_scoll_mpi_module_enable
- mca_scoll_mpi_comm_query
1
2
3
4
5
6
7
8
9
10
11
12
13 #include "ompi_config.h"
14 #include "scoll_mpi.h"
15 #include "opal/util/show_help.h"
16
17 #include "oshmem/proc/proc.h"
18 #include "oshmem/runtime/runtime.h"
19 #include "ompi/mca/coll/base/base.h"
20 #include "opal/util/timings.h"
21
22 int mca_scoll_mpi_init_query(bool enable_progress_threads, bool enable_mpi_threads)
23 {
24 return OSHMEM_SUCCESS;
25 }
26
27 static void mca_scoll_mpi_module_clear(mca_scoll_mpi_module_t *mpi_module)
28 {
29 mpi_module->previous_barrier = NULL;
30 mpi_module->previous_broadcast = NULL;
31 mpi_module->previous_reduce = NULL;
32 mpi_module->previous_collect = NULL;
33 mpi_module->previous_alltoall = NULL;
34 }
35
36 static void mca_scoll_mpi_module_construct(mca_scoll_mpi_module_t *mpi_module)
37 {
38 mca_scoll_mpi_module_clear(mpi_module);
39 }
40
41 static void mca_scoll_mpi_module_destruct(mca_scoll_mpi_module_t *mpi_module)
42 {
43
44 OBJ_RELEASE(mpi_module->previous_barrier_module);
45 OBJ_RELEASE(mpi_module->previous_broadcast_module);
46 OBJ_RELEASE(mpi_module->previous_reduce_module);
47 OBJ_RELEASE(mpi_module->previous_collect_module);
48 OBJ_RELEASE(mpi_module->previous_alltoall_module);
49
50 mca_scoll_mpi_module_clear(mpi_module);
51
52 if (mpi_module->comm != &(ompi_mpi_comm_world.comm) && (NULL != mpi_module->comm)) {
53 ompi_comm_free(&mpi_module->comm);
54 }
55 }
56
57 #define MPI_SAVE_PREV_SCOLL_API(__api) do {\
58 mpi_module->previous_ ## __api = osh_group->g_scoll.scoll_ ## __api;\
59 mpi_module->previous_ ## __api ## _module = osh_group->g_scoll.scoll_ ## __api ## _module;\
60 if (!osh_group->g_scoll.scoll_ ## __api || !osh_group->g_scoll.scoll_ ## __api ## _module) {\
61 MPI_COLL_VERBOSE(1, "no underlying " # __api"; disqualifying myself");\
62 return OSHMEM_ERROR;\
63 }\
64 OBJ_RETAIN(mpi_module->previous_ ## __api ## _module);\
65 } while(0)
66
67 static int mca_scoll_mpi_save_coll_handlers(mca_scoll_base_module_t *module, oshmem_group_t *osh_group)
68 {
69 mca_scoll_mpi_module_t* mpi_module = (mca_scoll_mpi_module_t*) module;
70 MPI_SAVE_PREV_SCOLL_API(barrier);
71 MPI_SAVE_PREV_SCOLL_API(broadcast);
72 MPI_SAVE_PREV_SCOLL_API(reduce);
73 MPI_SAVE_PREV_SCOLL_API(collect);
74 MPI_SAVE_PREV_SCOLL_API(alltoall);
75 return OSHMEM_SUCCESS;
76 }
77
78
79
80
81 static int mca_scoll_mpi_module_enable(mca_scoll_base_module_t *module,
82 oshmem_group_t *osh_group)
83 {
84
85 if (OSHMEM_SUCCESS != mca_scoll_mpi_save_coll_handlers(module, osh_group)){
86 MPI_COLL_ERROR("MPI module enable failed - aborting to prevent inconsistent application state");
87
88 opal_show_help("help-oshmem-scoll-mpi.txt",
89 "module_enable:fatal", true,
90 "MPI module enable failed - aborting to prevent inconsistent application state");
91
92 oshmem_shmem_abort(-1);
93 return OSHMEM_ERROR;
94 }
95
96 return OSHMEM_SUCCESS;
97 }
98
99
100
101
102
103
104
105
106 mca_scoll_base_module_t *
107 mca_scoll_mpi_comm_query(oshmem_group_t *osh_group, int *priority)
108 {
109 mca_scoll_base_module_t *module;
110 mca_scoll_mpi_module_t *mpi_module;
111 int err, i;
112 int tag;
113 ompi_group_t* world_group, *new_group;
114 ompi_communicator_t* newcomm = NULL;
115 *priority = 0;
116 mca_scoll_mpi_component_t *cm;
117 cm = &mca_scoll_mpi_component;
118 int* ranks;
119 if (!cm->mpi_enable){
120 return NULL;
121 }
122 if ((osh_group->proc_count < 2) || (osh_group->proc_count < cm->mpi_np)) {
123 return NULL;
124 }
125 OPAL_TIMING_ENV_INIT(comm_query);
126
127
128 if (NULL == oshmem_group_all) {
129 osh_group->ompi_comm = &(ompi_mpi_comm_world.comm);
130 OPAL_TIMING_ENV_NEXT(comm_query, "ompi_mpi_comm_world");
131 } else {
132 err = ompi_comm_group(&(ompi_mpi_comm_world.comm), &world_group);
133 if (OPAL_UNLIKELY(OMPI_SUCCESS != err)) {
134 return NULL;
135 }
136 OPAL_TIMING_ENV_NEXT(comm_query, "ompi_comm_group");
137
138 ranks = (int*) malloc(osh_group->proc_count * sizeof(int));
139 if (OPAL_UNLIKELY(NULL == ranks)) {
140 return NULL;
141 }
142 tag = 1;
143
144 OPAL_TIMING_ENV_NEXT(comm_query, "malloc");
145
146
147 for (i = 0; i < osh_group->proc_count; i++) {
148 ranks[i] = osh_group->proc_array[i]->super.proc_name.vpid;
149 }
150
151 OPAL_TIMING_ENV_NEXT(comm_query, "build_ranks");
152
153 err = ompi_group_incl(world_group, osh_group->proc_count, ranks, &new_group);
154 if (OPAL_UNLIKELY(OMPI_SUCCESS != err)) {
155 free(ranks);
156 return NULL;
157 }
158 OPAL_TIMING_ENV_NEXT(comm_query, "ompi_group_incl");
159
160 err = ompi_comm_create_group(&(ompi_mpi_comm_world.comm), new_group, tag, &newcomm);
161 if (OPAL_UNLIKELY(OMPI_SUCCESS != err)) {
162 free(ranks);
163 return NULL;
164 }
165 OPAL_TIMING_ENV_NEXT(comm_query, "ompi_comm_create_group");
166
167 err = ompi_group_free(&new_group);
168 if (OPAL_UNLIKELY(OMPI_SUCCESS != err)) {
169 free(ranks);
170 return NULL;
171 }
172 OPAL_TIMING_ENV_NEXT(comm_query, "ompi_group_free");
173
174 free(ranks);
175 osh_group->ompi_comm = newcomm;
176 OPAL_TIMING_ENV_NEXT(comm_query, "set_group_comm");
177 }
178 mpi_module = OBJ_NEW(mca_scoll_mpi_module_t);
179 if (!mpi_module){
180 return NULL;
181 }
182 mpi_module->comm = osh_group->ompi_comm;
183
184 mpi_module->super.scoll_module_enable = mca_scoll_mpi_module_enable;
185 mpi_module->super.scoll_barrier = mca_scoll_mpi_barrier;
186 mpi_module->super.scoll_broadcast = mca_scoll_mpi_broadcast;
187 mpi_module->super.scoll_reduce = mca_scoll_mpi_reduce;
188 mpi_module->super.scoll_collect = mca_scoll_mpi_collect;
189 mpi_module->super.scoll_alltoall = NULL;
190
191 *priority = cm->mpi_priority;
192 module = &mpi_module->super;
193
194 return module;
195 }
196
197
198 OBJ_CLASS_INSTANCE(mca_scoll_mpi_module_t,
199 mca_scoll_base_module_t,
200 mca_scoll_mpi_module_construct,
201 mca_scoll_mpi_module_destruct);
202
203
204