This source file includes following definitions.
- mca_spml_ucx_enable
- mca_spml_ucx_del_procs
- oshmem_shmem_xchng
- dump_address
- mca_spml_ucx_add_procs
- mca_spml_ucx_rmkey_free
- mca_spml_ucx_rmkey_ptr
- mca_spml_ucx_rmkey_unpack
- mca_spml_ucx_memuse_hook
- mca_spml_ucx_register
- mca_spml_ucx_deregister
- _ctx_add
- _ctx_remove
- mca_spml_ucx_ctx_create_common
- mca_spml_ucx_ctx_create
- mca_spml_ucx_ctx_destroy
- mca_spml_ucx_get
- mca_spml_ucx_get_nb
- mca_spml_ucx_put
- mca_spml_ucx_put_nb
- mca_spml_ucx_fence
- mca_spml_ucx_quiet
- mca_spml_ucx_recv
- mca_spml_ucx_send
- mca_spml_ucx_put_all_complete_cb
- mca_spml_ucx_create_aux_ctx
- mca_spml_ucx_put_all_nb
1
2
3
4
5
6
7
8
9
10
11
12
13
14 #define _GNU_SOURCE
15 #include <stdio.h>
16
17 #include <sys/types.h>
18 #include <unistd.h>
19 #include <stdint.h>
20
21 #include "oshmem_config.h"
22 #include "opal/datatype/opal_convertor.h"
23 #include "opal/mca/common/ucx/common_ucx.h"
24 #include "ompi/datatype/ompi_datatype.h"
25 #include "ompi/mca/pml/pml.h"
26
27
28 #include "oshmem/mca/spml/ucx/spml_ucx.h"
29 #include "oshmem/include/shmem.h"
30 #include "oshmem/mca/memheap/memheap.h"
31 #include "oshmem/mca/memheap/base/base.h"
32 #include "oshmem/proc/proc.h"
33 #include "oshmem/mca/spml/base/base.h"
34 #include "oshmem/mca/spml/base/spml_base_putreq.h"
35 #include "oshmem/mca/atomic/atomic.h"
36 #include "oshmem/runtime/runtime.h"
37
38 #include "oshmem/mca/spml/ucx/spml_ucx_component.h"
39 #include "oshmem/mca/sshmem/ucx/sshmem_ucx.h"
40
41
42 #ifndef SPML_UCX_PUT_DEBUG
43 #define SPML_UCX_PUT_DEBUG 0
44 #endif
45
46 mca_spml_ucx_t mca_spml_ucx = {
47 .super = {
48
49 .spml_add_procs = mca_spml_ucx_add_procs,
50 .spml_del_procs = mca_spml_ucx_del_procs,
51 .spml_enable = mca_spml_ucx_enable,
52 .spml_register = mca_spml_ucx_register,
53 .spml_deregister = mca_spml_ucx_deregister,
54 .spml_oob_get_mkeys = mca_spml_base_oob_get_mkeys,
55 .spml_ctx_create = mca_spml_ucx_ctx_create,
56 .spml_ctx_destroy = mca_spml_ucx_ctx_destroy,
57 .spml_put = mca_spml_ucx_put,
58 .spml_put_nb = mca_spml_ucx_put_nb,
59 .spml_get = mca_spml_ucx_get,
60 .spml_get_nb = mca_spml_ucx_get_nb,
61 .spml_recv = mca_spml_ucx_recv,
62 .spml_send = mca_spml_ucx_send,
63 .spml_wait = mca_spml_base_wait,
64 .spml_wait_nb = mca_spml_base_wait_nb,
65 .spml_test = mca_spml_base_test,
66 .spml_fence = mca_spml_ucx_fence,
67 .spml_quiet = mca_spml_ucx_quiet,
68 .spml_rmkey_unpack = mca_spml_ucx_rmkey_unpack,
69 .spml_rmkey_free = mca_spml_ucx_rmkey_free,
70 .spml_rmkey_ptr = mca_spml_ucx_rmkey_ptr,
71 .spml_memuse_hook = mca_spml_ucx_memuse_hook,
72 .spml_put_all_nb = mca_spml_ucx_put_all_nb,
73 .self = (void*)&mca_spml_ucx
74 },
75
76 .ucp_context = NULL,
77 .num_disconnect = 1,
78 .heap_reg_nb = 0,
79 .enabled = 0,
80 .get_mkey_slow = NULL
81 };
82
83 mca_spml_ucx_ctx_t mca_spml_ucx_ctx_default = {
84 .ucp_worker = NULL,
85 .ucp_peers = NULL,
86 .options = 0
87 };
88
89 int mca_spml_ucx_enable(bool enable)
90 {
91 SPML_UCX_VERBOSE(50, "*** ucx ENABLED ****");
92 if (false == enable) {
93 return OSHMEM_SUCCESS;
94 }
95
96 mca_spml_ucx.enabled = true;
97
98 return OSHMEM_SUCCESS;
99 }
100
101 int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs)
102 {
103 opal_common_ucx_del_proc_t *del_procs;
104 size_t i;
105 int ret;
106
107 oshmem_shmem_barrier();
108
109 if (!mca_spml_ucx_ctx_default.ucp_peers) {
110 return OSHMEM_SUCCESS;
111 }
112
113 del_procs = malloc(sizeof(*del_procs) * nprocs);
114 if (del_procs == NULL) {
115 return OMPI_ERR_OUT_OF_RESOURCE;
116 }
117
118 for (i = 0; i < nprocs; ++i) {
119 del_procs[i].ep = mca_spml_ucx_ctx_default.ucp_peers[i].ucp_conn;
120 del_procs[i].vpid = i;
121
122
123 mca_spml_ucx_ctx_default.ucp_peers[i].ucp_conn = NULL;
124 }
125
126 ret = opal_common_ucx_del_procs(del_procs, nprocs, oshmem_my_proc_id(),
127 mca_spml_ucx.num_disconnect,
128 mca_spml_ucx_ctx_default.ucp_worker);
129
130 free(del_procs);
131 free(mca_spml_ucx.remote_addrs_tbl);
132 free(mca_spml_ucx_ctx_default.ucp_peers);
133
134 mca_spml_ucx_ctx_default.ucp_peers = NULL;
135
136 opal_common_ucx_mca_proc_added();
137
138 return ret;
139 }
140
141
142 static int oshmem_shmem_xchng(
143 void *local_data, int local_size, int nprocs,
144 void **rdata_p, int **roffsets_p, int **rsizes_p)
145 {
146 int *rcv_sizes = NULL;
147 int *rcv_offsets = NULL;
148 void *rcv_buf = NULL;
149 int rc;
150 int i;
151
152
153 rcv_offsets = malloc(nprocs * sizeof(*rcv_offsets));
154 if (NULL == rcv_offsets) {
155 goto err;
156 }
157
158
159 rcv_sizes = malloc(nprocs * sizeof(*rcv_sizes));
160 if (NULL == rcv_sizes) {
161 goto err;
162 }
163
164 rc = oshmem_shmem_allgather(&local_size, rcv_sizes, sizeof(int));
165 if (MPI_SUCCESS != rc) {
166 goto err;
167 }
168
169
170 rcv_offsets[0] = 0;
171 for (i = 1; i < nprocs; i++) {
172 rcv_offsets[i] = rcv_offsets[i - 1] + rcv_sizes[i - 1];
173 }
174
175 rcv_buf = malloc(rcv_offsets[nprocs - 1] + rcv_sizes[nprocs - 1]);
176 if (NULL == rcv_buf) {
177 goto err;
178 }
179
180 rc = oshmem_shmem_allgatherv(local_data, rcv_buf, local_size, rcv_sizes, rcv_offsets);
181 if (MPI_SUCCESS != rc) {
182 goto err;
183 }
184
185 *rdata_p = rcv_buf;
186 *roffsets_p = rcv_offsets;
187 *rsizes_p = rcv_sizes;
188 return OSHMEM_SUCCESS;
189
190 err:
191 if (rcv_buf)
192 free(rcv_buf);
193 if (rcv_offsets)
194 free(rcv_offsets);
195 if (rcv_sizes)
196 free(rcv_sizes);
197 return OSHMEM_ERROR;
198 }
199
200 static void dump_address(int pe, char *addr, size_t len)
201 {
202 #ifdef SPML_UCX_DEBUG
203 int my_rank = oshmem_my_proc_id();
204 unsigned i;
205
206 printf("me=%d dest_pe=%d addr=%p len=%d\n", my_rank, pe, addr, len);
207 for (i = 0; i < len; i++) {
208 printf("%02X ", (unsigned)0xFF&addr[i]);
209 }
210 printf("\n");
211 #endif
212 }
213
214 static char spml_ucx_transport_ids[1] = { 0 };
215
216 int mca_spml_ucx_add_procs(ompi_proc_t** procs, size_t nprocs)
217 {
218 size_t i, j, n;
219 int rc = OSHMEM_ERROR;
220 int my_rank = oshmem_my_proc_id();
221 ucs_status_t err;
222 ucp_address_t *wk_local_addr;
223 size_t wk_addr_len;
224 int *wk_roffs = NULL;
225 int *wk_rsizes = NULL;
226 char *wk_raddrs = NULL;
227 ucp_ep_params_t ep_params;
228
229
230 mca_spml_ucx_ctx_default.ucp_peers = (ucp_peer_t *) calloc(nprocs, sizeof(*(mca_spml_ucx_ctx_default.ucp_peers)));
231 if (NULL == mca_spml_ucx_ctx_default.ucp_peers) {
232 goto error;
233 }
234
235 err = ucp_worker_get_address(mca_spml_ucx_ctx_default.ucp_worker, &wk_local_addr, &wk_addr_len);
236 if (err != UCS_OK) {
237 goto error;
238 }
239 dump_address(my_rank, (char *)wk_local_addr, wk_addr_len);
240
241 rc = oshmem_shmem_xchng(wk_local_addr, wk_addr_len, nprocs,
242 (void **)&wk_raddrs, &wk_roffs, &wk_rsizes);
243 if (rc != OSHMEM_SUCCESS) {
244 goto error;
245 }
246
247 opal_progress_register(spml_ucx_default_progress);
248
249 mca_spml_ucx.remote_addrs_tbl = (char **)calloc(nprocs, sizeof(char *));
250 memset(mca_spml_ucx.remote_addrs_tbl, 0, nprocs * sizeof(char *));
251
252
253 for (n = 0; n < nprocs; ++n) {
254 i = (my_rank + n) % nprocs;
255 dump_address(i, (char *)(wk_raddrs + wk_roffs[i]), wk_rsizes[i]);
256
257 ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
258 ep_params.address = (ucp_address_t *)(wk_raddrs + wk_roffs[i]);
259
260 err = ucp_ep_create(mca_spml_ucx_ctx_default.ucp_worker, &ep_params,
261 &mca_spml_ucx_ctx_default.ucp_peers[i].ucp_conn);
262 if (UCS_OK != err) {
263 SPML_UCX_ERROR("ucp_ep_create(proc=%zu/%zu) failed: %s", n, nprocs,
264 ucs_status_string(err));
265 goto error2;
266 }
267
268 OSHMEM_PROC_DATA(procs[i])->num_transports = 1;
269 OSHMEM_PROC_DATA(procs[i])->transport_ids = spml_ucx_transport_ids;
270
271 for (j = 0; j < MCA_MEMHEAP_MAX_SEGMENTS; j++) {
272 mca_spml_ucx_ctx_default.ucp_peers[i].mkeys[j].key.rkey = NULL;
273 }
274
275 mca_spml_ucx.remote_addrs_tbl[i] = (char *)malloc(wk_rsizes[i]);
276 memcpy(mca_spml_ucx.remote_addrs_tbl[i], (char *)(wk_raddrs + wk_roffs[i]),
277 wk_rsizes[i]);
278 }
279
280 ucp_worker_release_address(mca_spml_ucx_ctx_default.ucp_worker, wk_local_addr);
281 free(wk_raddrs);
282 free(wk_rsizes);
283 free(wk_roffs);
284
285 SPML_UCX_VERBOSE(50, "*** ADDED PROCS ***");
286 return OSHMEM_SUCCESS;
287
288 error2:
289 for (i = 0; i < nprocs; ++i) {
290 if (mca_spml_ucx_ctx_default.ucp_peers[i].ucp_conn) {
291 ucp_ep_destroy(mca_spml_ucx_ctx_default.ucp_peers[i].ucp_conn);
292 }
293 if (mca_spml_ucx.remote_addrs_tbl[i]) {
294 free(mca_spml_ucx.remote_addrs_tbl[i]);
295 }
296 }
297 if (mca_spml_ucx_ctx_default.ucp_peers)
298 free(mca_spml_ucx_ctx_default.ucp_peers);
299 if (mca_spml_ucx.remote_addrs_tbl)
300 free(mca_spml_ucx.remote_addrs_tbl);
301 free(wk_raddrs);
302 free(wk_rsizes);
303 free(wk_roffs);
304 error:
305 rc = OSHMEM_ERR_OUT_OF_RESOURCE;
306 SPML_UCX_ERROR("add procs FAILED rc=%d", rc);
307 return rc;
308
309 }
310
311 void mca_spml_ucx_rmkey_free(sshmem_mkey_t *mkey)
312 {
313 spml_ucx_mkey_t *ucx_mkey;
314
315 if (!mkey->spml_context) {
316 return;
317 }
318 ucx_mkey = (spml_ucx_mkey_t *)(mkey->spml_context);
319 ucp_rkey_destroy(ucx_mkey->rkey);
320 }
321
322 void *mca_spml_ucx_rmkey_ptr(const void *dst_addr, sshmem_mkey_t *mkey, int pe)
323 {
324 #if (((UCP_API_MAJOR >= 1) && (UCP_API_MINOR >= 3)) || (UCP_API_MAJOR >= 2))
325 void *rva;
326 ucs_status_t err;
327 spml_ucx_mkey_t *ucx_mkey = (spml_ucx_mkey_t *)(mkey->spml_context);
328
329 err = ucp_rkey_ptr(ucx_mkey->rkey, (uint64_t)dst_addr, &rva);
330 if (UCS_OK != err) {
331 return NULL;
332 }
333 return rva;
334 #else
335 return NULL;
336 #endif
337 }
338
339 void mca_spml_ucx_rmkey_unpack(shmem_ctx_t ctx, sshmem_mkey_t *mkey, uint32_t segno, int pe, int tr_id)
340 {
341 spml_ucx_mkey_t *ucx_mkey;
342 mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;
343 ucs_status_t err;
344
345 ucx_mkey = &ucx_ctx->ucp_peers[pe].mkeys[segno].key;
346
347 err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn,
348 mkey->u.data,
349 &ucx_mkey->rkey);
350 if (UCS_OK != err) {
351 SPML_UCX_ERROR("failed to unpack rkey: %s", ucs_status_string(err));
352 goto error_fatal;
353 }
354
355 if (ucx_ctx == &mca_spml_ucx_ctx_default) {
356 mkey->spml_context = ucx_mkey;
357 }
358 mca_spml_ucx_cache_mkey(ucx_ctx, mkey, segno, pe);
359 return;
360
361 error_fatal:
362 oshmem_shmem_abort(-1);
363 return;
364 }
365
366 void mca_spml_ucx_memuse_hook(void *addr, size_t length)
367 {
368 int my_pe;
369 spml_ucx_mkey_t *ucx_mkey;
370 ucp_mem_advise_params_t params;
371 ucs_status_t status;
372
373 if (!(mca_spml_ucx.heap_reg_nb && memheap_is_va_in_segment(addr, HEAP_SEG_INDEX))) {
374 return;
375 }
376
377 my_pe = oshmem_my_proc_id();
378 ucx_mkey = &mca_spml_ucx_ctx_default.ucp_peers[my_pe].mkeys[HEAP_SEG_INDEX].key;
379
380 params.field_mask = UCP_MEM_ADVISE_PARAM_FIELD_ADDRESS |
381 UCP_MEM_ADVISE_PARAM_FIELD_LENGTH |
382 UCP_MEM_ADVISE_PARAM_FIELD_ADVICE;
383
384 params.address = addr;
385 params.length = length;
386 params.advice = UCP_MADV_WILLNEED;
387
388 status = ucp_mem_advise(mca_spml_ucx.ucp_context, ucx_mkey->mem_h, ¶ms);
389 if (UCS_OK != status) {
390 SPML_UCX_ERROR("ucp_mem_advise failed addr %p len %llu : %s",
391 addr, (unsigned long long)length, ucs_status_string(status));
392 }
393 }
394
395 sshmem_mkey_t *mca_spml_ucx_register(void* addr,
396 size_t size,
397 uint64_t shmid,
398 int *count)
399 {
400 sshmem_mkey_t *mkeys;
401 ucs_status_t status;
402 spml_ucx_mkey_t *ucx_mkey;
403 size_t len;
404 ucp_mem_map_params_t mem_map_params;
405 int segno;
406 map_segment_t *mem_seg;
407 unsigned flags;
408 int my_pe = oshmem_my_proc_id();
409
410 *count = 0;
411 mkeys = (sshmem_mkey_t *) calloc(1, sizeof(*mkeys));
412 if (!mkeys) {
413 return NULL;
414 }
415
416 segno = memheap_find_segnum(addr);
417 mem_seg = memheap_find_seg(segno);
418
419 ucx_mkey = &mca_spml_ucx_ctx_default.ucp_peers[my_pe].mkeys[segno].key;
420 mkeys[0].spml_context = ucx_mkey;
421
422
423 if (MAP_SEGMENT_ALLOC_UCX != mem_seg->type) {
424 flags = 0;
425 if (mca_spml_ucx.heap_reg_nb && memheap_is_va_in_segment(addr, HEAP_SEG_INDEX)) {
426 flags = UCP_MEM_MAP_NONBLOCK;
427 }
428
429 mem_map_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
430 UCP_MEM_MAP_PARAM_FIELD_LENGTH |
431 UCP_MEM_MAP_PARAM_FIELD_FLAGS;
432 mem_map_params.address = addr;
433 mem_map_params.length = size;
434 mem_map_params.flags = flags;
435
436 status = ucp_mem_map(mca_spml_ucx.ucp_context, &mem_map_params, &ucx_mkey->mem_h);
437 if (UCS_OK != status) {
438 goto error_out;
439 }
440
441 } else {
442 mca_sshmem_ucx_segment_context_t *ctx = mem_seg->context;
443 ucx_mkey->mem_h = ctx->ucp_memh;
444 }
445
446 status = ucp_rkey_pack(mca_spml_ucx.ucp_context, ucx_mkey->mem_h,
447 &mkeys[0].u.data, &len);
448 if (UCS_OK != status) {
449 goto error_unmap;
450 }
451 if (len >= 0xffff) {
452 SPML_UCX_ERROR("packed rkey is too long: %llu >= %d",
453 (unsigned long long)len,
454 0xffff);
455 oshmem_shmem_abort(-1);
456 }
457
458 status = ucp_ep_rkey_unpack(mca_spml_ucx_ctx_default.ucp_peers[oshmem_group_self->my_pe].ucp_conn,
459 mkeys[0].u.data,
460 &ucx_mkey->rkey);
461 if (UCS_OK != status) {
462 SPML_UCX_ERROR("failed to unpack rkey");
463 goto error_unmap;
464 }
465
466 mkeys[0].len = len;
467 mkeys[0].va_base = addr;
468 *count = 1;
469 mca_spml_ucx_cache_mkey(&mca_spml_ucx_ctx_default, &mkeys[0], segno, my_pe);
470 return mkeys;
471
472 error_unmap:
473 ucp_mem_unmap(mca_spml_ucx.ucp_context, ucx_mkey->mem_h);
474 error_out:
475 free(mkeys);
476
477 return NULL ;
478 }
479
480 int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
481 {
482 spml_ucx_mkey_t *ucx_mkey;
483 map_segment_t *mem_seg;
484
485 MCA_SPML_CALL(quiet(oshmem_ctx_default));
486 if (!mkeys)
487 return OSHMEM_SUCCESS;
488
489 if (!mkeys[0].spml_context)
490 return OSHMEM_SUCCESS;
491
492 mem_seg = memheap_find_va(mkeys[0].va_base);
493 ucx_mkey = (spml_ucx_mkey_t*)mkeys[0].spml_context;
494
495 if (OPAL_UNLIKELY(NULL == mem_seg)) {
496 return OSHMEM_ERROR;
497 }
498
499 if (MAP_SEGMENT_ALLOC_UCX != mem_seg->type) {
500 ucp_mem_unmap(mca_spml_ucx.ucp_context, ucx_mkey->mem_h);
501 }
502 ucp_rkey_destroy(ucx_mkey->rkey);
503 ucx_mkey->rkey = NULL;
504
505 if (0 < mkeys[0].len) {
506 ucp_rkey_buffer_release(mkeys[0].u.data);
507 }
508
509 free(mkeys);
510
511 return OSHMEM_SUCCESS;
512 }
513
514 static inline void _ctx_add(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx_t *ctx)
515 {
516 int i;
517
518 if (array->ctxs_count < array->ctxs_num) {
519 array->ctxs[array->ctxs_count] = ctx;
520 } else {
521 array->ctxs = realloc(array->ctxs, (array->ctxs_num + MCA_SPML_UCX_CTXS_ARRAY_INC) * sizeof(mca_spml_ucx_ctx_t *));
522 opal_atomic_wmb ();
523 for (i = array->ctxs_num; i < array->ctxs_num + MCA_SPML_UCX_CTXS_ARRAY_INC; i++) {
524 array->ctxs[i] = NULL;
525 }
526 array->ctxs[array->ctxs_num] = ctx;
527 array->ctxs_num += MCA_SPML_UCX_CTXS_ARRAY_INC;
528 }
529
530 opal_atomic_wmb ();
531 array->ctxs_count++;
532 }
533
534 static inline void _ctx_remove(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx_t *ctx)
535 {
536 int i;
537
538 for (i = 0; i < array->ctxs_count; i++) {
539 if (array->ctxs[i] == ctx) {
540 array->ctxs[i] = array->ctxs[array->ctxs_count-1];
541 array->ctxs[array->ctxs_count-1] = NULL;
542 break;
543 }
544 }
545
546 array->ctxs_count--;
547 opal_atomic_wmb ();
548 }
549
550 static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx_ctx_p)
551 {
552 ucp_worker_params_t params;
553 ucp_ep_params_t ep_params;
554 size_t i, nprocs = oshmem_num_procs();
555 int j;
556 ucs_status_t err;
557 spml_ucx_mkey_t *ucx_mkey;
558 sshmem_mkey_t *mkey;
559 mca_spml_ucx_ctx_t *ucx_ctx;
560 int rc = OSHMEM_ERROR;
561
562 ucx_ctx = malloc(sizeof(mca_spml_ucx_ctx_t));
563 ucx_ctx->options = options;
564
565 params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
566 if (oshmem_mpi_thread_provided == SHMEM_THREAD_SINGLE || options & SHMEM_CTX_PRIVATE || options & SHMEM_CTX_SERIALIZED) {
567 params.thread_mode = UCS_THREAD_MODE_SINGLE;
568 } else {
569 params.thread_mode = UCS_THREAD_MODE_MULTI;
570 }
571
572 err = ucp_worker_create(mca_spml_ucx.ucp_context, ¶ms,
573 &ucx_ctx->ucp_worker);
574 if (UCS_OK != err) {
575 free(ucx_ctx);
576 return OSHMEM_ERROR;
577 }
578
579 ucx_ctx->ucp_peers = (ucp_peer_t *) calloc(nprocs, sizeof(*(ucx_ctx->ucp_peers)));
580 if (NULL == ucx_ctx->ucp_peers) {
581 goto error;
582 }
583
584 for (i = 0; i < nprocs; i++) {
585 ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
586 ep_params.address = (ucp_address_t *)(mca_spml_ucx.remote_addrs_tbl[i]);
587 err = ucp_ep_create(ucx_ctx->ucp_worker, &ep_params,
588 &ucx_ctx->ucp_peers[i].ucp_conn);
589 if (UCS_OK != err) {
590 SPML_ERROR("ucp_ep_create(proc=%d/%d) failed: %s", i, nprocs,
591 ucs_status_string(err));
592 goto error2;
593 }
594
595 for (j = 0; j < memheap_map->n_segments; j++) {
596 mkey = &memheap_map->mem_segs[j].mkeys_cache[i][0];
597 ucx_mkey = &ucx_ctx->ucp_peers[i].mkeys[j].key;
598 if (mkey->u.data) {
599 err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[i].ucp_conn,
600 mkey->u.data,
601 &ucx_mkey->rkey);
602 if (UCS_OK != err) {
603 SPML_UCX_ERROR("failed to unpack rkey");
604 goto error2;
605 }
606 mca_spml_ucx_cache_mkey(ucx_ctx, mkey, j, i);
607 }
608 }
609 }
610
611 *ucx_ctx_p = ucx_ctx;
612
613 return OSHMEM_SUCCESS;
614
615 error2:
616 for (i = 0; i < nprocs; i++) {
617 if (ucx_ctx->ucp_peers[i].ucp_conn) {
618 ucp_ep_destroy(ucx_ctx->ucp_peers[i].ucp_conn);
619 }
620 }
621
622 if (ucx_ctx->ucp_peers)
623 free(ucx_ctx->ucp_peers);
624
625 error:
626 ucp_worker_destroy(ucx_ctx->ucp_worker);
627 free(ucx_ctx);
628 rc = OSHMEM_ERR_OUT_OF_RESOURCE;
629 SPML_ERROR("ctx create FAILED rc=%d", rc);
630 return rc;
631 }
632
633 int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
634 {
635 mca_spml_ucx_ctx_t *ucx_ctx;
636 int rc;
637
638
639
640
641 pthread_mutex_lock(&mca_spml_ucx.ctx_create_mutex);
642 rc = mca_spml_ucx_ctx_create_common(options, &ucx_ctx);
643 pthread_mutex_unlock(&mca_spml_ucx.ctx_create_mutex);
644 if (rc != OSHMEM_SUCCESS) {
645 return rc;
646 }
647
648 if (mca_spml_ucx.active_array.ctxs_count == 0) {
649 opal_progress_register(spml_ucx_ctx_progress);
650 }
651
652 SHMEM_MUTEX_LOCK(mca_spml_ucx.internal_mutex);
653 _ctx_add(&mca_spml_ucx.active_array, ucx_ctx);
654 SHMEM_MUTEX_UNLOCK(mca_spml_ucx.internal_mutex);
655
656 (*ctx) = (shmem_ctx_t)ucx_ctx;
657 return OSHMEM_SUCCESS;
658 }
659
660 void mca_spml_ucx_ctx_destroy(shmem_ctx_t ctx)
661 {
662 MCA_SPML_CALL(quiet(ctx));
663
664 SHMEM_MUTEX_LOCK(mca_spml_ucx.internal_mutex);
665 _ctx_remove(&mca_spml_ucx.active_array, (mca_spml_ucx_ctx_t *)ctx);
666 _ctx_add(&mca_spml_ucx.idle_array, (mca_spml_ucx_ctx_t *)ctx);
667 SHMEM_MUTEX_UNLOCK(mca_spml_ucx.internal_mutex);
668
669 if (!mca_spml_ucx.active_array.ctxs_count) {
670 opal_progress_unregister(spml_ucx_ctx_progress);
671 }
672 }
673
674 int mca_spml_ucx_get(shmem_ctx_t ctx, void *src_addr, size_t size, void *dst_addr, int src)
675 {
676 void *rva;
677 spml_ucx_mkey_t *ucx_mkey;
678 mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;
679 #if HAVE_DECL_UCP_GET_NB
680 ucs_status_ptr_t request;
681 #else
682 ucs_status_t status;
683 #endif
684
685 ucx_mkey = mca_spml_ucx_get_mkey(ctx, src, src_addr, &rva, &mca_spml_ucx);
686 #if HAVE_DECL_UCP_GET_NB
687 request = ucp_get_nb(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size,
688 (uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb);
689 return opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker, "ucp_get_nb");
690 #else
691 status = ucp_get(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size,
692 (uint64_t)rva, ucx_mkey->rkey);
693 return ucx_status_to_oshmem(status);
694 #endif
695 }
696
697 int mca_spml_ucx_get_nb(shmem_ctx_t ctx, void *src_addr, size_t size, void *dst_addr, int src, void **handle)
698 {
699 void *rva;
700 ucs_status_t status;
701 spml_ucx_mkey_t *ucx_mkey;
702 mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;
703
704 ucx_mkey = mca_spml_ucx_get_mkey(ctx, src, src_addr, &rva, &mca_spml_ucx);
705 status = ucp_get_nbi(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size,
706 (uint64_t)rva, ucx_mkey->rkey);
707
708 return ucx_status_to_oshmem_nb(status);
709 }
710
711 int mca_spml_ucx_put(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_addr, int dst)
712 {
713 void *rva;
714 spml_ucx_mkey_t *ucx_mkey;
715 mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;
716 #if HAVE_DECL_UCP_PUT_NB
717 ucs_status_ptr_t request;
718 #else
719 ucs_status_t status;
720 #endif
721
722 ucx_mkey = mca_spml_ucx_get_mkey(ctx, dst, dst_addr, &rva, &mca_spml_ucx);
723 #if HAVE_DECL_UCP_PUT_NB
724 request = ucp_put_nb(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size,
725 (uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb);
726 return opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker, "ucp_put_nb");
727 #else
728 status = ucp_put(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size,
729 (uint64_t)rva, ucx_mkey->rkey);
730 return ucx_status_to_oshmem(status);
731 #endif
732 }
733
734 int mca_spml_ucx_put_nb(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_addr, int dst, void **handle)
735 {
736 void *rva;
737 ucs_status_t status;
738 spml_ucx_mkey_t *ucx_mkey;
739 mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;
740
741 ucx_mkey = mca_spml_ucx_get_mkey(ctx, dst, dst_addr, &rva, &mca_spml_ucx);
742 status = ucp_put_nbi(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size,
743 (uint64_t)rva, ucx_mkey->rkey);
744
745 return ucx_status_to_oshmem_nb(status);
746 }
747
748
749
750 int mca_spml_ucx_fence(shmem_ctx_t ctx)
751 {
752 ucs_status_t err;
753 mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;
754
755 opal_atomic_wmb();
756
757 err = ucp_worker_fence(ucx_ctx->ucp_worker);
758 if (UCS_OK != err) {
759 SPML_UCX_ERROR("fence failed: %s", ucs_status_string(err));
760 oshmem_shmem_abort(-1);
761 return OSHMEM_ERROR;
762 }
763 return OSHMEM_SUCCESS;
764 }
765
766 int mca_spml_ucx_quiet(shmem_ctx_t ctx)
767 {
768 int ret;
769 mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;
770
771 opal_atomic_wmb();
772
773 ret = opal_common_ucx_worker_flush(ucx_ctx->ucp_worker);
774 if (OMPI_SUCCESS != ret) {
775 oshmem_shmem_abort(-1);
776 return ret;
777 }
778
779
780
781 if (ctx == oshmem_ctx_default) {
782 while (mca_spml_ucx.aux_refcnt) {
783 opal_progress();
784 }
785 }
786
787 return OSHMEM_SUCCESS;
788 }
789
790
791 int mca_spml_ucx_recv(void* buf, size_t size, int src)
792 {
793 int rc = OSHMEM_SUCCESS;
794
795 rc = MCA_PML_CALL(recv(buf,
796 size,
797 &(ompi_mpi_unsigned_char.dt),
798 src,
799 0,
800 &(ompi_mpi_comm_world.comm),
801 NULL));
802
803 return rc;
804 }
805
806
807 int mca_spml_ucx_send(void* buf,
808 size_t size,
809 int dst,
810 mca_spml_base_put_mode_t mode)
811 {
812 int rc = OSHMEM_SUCCESS;
813
814 rc = MCA_PML_CALL(send(buf,
815 size,
816 &(ompi_mpi_unsigned_char.dt),
817 dst,
818 0,
819 (mca_pml_base_send_mode_t)mode,
820 &(ompi_mpi_comm_world.comm)));
821
822 return rc;
823 }
824
825
826 static void mca_spml_ucx_put_all_complete_cb(void *request, ucs_status_t status)
827 {
828 if (mca_spml_ucx.async_progress && (--mca_spml_ucx.aux_refcnt == 0)) {
829 opal_event_evtimer_del(mca_spml_ucx.tick_event);
830 opal_progress_unregister(spml_ucx_progress_aux_ctx);
831 }
832
833 if (request != NULL) {
834 ucp_request_free(request);
835 }
836 }
837
838
839 static int mca_spml_ucx_create_aux_ctx(void)
840 {
841 unsigned major = 0;
842 unsigned minor = 0;
843 unsigned rel_number = 0;
844 int rc;
845 bool rand_dci_supp;
846
847 ucp_get_version(&major, &minor, &rel_number);
848 rand_dci_supp = UCX_VERSION(major, minor, rel_number) >= UCX_VERSION(1, 6, 0);
849
850 if (rand_dci_supp) {
851 pthread_mutex_lock(&mca_spml_ucx.ctx_create_mutex);
852 opal_setenv("UCX_DC_MLX5_TX_POLICY", "rand", 0, &environ);
853 }
854
855 rc = mca_spml_ucx_ctx_create_common(SHMEM_CTX_PRIVATE, &mca_spml_ucx.aux_ctx);
856
857 if (rand_dci_supp) {
858 opal_unsetenv("UCX_DC_MLX5_TX_POLICY", &environ);
859 pthread_mutex_unlock(&mca_spml_ucx.ctx_create_mutex);
860 }
861
862 return rc;
863 }
864
865 int mca_spml_ucx_put_all_nb(void *dest, const void *source, size_t size, long *counter)
866 {
867 int my_pe = oshmem_my_proc_id();
868 long val = 1;
869 int peer, dst_pe, rc;
870 shmem_ctx_t ctx;
871 struct timeval tv;
872 void *request;
873
874 mca_spml_ucx_aux_lock();
875 if (mca_spml_ucx.async_progress) {
876 if (mca_spml_ucx.aux_ctx == NULL) {
877 rc = mca_spml_ucx_create_aux_ctx();
878 if (rc != OMPI_SUCCESS) {
879 mca_spml_ucx_aux_unlock();
880 oshmem_shmem_abort(-1);
881 }
882 }
883
884 if (mca_spml_ucx.aux_refcnt++ == 0) {
885 tv.tv_sec = 0;
886 tv.tv_usec = mca_spml_ucx.async_tick;
887 opal_event_evtimer_add(mca_spml_ucx.tick_event, &tv);
888 opal_progress_register(spml_ucx_progress_aux_ctx);
889 }
890 ctx = (shmem_ctx_t)mca_spml_ucx.aux_ctx;
891 } else {
892 ctx = oshmem_ctx_default;
893 }
894
895 assert(ctx != NULL);
896
897 for (peer = 0; peer < oshmem_num_procs(); peer++) {
898 dst_pe = (peer + my_pe) % oshmem_num_procs();
899 rc = mca_spml_ucx_put_nb(ctx,
900 (void*)((uintptr_t)dest + my_pe * size),
901 size,
902 (void*)((uintptr_t)source + dst_pe * size),
903 dst_pe, NULL);
904 RUNTIME_CHECK_RC(rc);
905
906 mca_spml_ucx_fence(ctx);
907
908 rc = MCA_ATOMIC_CALL(add(ctx, (void*)counter, val, sizeof(val), dst_pe));
909 RUNTIME_CHECK_RC(rc);
910 }
911
912 request = ucp_worker_flush_nb(((mca_spml_ucx_ctx_t*)ctx)->ucp_worker, 0,
913 mca_spml_ucx_put_all_complete_cb);
914 if (!UCS_PTR_IS_PTR(request)) {
915 mca_spml_ucx_put_all_complete_cb(NULL, UCS_PTR_STATUS(request));
916 }
917
918 mca_spml_ucx_aux_unlock();
919
920 return OSHMEM_SUCCESS;
921 }