This source file includes following definitions.
- ompi_coll_base_bcast_intra_generic
- ompi_coll_base_bcast_intra_bintree
- ompi_coll_base_bcast_intra_pipeline
- ompi_coll_base_bcast_intra_chain
- ompi_coll_base_bcast_intra_binomial
- ompi_coll_base_bcast_intra_split_bintree
- ompi_coll_base_bcast_intra_basic_linear
- ompi_coll_base_bcast_intra_knomial
- ompi_coll_base_bcast_intra_scatter_allgather
- ompi_coll_base_bcast_intra_scatter_allgather_ring
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24 #include "ompi_config.h"
25
26 #include "mpi.h"
27 #include "ompi/constants.h"
28 #include "ompi/datatype/ompi_datatype.h"
29 #include "ompi/communicator/communicator.h"
30 #include "ompi/mca/coll/coll.h"
31 #include "ompi/mca/coll/base/coll_tags.h"
32 #include "ompi/mca/pml/pml.h"
33 #include "ompi/mca/coll/base/coll_base_functions.h"
34 #include "coll_base_topo.h"
35 #include "coll_base_util.h"
36
37 int
38 ompi_coll_base_bcast_intra_generic( void* buffer,
39 int original_count,
40 struct ompi_datatype_t* datatype,
41 int root,
42 struct ompi_communicator_t* comm,
43 mca_coll_base_module_t *module,
44 uint32_t count_by_segment,
45 ompi_coll_tree_t* tree )
46 {
47 int err = 0, line, i, rank, segindex, req_index;
48 int num_segments;
49 int sendcount;
50 size_t realsegsize, type_size;
51 char *tmpbuf;
52 ptrdiff_t extent, lb;
53 ompi_request_t *recv_reqs[2] = {MPI_REQUEST_NULL, MPI_REQUEST_NULL};
54 ompi_request_t **send_reqs = NULL;
55
56 #if OPAL_ENABLE_DEBUG
57 int size;
58 size = ompi_comm_size(comm);
59 assert( size > 1 );
60 #endif
61 rank = ompi_comm_rank(comm);
62
63 ompi_datatype_get_extent (datatype, &lb, &extent);
64 ompi_datatype_type_size( datatype, &type_size );
65 num_segments = (original_count + count_by_segment - 1) / count_by_segment;
66 realsegsize = (ptrdiff_t)count_by_segment * extent;
67
68
69 tmpbuf = (char *) buffer;
70
71 if( tree->tree_nextsize != 0 ) {
72 send_reqs = ompi_coll_base_comm_get_reqs(module->base_data, tree->tree_nextsize);
73 if( NULL == send_reqs ) { err = OMPI_ERR_OUT_OF_RESOURCE; line = __LINE__; goto error_hndl; }
74 }
75
76
77 if( rank == root ) {
78
79
80
81
82
83 sendcount = count_by_segment;
84 for( segindex = 0; segindex < num_segments; segindex++ ) {
85 if( segindex == (num_segments - 1) ) {
86 sendcount = original_count - segindex * count_by_segment;
87 }
88 for( i = 0; i < tree->tree_nextsize; i++ ) {
89 err = MCA_PML_CALL(isend(tmpbuf, sendcount, datatype,
90 tree->tree_next[i],
91 MCA_COLL_BASE_TAG_BCAST,
92 MCA_PML_BASE_SEND_STANDARD, comm,
93 &send_reqs[i]));
94 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
95 }
96
97
98 err = ompi_request_wait_all( tree->tree_nextsize, send_reqs,
99 MPI_STATUSES_IGNORE );
100 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
101
102
103 tmpbuf += realsegsize;
104
105 }
106 }
107
108
109 else if( tree->tree_nextsize > 0 ) {
110
111
112
113
114
115
116
117
118
119
120
121 req_index = 0;
122 err = MCA_PML_CALL(irecv(tmpbuf, count_by_segment, datatype,
123 tree->tree_prev, MCA_COLL_BASE_TAG_BCAST,
124 comm, &recv_reqs[req_index]));
125 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
126
127 for( segindex = 1; segindex < num_segments; segindex++ ) {
128
129 req_index = req_index ^ 0x1;
130
131
132 err = MCA_PML_CALL(irecv( tmpbuf + realsegsize, count_by_segment,
133 datatype, tree->tree_prev,
134 MCA_COLL_BASE_TAG_BCAST,
135 comm, &recv_reqs[req_index]));
136 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
137
138
139 err = ompi_request_wait( &recv_reqs[req_index ^ 0x1],
140 MPI_STATUS_IGNORE );
141 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
142
143 for( i = 0; i < tree->tree_nextsize; i++ ) {
144 err = MCA_PML_CALL(isend(tmpbuf, count_by_segment, datatype,
145 tree->tree_next[i],
146 MCA_COLL_BASE_TAG_BCAST,
147 MCA_PML_BASE_SEND_STANDARD, comm,
148 &send_reqs[i]));
149 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
150 }
151
152
153 err = ompi_request_wait_all( tree->tree_nextsize, send_reqs,
154 MPI_STATUSES_IGNORE );
155 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
156
157
158 tmpbuf += realsegsize;
159
160 }
161
162
163 err = ompi_request_wait( &recv_reqs[req_index], MPI_STATUS_IGNORE );
164 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
165 sendcount = original_count - (ptrdiff_t)(num_segments - 1) * count_by_segment;
166 for( i = 0; i < tree->tree_nextsize; i++ ) {
167 err = MCA_PML_CALL(isend(tmpbuf, sendcount, datatype,
168 tree->tree_next[i],
169 MCA_COLL_BASE_TAG_BCAST,
170 MCA_PML_BASE_SEND_STANDARD, comm,
171 &send_reqs[i]));
172 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
173 }
174
175 err = ompi_request_wait_all( tree->tree_nextsize, send_reqs,
176 MPI_STATUSES_IGNORE );
177 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
178 }
179
180
181 else {
182
183
184
185
186
187
188
189
190 req_index = 0;
191 err = MCA_PML_CALL(irecv(tmpbuf, count_by_segment, datatype,
192 tree->tree_prev, MCA_COLL_BASE_TAG_BCAST,
193 comm, &recv_reqs[req_index]));
194 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
195
196 for( segindex = 1; segindex < num_segments; segindex++ ) {
197 req_index = req_index ^ 0x1;
198 tmpbuf += realsegsize;
199
200 err = MCA_PML_CALL(irecv(tmpbuf, count_by_segment, datatype,
201 tree->tree_prev, MCA_COLL_BASE_TAG_BCAST,
202 comm, &recv_reqs[req_index]));
203 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
204
205 err = ompi_request_wait( &recv_reqs[req_index ^ 0x1],
206 MPI_STATUS_IGNORE );
207 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
208 }
209
210 err = ompi_request_wait( &recv_reqs[req_index], MPI_STATUS_IGNORE );
211 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
212 }
213
214 return (MPI_SUCCESS);
215
216 error_hndl:
217 if (MPI_ERR_IN_STATUS == err) {
218 for( req_index = 0; req_index < 2; req_index++ ) {
219 if (MPI_REQUEST_NULL == recv_reqs[req_index]) continue;
220 if (MPI_ERR_PENDING == recv_reqs[req_index]->req_status.MPI_ERROR) continue;
221 err = recv_reqs[req_index]->req_status.MPI_ERROR;
222 break;
223 }
224 }
225 ompi_coll_base_free_reqs( recv_reqs, 2);
226 if( NULL != send_reqs ) {
227 if (MPI_ERR_IN_STATUS == err) {
228 for( req_index = 0; req_index < tree->tree_nextsize; req_index++ ) {
229 if (MPI_REQUEST_NULL == send_reqs[req_index]) continue;
230 if (MPI_ERR_PENDING == send_reqs[req_index]->req_status.MPI_ERROR) continue;
231 err = send_reqs[req_index]->req_status.MPI_ERROR;
232 break;
233 }
234 }
235 ompi_coll_base_free_reqs(send_reqs, tree->tree_nextsize);
236 }
237 OPAL_OUTPUT( (ompi_coll_base_framework.framework_output,"%s:%4d\tError occurred %d, rank %2d",
238 __FILE__, line, err, rank) );
239 (void)line;
240
241 return err;
242 }
243
244 int
245 ompi_coll_base_bcast_intra_bintree ( void* buffer,
246 int count,
247 struct ompi_datatype_t* datatype,
248 int root,
249 struct ompi_communicator_t* comm,
250 mca_coll_base_module_t *module,
251 uint32_t segsize )
252 {
253 int segcount = count;
254 size_t typelng;
255 mca_coll_base_comm_t *data = module->base_data;
256
257 COLL_BASE_UPDATE_BINTREE( comm, module, root );
258
259
260
261
262 ompi_datatype_type_size( datatype, &typelng );
263 COLL_BASE_COMPUTED_SEGCOUNT( segsize, typelng, segcount );
264
265 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,"coll:base:bcast_intra_binary rank %d ss %5d typelng %lu segcount %d",
266 ompi_comm_rank(comm), segsize, (unsigned long)typelng, segcount));
267
268 return ompi_coll_base_bcast_intra_generic( buffer, count, datatype, root, comm, module,
269 segcount, data->cached_bintree );
270 }
271
272 int
273 ompi_coll_base_bcast_intra_pipeline( void* buffer,
274 int count,
275 struct ompi_datatype_t* datatype,
276 int root,
277 struct ompi_communicator_t* comm,
278 mca_coll_base_module_t *module,
279 uint32_t segsize )
280 {
281 int segcount = count;
282 size_t typelng;
283 mca_coll_base_comm_t *data = module->base_data;
284
285 COLL_BASE_UPDATE_PIPELINE( comm, module, root );
286
287
288
289
290 ompi_datatype_type_size( datatype, &typelng );
291 COLL_BASE_COMPUTED_SEGCOUNT( segsize, typelng, segcount );
292
293 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,"coll:base:bcast_intra_pipeline rank %d ss %5d typelng %lu segcount %d",
294 ompi_comm_rank(comm), segsize, (unsigned long)typelng, segcount));
295
296 return ompi_coll_base_bcast_intra_generic( buffer, count, datatype, root, comm, module,
297 segcount, data->cached_pipeline );
298 }
299
300 int
301 ompi_coll_base_bcast_intra_chain( void* buffer,
302 int count,
303 struct ompi_datatype_t* datatype,
304 int root,
305 struct ompi_communicator_t* comm,
306 mca_coll_base_module_t *module,
307 uint32_t segsize, int32_t chains )
308 {
309 int segcount = count;
310 size_t typelng;
311 mca_coll_base_comm_t *data = module->base_data;
312
313 COLL_BASE_UPDATE_CHAIN( comm, module, root, chains );
314
315
316
317
318 ompi_datatype_type_size( datatype, &typelng );
319 COLL_BASE_COMPUTED_SEGCOUNT( segsize, typelng, segcount );
320
321 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,"coll:base:bcast_intra_chain rank %d fo %d ss %5d typelng %lu segcount %d",
322 ompi_comm_rank(comm), chains, segsize, (unsigned long)typelng, segcount));
323
324 return ompi_coll_base_bcast_intra_generic( buffer, count, datatype, root, comm, module,
325 segcount, data->cached_chain );
326 }
327
328 int
329 ompi_coll_base_bcast_intra_binomial( void* buffer,
330 int count,
331 struct ompi_datatype_t* datatype,
332 int root,
333 struct ompi_communicator_t* comm,
334 mca_coll_base_module_t *module,
335 uint32_t segsize )
336 {
337 int segcount = count;
338 size_t typelng;
339 mca_coll_base_comm_t *data = module->base_data;
340
341 COLL_BASE_UPDATE_BMTREE( comm, module, root );
342
343
344
345
346 ompi_datatype_type_size( datatype, &typelng );
347 COLL_BASE_COMPUTED_SEGCOUNT( segsize, typelng, segcount );
348
349 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,"coll:base:bcast_intra_binomial rank %d ss %5d typelng %lu segcount %d",
350 ompi_comm_rank(comm), segsize, (unsigned long)typelng, segcount));
351
352 return ompi_coll_base_bcast_intra_generic( buffer, count, datatype, root, comm, module,
353 segcount, data->cached_bmtree );
354 }
355
356 int
357 ompi_coll_base_bcast_intra_split_bintree ( void* buffer,
358 int count,
359 struct ompi_datatype_t* datatype,
360 int root,
361 struct ompi_communicator_t* comm,
362 mca_coll_base_module_t *module,
363 uint32_t segsize )
364 {
365 int err=0, line, rank, size, segindex, i, lr, pair;
366 uint32_t counts[2];
367 int segcount[2];
368 int num_segments[2];
369 int sendcount[2];
370 size_t realsegsize[2], type_size;
371 char *tmpbuf[2];
372 ptrdiff_t type_extent, lb;
373 ompi_request_t *base_req, *new_req;
374 ompi_coll_tree_t *tree;
375
376 size = ompi_comm_size(comm);
377 rank = ompi_comm_rank(comm);
378
379 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,"ompi_coll_base_bcast_intra_split_bintree rank %d root %d ss %5d", rank, root, segsize));
380
381 if (size == 1) {
382 return MPI_SUCCESS;
383 }
384
385
386 COLL_BASE_UPDATE_BINTREE( comm, module, root );
387 tree = module->base_data->cached_bintree;
388
389 err = ompi_datatype_type_size( datatype, &type_size );
390
391
392 counts[0] = count/2;
393 if (count % 2 != 0) counts[0]++;
394 counts[1] = count - counts[0];
395 if ( segsize > 0 ) {
396
397
398
399 if (segsize < ((uint32_t) type_size)) {
400 segsize = type_size;
401 }
402 segcount[0] = segcount[1] = segsize / type_size;
403 num_segments[0] = counts[0]/segcount[0];
404 if ((counts[0] % segcount[0]) != 0) num_segments[0]++;
405 num_segments[1] = counts[1]/segcount[1];
406 if ((counts[1] % segcount[1]) != 0) num_segments[1]++;
407 } else {
408 segcount[0] = counts[0];
409 segcount[1] = counts[1];
410 num_segments[0] = num_segments[1] = 1;
411 }
412
413
414 if( (counts[0] == 0 || counts[1] == 0) ||
415 (segsize > ((ptrdiff_t)counts[0] * type_size)) ||
416 (segsize > ((ptrdiff_t)counts[1] * type_size)) ) {
417
418 return (ompi_coll_base_bcast_intra_chain ( buffer, count, datatype,
419 root, comm, module,
420 segsize, 1 ));
421 }
422
423 err = ompi_datatype_get_extent (datatype, &lb, &type_extent);
424
425
426 realsegsize[0] = (ptrdiff_t)segcount[0] * type_extent;
427 realsegsize[1] = (ptrdiff_t)segcount[1] * type_extent;
428
429
430 tmpbuf[0] = (char *) buffer;
431 tmpbuf[1] = (char *) buffer + (ptrdiff_t)counts[0] * type_extent;
432
433
434
435
436
437
438
439
440 lr = ((rank + size - root)%size + 1)%2;
441
442
443 if( rank == root ) {
444
445 sendcount[0] = segcount[0];
446 sendcount[1] = segcount[1];
447
448 for (segindex = 0; segindex < num_segments[0]; segindex++) {
449
450 for( i = 0; i < tree->tree_nextsize && i < 2; i++ ) {
451 if (segindex >= num_segments[i]) {
452 continue;
453 }
454
455 if(segindex == (num_segments[i] - 1))
456 sendcount[i] = counts[i] - segindex*segcount[i];
457
458 MCA_PML_CALL(send(tmpbuf[i], sendcount[i], datatype,
459 tree->tree_next[i], MCA_COLL_BASE_TAG_BCAST,
460 MCA_PML_BASE_SEND_STANDARD, comm));
461 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
462
463 tmpbuf[i] += realsegsize[i];
464 }
465 }
466 }
467
468
469 else if( tree->tree_nextsize > 0 ) {
470
471
472
473
474
475
476
477
478
479
480
481 sendcount[lr] = segcount[lr];
482 err = MCA_PML_CALL(irecv(tmpbuf[lr], sendcount[lr], datatype,
483 tree->tree_prev, MCA_COLL_BASE_TAG_BCAST,
484 comm, &base_req));
485 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
486
487 for( segindex = 1; segindex < num_segments[lr]; segindex++ ) {
488
489 if( segindex == (num_segments[lr] - 1))
490 sendcount[lr] = counts[lr] - (ptrdiff_t)segindex * (ptrdiff_t)segcount[lr];
491
492 err = MCA_PML_CALL(irecv( tmpbuf[lr] + realsegsize[lr], sendcount[lr],
493 datatype, tree->tree_prev, MCA_COLL_BASE_TAG_BCAST,
494 comm, &new_req));
495 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
496
497
498 err = ompi_request_wait( &base_req, MPI_STATUS_IGNORE );
499 for( i = 0; i < tree->tree_nextsize; i++ ) {
500 err = MCA_PML_CALL(send( tmpbuf[lr], segcount[lr], datatype,
501 tree->tree_next[i], MCA_COLL_BASE_TAG_BCAST,
502 MCA_PML_BASE_SEND_STANDARD, comm));
503 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
504 }
505
506
507 base_req = new_req;
508
509 tmpbuf[lr] += realsegsize[lr];
510 }
511
512
513 err = ompi_request_wait( &base_req, MPI_STATUS_IGNORE );
514 for( i = 0; i < tree->tree_nextsize; i++ ) {
515 err = MCA_PML_CALL(send(tmpbuf[lr], sendcount[lr], datatype,
516 tree->tree_next[i], MCA_COLL_BASE_TAG_BCAST,
517 MCA_PML_BASE_SEND_STANDARD, comm));
518 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
519 }
520 }
521
522
523 else {
524
525 sendcount[lr] = segcount[lr];
526 for (segindex = 0; segindex < num_segments[lr]; segindex++) {
527
528 if (segindex == (num_segments[lr] - 1))
529 sendcount[lr] = counts[lr] - (ptrdiff_t)segindex * (ptrdiff_t)segcount[lr];
530
531 err = MCA_PML_CALL(recv(tmpbuf[lr], sendcount[lr], datatype,
532 tree->tree_prev, MCA_COLL_BASE_TAG_BCAST,
533 comm, MPI_STATUS_IGNORE));
534 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
535
536 tmpbuf[lr] += realsegsize[lr];
537 }
538 }
539
540
541 tmpbuf[0] = (char *) buffer;
542 tmpbuf[1] = (char *) buffer + (ptrdiff_t)counts[0] * type_extent;
543
544
545
546
547
548
549
550
551
552
553 if (lr == 0) {
554 pair = (rank+1)%size;
555 } else {
556 pair = (rank+size-1)%size;
557 }
558
559 if ( (size%2) != 0 && rank != root) {
560
561 err = ompi_coll_base_sendrecv( tmpbuf[lr], counts[lr], datatype,
562 pair, MCA_COLL_BASE_TAG_BCAST,
563 tmpbuf[(lr+1)%2], counts[(lr+1)%2], datatype,
564 pair, MCA_COLL_BASE_TAG_BCAST,
565 comm, MPI_STATUS_IGNORE, rank);
566 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
567 } else if ( (size%2) == 0 ) {
568
569 if( rank == root ) {
570 err = MCA_PML_CALL(send(tmpbuf[1], counts[1], datatype,
571 (root+size-1)%size, MCA_COLL_BASE_TAG_BCAST,
572 MCA_PML_BASE_SEND_STANDARD, comm));
573 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
574
575 }
576
577 else if (rank == (root+size-1)%size) {
578 err = MCA_PML_CALL(recv(tmpbuf[1], counts[1], datatype,
579 root, MCA_COLL_BASE_TAG_BCAST,
580 comm, MPI_STATUS_IGNORE));
581 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
582 }
583
584 else {
585 err = ompi_coll_base_sendrecv( tmpbuf[lr], counts[lr], datatype,
586 pair, MCA_COLL_BASE_TAG_BCAST,
587 tmpbuf[(lr+1)%2], counts[(lr+1)%2], datatype,
588 pair, MCA_COLL_BASE_TAG_BCAST,
589 comm, MPI_STATUS_IGNORE, rank);
590 if (err != MPI_SUCCESS) { line = __LINE__; goto error_hndl; }
591 }
592 }
593 return (MPI_SUCCESS);
594
595 error_hndl:
596 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,"%s:%4d\tError occurred %d, rank %2d", __FILE__,line,err,rank));
597 (void)line;
598 return (err);
599 }
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623 int
624 ompi_coll_base_bcast_intra_basic_linear(void *buff, int count,
625 struct ompi_datatype_t *datatype, int root,
626 struct ompi_communicator_t *comm,
627 mca_coll_base_module_t *module)
628 {
629 int i, size, rank, err;
630 ompi_request_t **preq, **reqs;
631
632 size = ompi_comm_size(comm);
633 rank = ompi_comm_rank(comm);
634
635 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,"ompi_coll_base_bcast_intra_basic_linear rank %d root %d", rank, root));
636
637 if (1 == size) return OMPI_SUCCESS;
638
639
640
641 if (rank != root) {
642 return MCA_PML_CALL(recv(buff, count, datatype, root,
643 MCA_COLL_BASE_TAG_BCAST, comm,
644 MPI_STATUS_IGNORE));
645 }
646
647
648 preq = reqs = ompi_coll_base_comm_get_reqs(module->base_data, size-1);
649 if( NULL == reqs ) {
650 return OMPI_ERR_OUT_OF_RESOURCE;
651 }
652
653 for (i = 0; i < size; ++i) {
654 if (i == rank) {
655 continue;
656 }
657
658 err = MCA_PML_CALL(isend(buff, count, datatype, i,
659 MCA_COLL_BASE_TAG_BCAST,
660 MCA_PML_BASE_SEND_STANDARD,
661 comm, preq++));
662 if (MPI_SUCCESS != err) { goto err_hndl; }
663 }
664 --i;
665
666
667
668
669
670
671
672
673
674
675 err = ompi_request_wait_all(i, reqs, MPI_STATUSES_IGNORE);
676 err_hndl:
677 if( MPI_SUCCESS != err ) {
678
679 for( preq = reqs; preq < reqs+i; preq++ ) {
680 if (MPI_REQUEST_NULL == *preq) continue;
681 if (MPI_ERR_PENDING == (*preq)->req_status.MPI_ERROR) continue;
682 err = (*preq)->req_status.MPI_ERROR;
683 break;
684 }
685 ompi_coll_base_free_reqs(reqs, i);
686 }
687
688
689 return err;
690 }
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714 int ompi_coll_base_bcast_intra_knomial(
715 void *buf, int count, struct ompi_datatype_t *datatype, int root,
716 struct ompi_communicator_t *comm, mca_coll_base_module_t *module,
717 uint32_t segsize, int radix)
718 {
719 int segcount = count;
720 size_t typesize;
721 mca_coll_base_comm_t *data = module->base_data;
722
723 COLL_BASE_UPDATE_KMTREE(comm, module, root, radix);
724 if (NULL == data->cached_kmtree) {
725
726 return ompi_coll_base_bcast_intra_binomial(buf, count, datatype, root, comm, module,
727 segcount);
728 }
729
730
731
732
733 ompi_datatype_type_size(datatype, &typesize);
734 COLL_BASE_COMPUTED_SEGCOUNT(segsize, typesize, segcount);
735
736 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
737 "coll:base:bcast_intra_knomial rank %d segsize %5d typesize %lu segcount %d",
738 ompi_comm_rank(comm), segsize, (unsigned long)typesize, segcount));
739
740 return ompi_coll_base_bcast_intra_generic(buf, count, datatype, root, comm, module,
741 segcount, data->cached_kmtree);
742 }
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768 int ompi_coll_base_bcast_intra_scatter_allgather(
769 void *buf, int count, struct ompi_datatype_t *datatype, int root,
770 struct ompi_communicator_t *comm, mca_coll_base_module_t *module,
771 uint32_t segsize)
772 {
773 int err = MPI_SUCCESS;
774 ptrdiff_t lb, extent;
775 size_t datatype_size;
776 MPI_Status status;
777 ompi_datatype_get_extent(datatype, &lb, &extent);
778 ompi_datatype_type_size(datatype, &datatype_size);
779 int comm_size = ompi_comm_size(comm);
780 int rank = ompi_comm_rank(comm);
781
782 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
783 "coll:base:bcast_intra_scatter_allgather: rank %d/%d",
784 rank, comm_size));
785 if (comm_size < 2 || datatype_size == 0)
786 return MPI_SUCCESS;
787
788 if (count < comm_size) {
789 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
790 "coll:base:bcast_intra_scatter_allgather: rank %d/%d "
791 "count %d switching to basic linear bcast",
792 rank, comm_size, count));
793 return ompi_coll_base_bcast_intra_basic_linear(buf, count, datatype,
794 root, comm, module);
795 }
796
797 int vrank = (rank - root + comm_size) % comm_size;
798 int recv_count = 0, send_count = 0;
799 int scatter_count = (count + comm_size - 1) / comm_size;
800 int curr_count = (rank == root) ? count : 0;
801
802
803 int mask = 0x1;
804 while (mask < comm_size) {
805 if (vrank & mask) {
806 int parent = (rank - mask + comm_size) % comm_size;
807
808 recv_count = count - vrank * scatter_count;
809 if (recv_count <= 0) {
810 curr_count = 0;
811 } else {
812
813 err = MCA_PML_CALL(recv((char *)buf + (ptrdiff_t)vrank * scatter_count * extent,
814 recv_count, datatype, parent,
815 MCA_COLL_BASE_TAG_BCAST, comm, &status));
816 if (MPI_SUCCESS != err) { goto cleanup_and_return; }
817
818 curr_count = (int)(status._ucount / datatype_size);
819 }
820 break;
821 }
822 mask <<= 1;
823 }
824
825
826 mask >>= 1;
827 while (mask > 0) {
828 if (vrank + mask < comm_size) {
829 send_count = curr_count - scatter_count * mask;
830 if (send_count > 0) {
831 int child = (rank + mask) % comm_size;
832 err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent,
833 send_count, datatype, child,
834 MCA_COLL_BASE_TAG_BCAST,
835 MCA_PML_BASE_SEND_STANDARD, comm));
836 if (MPI_SUCCESS != err) { goto cleanup_and_return; }
837 curr_count -= send_count;
838 }
839 }
840 mask >>= 1;
841 }
842
843
844
845
846
847 int rem_count = count - vrank * scatter_count;
848 curr_count = (scatter_count < rem_count) ? scatter_count : rem_count;
849 if (curr_count < 0)
850 curr_count = 0;
851
852 mask = 0x1;
853 while (mask < comm_size) {
854 int vremote = vrank ^ mask;
855 int remote = (vremote + root) % comm_size;
856
857 int vrank_tree_root = ompi_rounddown(vrank, mask);
858 int vremote_tree_root = ompi_rounddown(vremote, mask);
859
860 if (vremote < comm_size) {
861 ptrdiff_t send_offset = vrank_tree_root * scatter_count * extent;
862 ptrdiff_t recv_offset = vremote_tree_root * scatter_count * extent;
863 recv_count = count - vremote_tree_root * scatter_count;
864 if (recv_count < 0)
865 recv_count = 0;
866 err = ompi_coll_base_sendrecv((char *)buf + send_offset,
867 curr_count, datatype, remote,
868 MCA_COLL_BASE_TAG_BCAST,
869 (char *)buf + recv_offset,
870 recv_count, datatype, remote,
871 MCA_COLL_BASE_TAG_BCAST,
872 comm, &status, rank);
873 if (MPI_SUCCESS != err) { goto cleanup_and_return; }
874 recv_count = (int)(status._ucount / datatype_size);
875 curr_count += recv_count;
876 }
877
878
879
880
881
882
883 if (vremote_tree_root + mask > comm_size) {
884 int nprocs_alldata = comm_size - vrank_tree_root - mask;
885 int offset = scatter_count * (vrank_tree_root + mask);
886 for (int rhalving_mask = mask >> 1; rhalving_mask > 0; rhalving_mask >>= 1) {
887 vremote = vrank ^ rhalving_mask;
888 remote = (vremote + root) % comm_size;
889 int tree_root = ompi_rounddown(vrank, rhalving_mask << 1);
890
891
892
893
894
895 if ((vremote > vrank) && (vrank < tree_root + nprocs_alldata)
896 && (vremote >= tree_root + nprocs_alldata)) {
897 err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)offset * extent,
898 recv_count, datatype, remote,
899 MCA_COLL_BASE_TAG_BCAST,
900 MCA_PML_BASE_SEND_STANDARD, comm));
901 if (MPI_SUCCESS != err) { goto cleanup_and_return; }
902
903 } else if ((vremote < vrank) && (vremote < tree_root + nprocs_alldata)
904 && (vrank >= tree_root + nprocs_alldata)) {
905 err = MCA_PML_CALL(recv((char *)buf + (ptrdiff_t)offset * extent,
906 count - offset, datatype, remote,
907 MCA_COLL_BASE_TAG_BCAST,
908 comm, &status));
909 if (MPI_SUCCESS != err) { goto cleanup_and_return; }
910 recv_count = (int)(status._ucount / datatype_size);
911 curr_count += recv_count;
912 }
913 }
914 }
915 mask <<= 1;
916 }
917
918 cleanup_and_return:
919 return err;
920 }
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945 int ompi_coll_base_bcast_intra_scatter_allgather_ring(
946 void *buf, int count, struct ompi_datatype_t *datatype, int root,
947 struct ompi_communicator_t *comm, mca_coll_base_module_t *module,
948 uint32_t segsize)
949 {
950 int err = MPI_SUCCESS;
951 ptrdiff_t lb, extent;
952 size_t datatype_size;
953 MPI_Status status;
954 ompi_datatype_get_extent(datatype, &lb, &extent);
955 ompi_datatype_type_size(datatype, &datatype_size);
956 int comm_size = ompi_comm_size(comm);
957 int rank = ompi_comm_rank(comm);
958
959 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
960 "coll:base:bcast_intra_scatter_allgather_ring: rank %d/%d",
961 rank, comm_size));
962 if (comm_size < 2 || datatype_size == 0)
963 return MPI_SUCCESS;
964
965 if (count < comm_size) {
966 OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
967 "coll:base:bcast_intra_scatter_allgather_ring: rank %d/%d "
968 "count %d switching to basic linear bcast",
969 rank, comm_size, count));
970 return ompi_coll_base_bcast_intra_basic_linear(buf, count, datatype,
971 root, comm, module);
972 }
973
974 int vrank = (rank - root + comm_size) % comm_size;
975 int recv_count = 0, send_count = 0;
976 int scatter_count = (count + comm_size - 1) / comm_size;
977 int curr_count = (rank == root) ? count : 0;
978
979
980 int mask = 1;
981 while (mask < comm_size) {
982 if (vrank & mask) {
983 int parent = (rank - mask + comm_size) % comm_size;
984
985 recv_count = count - vrank * scatter_count;
986 if (recv_count <= 0) {
987 curr_count = 0;
988 } else {
989
990 err = MCA_PML_CALL(recv((char *)buf + (ptrdiff_t)vrank * scatter_count * extent,
991 recv_count, datatype, parent,
992 MCA_COLL_BASE_TAG_BCAST, comm, &status));
993 if (MPI_SUCCESS != err) { goto cleanup_and_return; }
994
995 curr_count = (int)(status._ucount / datatype_size);
996 }
997 break;
998 }
999 mask <<= 1;
1000 }
1001
1002
1003 mask >>= 1;
1004 while (mask > 0) {
1005 if (vrank + mask < comm_size) {
1006 send_count = curr_count - scatter_count * mask;
1007 if (send_count > 0) {
1008 int child = (rank + mask) % comm_size;
1009 err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent,
1010 send_count, datatype, child,
1011 MCA_COLL_BASE_TAG_BCAST,
1012 MCA_PML_BASE_SEND_STANDARD, comm));
1013 if (MPI_SUCCESS != err) { goto cleanup_and_return; }
1014 curr_count -= send_count;
1015 }
1016 }
1017 mask >>= 1;
1018 }
1019
1020
1021 int left = (rank - 1 + comm_size) % comm_size;
1022 int right = (rank + 1) % comm_size;
1023 int send_block = vrank;
1024 int recv_block = (vrank - 1 + comm_size) % comm_size;
1025
1026 for (int i = 1; i < comm_size; i++) {
1027 recv_count = (scatter_count < count - recv_block * scatter_count) ?
1028 scatter_count : count - recv_block * scatter_count;
1029 if (recv_count < 0)
1030 recv_count = 0;
1031 ptrdiff_t recv_offset = recv_block * scatter_count * extent;
1032
1033 send_count = (scatter_count < count - send_block * scatter_count) ?
1034 scatter_count : count - send_block * scatter_count;
1035 if (send_count < 0)
1036 send_count = 0;
1037 ptrdiff_t send_offset = send_block * scatter_count * extent;
1038
1039 err = ompi_coll_base_sendrecv((char *)buf + send_offset, send_count,
1040 datatype, right, MCA_COLL_BASE_TAG_BCAST,
1041 (char *)buf + recv_offset, recv_count,
1042 datatype, left, MCA_COLL_BASE_TAG_BCAST,
1043 comm, MPI_STATUS_IGNORE, rank);
1044 if (MPI_SUCCESS != err) { goto cleanup_and_return; }
1045 send_block = recv_block;
1046 recv_block = (recv_block - 1 + comm_size) % comm_size;
1047 }
1048
1049 cleanup_and_return:
1050 return err;
1051 }