1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 #ifndef MCA_COLL_SM_EXPORT_H
24 #define MCA_COLL_SM_EXPORT_H
25
26 #include "ompi_config.h"
27
28 #include "mpi.h"
29 #include "ompi/mca/mca.h"
30 #include "opal/datatype/opal_convertor.h"
31 #include "opal/mca/common/sm/common_sm.h"
32 #include "ompi/mca/coll/coll.h"
33
34 BEGIN_C_DECLS
35
36
37
38
39 #define SPIN_CONDITION_MAX 100000
40 #define SPIN_CONDITION(cond, exit_label) \
41 do { int i; \
42 if (cond) goto exit_label; \
43 for (i = 0; i < SPIN_CONDITION_MAX; ++i) { \
44 if (cond) { goto exit_label; } \
45 } \
46 opal_progress(); \
47 } while (1); \
48 exit_label:
49
50
51
52
53
54
55
56 typedef struct mca_coll_sm_component_t {
57
58 mca_coll_base_component_2_0_0_t super;
59
60
61 int sm_priority;
62
63
64 int sm_control_size;
65
66
67
68 int sm_comm_num_in_use_flags;
69
70
71
72 int sm_comm_num_segments;
73
74
75 int sm_fragment_size;
76
77
78 int sm_tree_degree;
79
80
81
82 int sm_info_comm_size;
83
84
85
86
87
88
89
90 int sm_segs_per_inuse_flag;
91 } mca_coll_sm_component_t;
92
93
94
95
96 typedef struct mca_coll_sm_tree_node_t {
97
98 int mcstn_id;
99
100 struct mca_coll_sm_tree_node_t *mcstn_parent;
101
102 int mcstn_num_children;
103
104
105 struct mca_coll_sm_tree_node_t **mcstn_children;
106 } mca_coll_sm_tree_node_t;
107
108
109
110
111
112
113
114 typedef struct mca_coll_sm_in_use_flag_t {
115
116
117 opal_atomic_uint32_t mcsiuf_num_procs_using;
118
119 volatile uint32_t mcsiuf_operation_count;
120 } mca_coll_sm_in_use_flag_t;
121
122
123
124
125
126
127
128
129
130 typedef struct mca_coll_sm_data_index_t {
131
132 uint32_t volatile *mcbmi_control;
133
134 char *mcbmi_data;
135 } mca_coll_sm_data_index_t;
136
137
138
139
140
141
142
143 typedef struct mca_coll_sm_comm_t {
144
145
146 mca_common_sm_module_t *sm_bootstrap_meta;
147
148
149
150 uint32_t *mcb_barrier_control_me;
151
152
153
154
155 opal_atomic_uint32_t *mcb_barrier_control_parent;
156
157
158
159
160
161
162 uint32_t *mcb_barrier_control_children;
163
164
165
166 int mcb_barrier_count;
167
168
169 mca_coll_sm_in_use_flag_t *mcb_in_use_flags;
170
171
172
173
174 mca_coll_sm_data_index_t *mcb_data_index;
175
176
177
178 mca_coll_sm_tree_node_t *mcb_tree;
179
180
181 uint32_t mcb_operation_count;
182 } mca_coll_sm_comm_t;
183
184
185 typedef struct mca_coll_sm_module_t {
186
187 mca_coll_base_module_t super;
188
189
190 bool enabled;
191
192
193 mca_coll_sm_comm_t *sm_comm_data;
194
195
196 mca_coll_base_module_reduce_fn_t previous_reduce;
197 mca_coll_base_module_t *previous_reduce_module;
198 } mca_coll_sm_module_t;
199 OBJ_CLASS_DECLARATION(mca_coll_sm_module_t);
200
201
202
203
204 OMPI_MODULE_DECLSPEC extern mca_coll_sm_component_t mca_coll_sm_component;
205
206
207
208
209 int mca_coll_sm_init_query(bool enable_progress_threads,
210 bool enable_mpi_threads);
211
212 mca_coll_base_module_t *
213 mca_coll_sm_comm_query(struct ompi_communicator_t *comm, int *priority);
214
215
216
217 int ompi_coll_sm_lazy_enable(mca_coll_base_module_t *module,
218 struct ompi_communicator_t *comm);
219
220 int mca_coll_sm_allgather_intra(const void *sbuf, int scount,
221 struct ompi_datatype_t *sdtype,
222 void *rbuf, int rcount,
223 struct ompi_datatype_t *rdtype,
224 struct ompi_communicator_t *comm,
225 mca_coll_base_module_t *module);
226
227 int mca_coll_sm_allgatherv_intra(const void *sbuf, int scount,
228 struct ompi_datatype_t *sdtype,
229 void * rbuf, const int *rcounts, const int *disps,
230 struct ompi_datatype_t *rdtype,
231 struct ompi_communicator_t *comm,
232 mca_coll_base_module_t *module);
233 int mca_coll_sm_allreduce_intra(const void *sbuf, void *rbuf, int count,
234 struct ompi_datatype_t *dtype,
235 struct ompi_op_t *op,
236 struct ompi_communicator_t *comm,
237 mca_coll_base_module_t *module);
238 int mca_coll_sm_alltoall_intra(const void *sbuf, int scount,
239 struct ompi_datatype_t *sdtype,
240 void* rbuf, int rcount,
241 struct ompi_datatype_t *rdtype,
242 struct ompi_communicator_t *comm,
243 mca_coll_base_module_t *module);
244 int mca_coll_sm_alltoallv_intra(const void *sbuf, const int *scounts, const int *sdisps,
245 struct ompi_datatype_t *sdtype,
246 void *rbuf, const int *rcounts, const int *rdisps,
247 struct ompi_datatype_t *rdtype,
248 struct ompi_communicator_t *comm,
249 mca_coll_base_module_t *module);
250 int mca_coll_sm_alltoallw_intra(const void *sbuf, const int *scounts, const int *sdisps,
251 struct ompi_datatype_t * const *sdtypes,
252 void *rbuf, const int *rcounts, const int *rdisps,
253 struct ompi_datatype_t * const *rdtypes,
254 struct ompi_communicator_t *comm,
255 mca_coll_base_module_t *module);
256 int mca_coll_sm_barrier_intra(struct ompi_communicator_t *comm,
257 mca_coll_base_module_t *module);
258 int mca_coll_sm_bcast_intra(void *buff, int count,
259 struct ompi_datatype_t *datatype,
260 int root,
261 struct ompi_communicator_t *comm,
262 mca_coll_base_module_t *module);
263 int mca_coll_sm_bcast_log_intra(void *buff, int count,
264 struct ompi_datatype_t *datatype,
265 int root,
266 struct ompi_communicator_t *comm,
267 mca_coll_base_module_t *module);
268 int mca_coll_sm_exscan_intra(const void *sbuf, void *rbuf, int count,
269 struct ompi_datatype_t *dtype,
270 struct ompi_op_t *op,
271 struct ompi_communicator_t *comm,
272 mca_coll_base_module_t *module);
273 int mca_coll_sm_gather_intra(void *sbuf, int scount,
274 struct ompi_datatype_t *sdtype, void *rbuf,
275 int rcount, struct ompi_datatype_t *rdtype,
276 int root, struct ompi_communicator_t *comm,
277 mca_coll_base_module_t *module);
278 int mca_coll_sm_gatherv_intra(void *sbuf, int scount,
279 struct ompi_datatype_t *sdtype, void *rbuf,
280 int *rcounts, int *disps,
281 struct ompi_datatype_t *rdtype, int root,
282 struct ompi_communicator_t *comm,
283 mca_coll_base_module_t *module);
284 int mca_coll_sm_reduce_intra(const void *sbuf, void* rbuf, int count,
285 struct ompi_datatype_t *dtype,
286 struct ompi_op_t *op,
287 int root,
288 struct ompi_communicator_t *comm,
289 mca_coll_base_module_t *module);
290 int mca_coll_sm_reduce_log_intra(const void *sbuf, void* rbuf, int count,
291 struct ompi_datatype_t *dtype,
292 struct ompi_op_t *op,
293 int root,
294 struct ompi_communicator_t *comm,
295 mca_coll_base_module_t *module);
296 int mca_coll_sm_reduce_scatter_intra(const void *sbuf, void *rbuf,
297 int *rcounts,
298 struct ompi_datatype_t *dtype,
299 struct ompi_op_t *op,
300 struct ompi_communicator_t *comm,
301 mca_coll_base_module_t *module);
302 int mca_coll_sm_scan_intra(const void *sbuf, void *rbuf, int count,
303 struct ompi_datatype_t *dtype,
304 struct ompi_op_t *op,
305 struct ompi_communicator_t *comm,
306 mca_coll_base_module_t *module);
307 int mca_coll_sm_scatter_intra(const void *sbuf, int scount,
308 struct ompi_datatype_t *sdtype, void *rbuf,
309 int rcount, struct ompi_datatype_t *rdtype,
310 int root, struct ompi_communicator_t *comm,
311 mca_coll_base_module_t *module);
312 int mca_coll_sm_scatterv_intra(const void *sbuf, const int *scounts, const int *disps,
313 struct ompi_datatype_t *sdtype,
314 void* rbuf, int rcount,
315 struct ompi_datatype_t *rdtype, int root,
316 struct ompi_communicator_t *comm,
317 mca_coll_base_module_t *module);
318
319 int mca_coll_sm_ft_event(int state);
320
321
322
323
324
325 extern uint32_t mca_coll_sm_one;
326
327
328
329
330
331 #define FLAG_SETUP(flag_num, flag, data) \
332 (flag) = (mca_coll_sm_in_use_flag_t*) \
333 (((char *) (data)->mcb_in_use_flags) + \
334 ((flag_num) * mca_coll_sm_component.sm_control_size))
335
336
337
338
339 #define FLAG_WAIT_FOR_IDLE(flag, label) \
340 SPIN_CONDITION(0 == (flag)->mcsiuf_num_procs_using, label)
341
342
343
344
345
346
347 #define FLAG_WAIT_FOR_OP(flag, op, label) \
348 SPIN_CONDITION((op) == flag->mcsiuf_operation_count, label)
349
350
351
352
353 #define FLAG_RETAIN(flag, num_procs, op_count) \
354 (flag)->mcsiuf_num_procs_using = (num_procs); \
355 (flag)->mcsiuf_operation_count = (op_count)
356
357
358
359
360 #define FLAG_RELEASE(flag) \
361 opal_atomic_add(&(flag)->mcsiuf_num_procs_using, -1)
362
363
364
365
366
367 #define COPY_FRAGMENT_IN(convertor, index, rank, iov, max_data) \
368 (iov).iov_base = \
369 (index)->mcbmi_data + \
370 ((rank) * mca_coll_sm_component.sm_fragment_size); \
371 (iov).iov_len = (max_data); \
372 opal_convertor_pack(&(convertor), &(iov), &mca_coll_sm_one, \
373 &(max_data) )
374
375
376
377
378
379 #define COPY_FRAGMENT_OUT(convertor, src_rank, index, iov, max_data) \
380 (iov).iov_base = (((char*) (index)->mcbmi_data) + \
381 ((src_rank) * (mca_coll_sm_component.sm_fragment_size))); \
382 (iov).iov_len = (max_data); \
383 opal_convertor_unpack(&(convertor), &(iov), &mca_coll_sm_one, \
384 &(max_data) )
385
386
387
388
389 #define COPY_FRAGMENT_BETWEEN(src_rank, dest_rank, index, len) \
390 memcpy(((index)->mcbmi_data + \
391 ((dest_rank) * mca_coll_sm_component.sm_fragment_size)), \
392 ((index)->mcbmi_data + \
393 ((src_rank) * \
394 mca_coll_sm_component.sm_fragment_size)), \
395 (len))
396
397
398
399
400
401
402 #define PARENT_NOTIFY_CHILDREN(children, num_children, index, value) \
403 do { \
404 for (i = 0; i < (num_children); ++i) { \
405 *((size_t*) \
406 (((char*) index->mcbmi_control) + \
407 (mca_coll_sm_component.sm_control_size * \
408 (((children)[i]->mcstn_id + root) % size)))) = (value); \
409 } \
410 } while (0)
411
412
413
414
415
416
417 #define CHILD_WAIT_FOR_NOTIFY(rank, index, value, label) \
418 do { \
419 uint32_t volatile *ptr = ((uint32_t*) \
420 (((char*) index->mcbmi_control) + \
421 ((rank) * mca_coll_sm_component.sm_control_size))); \
422 SPIN_CONDITION(0 != *ptr, label); \
423 (value) = *ptr; \
424 *ptr = 0; \
425 } while (0)
426
427
428
429
430
431 #define CHILD_NOTIFY_PARENT(child_rank, parent_rank, index, value) \
432 ((size_t volatile *) \
433 (((char*) (index)->mcbmi_control) + \
434 (mca_coll_sm_component.sm_control_size * \
435 (parent_rank))))[(child_rank)] = (value)
436
437
438
439
440
441
442 #define PARENT_WAIT_FOR_NOTIFY_SPECIFIC(child_rank, parent_rank, index, value, label) \
443 do { \
444 size_t volatile *ptr = ((size_t volatile *) \
445 (((char*) index->mcbmi_control) + \
446 (mca_coll_sm_component.sm_control_size * \
447 (parent_rank)))) + child_rank; \
448 SPIN_CONDITION(0 != *ptr, label); \
449 (value) = *ptr; \
450 *ptr = 0; \
451 } while (0)
452
453 END_C_DECLS
454
455 #endif