This source file includes following definitions.
- mca_pml_ucx_send_worker_address_type
- mca_pml_ucx_send_worker_address
- mca_pml_ucx_recv_worker_address
- mca_pml_ucx_open
- mca_pml_ucx_close
- mca_pml_ucx_init
- mca_pml_ucx_cleanup
- mca_pml_ucx_add_proc_common
- mca_pml_ucx_add_proc
- mca_pml_ucx_add_procs
- mca_pml_ucx_get_ep
- mca_pml_ucx_del_procs
- mca_pml_ucx_enable
- mca_pml_ucx_progress
- mca_pml_ucx_add_comm
- mca_pml_ucx_del_comm
- mca_pml_ucx_irecv_init
- mca_pml_ucx_irecv
- mca_pml_ucx_recv
- mca_pml_ucx_send_mode_name
- mca_pml_ucx_isend_init
- mca_pml_ucx_bsend
- mca_pml_ucx_common_send
- mca_pml_ucx_isend
- mca_pml_ucx_send_nb
- mca_pml_ucx_send_nbr
- mca_pml_ucx_send
- mca_pml_ucx_iprobe
- mca_pml_ucx_probe
- mca_pml_ucx_improbe
- mca_pml_ucx_mprobe
- mca_pml_ucx_imrecv
- mca_pml_ucx_mrecv
- mca_pml_ucx_start
- mca_pml_ucx_dump
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 #include "pml_ucx.h"
16
17 #include "opal/runtime/opal.h"
18 #include "opal/mca/pmix/pmix.h"
19 #include "ompi/attribute/attribute.h"
20 #include "ompi/message/message.h"
21 #include "ompi/mca/pml/base/pml_base_bsend.h"
22 #include "opal/mca/common/ucx/common_ucx.h"
23 #include "pml_ucx_request.h"
24
25 #include <inttypes.h>
26
27
28 #define PML_UCX_TRACE_SEND(_msg, _buf, _count, _datatype, _dst, _tag, _mode, _comm, ...) \
29 PML_UCX_VERBOSE(8, _msg " buf %p count %zu type '%s' dst %d tag %d mode %s comm %d '%s'", \
30 __VA_ARGS__, \
31 (_buf), (_count), (_datatype)->name, (_dst), (_tag), \
32 mca_pml_ucx_send_mode_name(_mode), (_comm)->c_contextid, \
33 (_comm)->c_name);
34
35 #define PML_UCX_TRACE_RECV(_msg, _buf, _count, _datatype, _src, _tag, _comm, ...) \
36 PML_UCX_VERBOSE(8, _msg " buf %p count %zu type '%s' src %d tag %d comm %d '%s'", \
37 __VA_ARGS__, \
38 (_buf), (_count), (_datatype)->name, (_src), (_tag), \
39 (_comm)->c_contextid, (_comm)->c_name);
40
41 #define PML_UCX_TRACE_PROBE(_msg, _src, _tag, _comm) \
42 PML_UCX_VERBOSE(8, _msg " src %d tag %d comm %d '%s'", \
43 _src, (_tag), (_comm)->c_contextid, (_comm)->c_name);
44
45 #define PML_UCX_TRACE_MRECV(_msg, _buf, _count, _datatype, _message) \
46 PML_UCX_VERBOSE(8, _msg " buf %p count %zu type '%s' msg *%p=%p (%p)", \
47 (_buf), (_count), (_datatype)->name, (void*)(_message), \
48 (void*)*(_message), (*(_message))->req_ptr);
49
50 #define MODEX_KEY "pml-ucx"
51
52 mca_pml_ucx_module_t ompi_pml_ucx = {
53 .super = {
54 .pml_add_procs = mca_pml_ucx_add_procs,
55 .pml_del_procs = mca_pml_ucx_del_procs,
56 .pml_enable = mca_pml_ucx_enable,
57 .pml_progress = NULL,
58 .pml_add_comm = mca_pml_ucx_add_comm,
59 .pml_del_comm = mca_pml_ucx_del_comm,
60 .pml_irecv_init = mca_pml_ucx_irecv_init,
61 .pml_irecv = mca_pml_ucx_irecv,
62 .pml_recv = mca_pml_ucx_recv,
63 .pml_isend_init = mca_pml_ucx_isend_init,
64 .pml_isend = mca_pml_ucx_isend,
65 .pml_send = mca_pml_ucx_send,
66 .pml_iprobe = mca_pml_ucx_iprobe,
67 .pml_probe = mca_pml_ucx_probe,
68 .pml_start = mca_pml_ucx_start,
69 .pml_improbe = mca_pml_ucx_improbe,
70 .pml_mprobe = mca_pml_ucx_mprobe,
71 .pml_imrecv = mca_pml_ucx_imrecv,
72 .pml_mrecv = mca_pml_ucx_mrecv,
73 .pml_dump = mca_pml_ucx_dump,
74 .pml_ft_event = NULL,
75 .pml_max_contextid = 1ul << (PML_UCX_CONTEXT_BITS),
76 .pml_max_tag = 1ul << (PML_UCX_TAG_BITS - 1)
77 },
78 .ucp_context = NULL,
79 .ucp_worker = NULL
80 };
81
82 #define PML_UCX_REQ_ALLOCA() \
83 ((char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size);
84
85 #if HAVE_UCP_WORKER_ADDRESS_FLAGS
86 static int mca_pml_ucx_send_worker_address_type(int addr_flags, int modex_scope)
87 {
88 ucs_status_t status;
89 ucp_worker_attr_t attrs;
90 int rc;
91
92 attrs.field_mask = UCP_WORKER_ATTR_FIELD_ADDRESS |
93 UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS;
94 attrs.address_flags = addr_flags;
95
96 status = ucp_worker_query(ompi_pml_ucx.ucp_worker, &attrs);
97 if (UCS_OK != status) {
98 PML_UCX_ERROR("Failed to query UCP worker address");
99 return OMPI_ERROR;
100 }
101
102 OPAL_MODEX_SEND(rc, modex_scope, &mca_pml_ucx_component.pmlm_version,
103 (void*)attrs.address, attrs.address_length);
104
105 ucp_worker_release_address(ompi_pml_ucx.ucp_worker, attrs.address);
106
107 if (OMPI_SUCCESS != rc) {
108 return OMPI_ERROR;
109 }
110
111 PML_UCX_VERBOSE(2, "Pack %s worker address, size %ld",
112 (modex_scope == OPAL_PMIX_LOCAL) ? "local" : "remote",
113 attrs.address_length);
114
115 return OMPI_SUCCESS;
116 }
117 #endif
118
119 static int mca_pml_ucx_send_worker_address(void)
120 {
121 ucs_status_t status;
122
123 #if !HAVE_UCP_WORKER_ADDRESS_FLAGS
124 ucp_address_t *address;
125 size_t addrlen;
126 int rc;
127
128 status = ucp_worker_get_address(ompi_pml_ucx.ucp_worker, &address, &addrlen);
129 if (UCS_OK != status) {
130 PML_UCX_ERROR("Failed to get worker address");
131 return OMPI_ERROR;
132 }
133
134 PML_UCX_VERBOSE(2, "Pack worker address, size %ld", addrlen);
135
136 OPAL_MODEX_SEND(rc, OPAL_PMIX_GLOBAL,
137 &mca_pml_ucx_component.pmlm_version, (void*)address, addrlen);
138
139 ucp_worker_release_address(ompi_pml_ucx.ucp_worker, address);
140
141 if (OMPI_SUCCESS != rc) {
142 goto err;
143 }
144 #else
145
146 status = mca_pml_ucx_send_worker_address_type(UCP_WORKER_ADDRESS_FLAG_NET_ONLY,
147 OPAL_PMIX_REMOTE);
148 if (UCS_OK != status) {
149 goto err;
150 }
151
152 status = mca_pml_ucx_send_worker_address_type(0, OPAL_PMIX_LOCAL);
153 if (UCS_OK != status) {
154 goto err;
155 }
156 #endif
157
158 return OMPI_SUCCESS;
159
160 err:
161 PML_UCX_ERROR("Open MPI couldn't distribute EP connection details");
162 return OMPI_ERROR;
163 }
164
165 static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,
166 ucp_address_t **address_p,
167 size_t *addrlen_p)
168 {
169 int ret;
170
171 *address_p = NULL;
172 OPAL_MODEX_RECV(ret, &mca_pml_ucx_component.pmlm_version, &proc->super.proc_name,
173 (void**)address_p, addrlen_p);
174 if (ret < 0) {
175 PML_UCX_ERROR("Failed to receive UCX worker address: %s (%d)",
176 opal_strerror(ret), ret);
177 }
178
179 PML_UCX_VERBOSE(2, "Got proc %d address, size %ld",
180 proc->super.proc_name.vpid, *addrlen_p);
181 return ret;
182 }
183
184 int mca_pml_ucx_open(void)
185 {
186 ucp_context_attr_t attr;
187 ucp_params_t params;
188 ucp_config_t *config;
189 ucs_status_t status;
190
191 PML_UCX_VERBOSE(1, "mca_pml_ucx_open");
192
193
194 status = ucp_config_read("MPI", NULL, &config);
195 if (UCS_OK != status) {
196 return OMPI_ERROR;
197 }
198
199
200 params.field_mask = UCP_PARAM_FIELD_FEATURES |
201 UCP_PARAM_FIELD_REQUEST_SIZE |
202 UCP_PARAM_FIELD_REQUEST_INIT |
203 UCP_PARAM_FIELD_REQUEST_CLEANUP |
204 UCP_PARAM_FIELD_TAG_SENDER_MASK |
205 UCP_PARAM_FIELD_MT_WORKERS_SHARED |
206 UCP_PARAM_FIELD_ESTIMATED_NUM_EPS;
207 params.features = UCP_FEATURE_TAG;
208 params.request_size = sizeof(ompi_request_t);
209 params.request_init = mca_pml_ucx_request_init;
210 params.request_cleanup = mca_pml_ucx_request_cleanup;
211 params.tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK;
212 params.mt_workers_shared = 0;
213
214 params.estimated_num_eps = ompi_proc_world_size();
215
216 status = ucp_init(¶ms, config, &ompi_pml_ucx.ucp_context);
217 ucp_config_release(config);
218
219 if (UCS_OK != status) {
220 return OMPI_ERROR;
221 }
222
223
224 attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE;
225 status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr);
226 if (UCS_OK != status) {
227 ucp_cleanup(ompi_pml_ucx.ucp_context);
228 ompi_pml_ucx.ucp_context = NULL;
229 return OMPI_ERROR;
230 }
231
232 ompi_pml_ucx.request_size = attr.request_size;
233
234 return OMPI_SUCCESS;
235 }
236
237 int mca_pml_ucx_close(void)
238 {
239 PML_UCX_VERBOSE(1, "mca_pml_ucx_close");
240
241 if (ompi_pml_ucx.ucp_context != NULL) {
242 ucp_cleanup(ompi_pml_ucx.ucp_context);
243 ompi_pml_ucx.ucp_context = NULL;
244 }
245 return OMPI_SUCCESS;
246 }
247
248 int mca_pml_ucx_init(int enable_mpi_threads)
249 {
250 ucp_worker_params_t params;
251 ucp_worker_attr_t attr;
252 ucs_status_t status;
253 int i, rc;
254
255 PML_UCX_VERBOSE(1, "mca_pml_ucx_init");
256
257
258 params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
259 if (enable_mpi_threads) {
260 params.thread_mode = UCS_THREAD_MODE_MULTI;
261 } else {
262 params.thread_mode = UCS_THREAD_MODE_SINGLE;
263 }
264
265 status = ucp_worker_create(ompi_pml_ucx.ucp_context, ¶ms,
266 &ompi_pml_ucx.ucp_worker);
267 if (UCS_OK != status) {
268 PML_UCX_ERROR("Failed to create UCP worker");
269 rc = OMPI_ERROR;
270 goto err;
271 }
272
273 attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE;
274 status = ucp_worker_query(ompi_pml_ucx.ucp_worker, &attr);
275 if (UCS_OK != status) {
276 PML_UCX_ERROR("Failed to query UCP worker thread level");
277 rc = OMPI_ERROR;
278 goto err_destroy_worker;
279 }
280
281 if (enable_mpi_threads && (attr.thread_mode != UCS_THREAD_MODE_MULTI)) {
282
283
284 PML_UCX_VERBOSE(1, "UCP worker does not support MPI_THREAD_MULTIPLE. "
285 "PML UCX could not be selected");
286 rc = OMPI_ERR_NOT_SUPPORTED;
287 goto err_destroy_worker;
288 }
289
290 rc = mca_pml_ucx_send_worker_address();
291 if (rc < 0) {
292 goto err_destroy_worker;
293 }
294
295 ompi_pml_ucx.datatype_attr_keyval = MPI_KEYVAL_INVALID;
296 for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
297 ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID;
298 }
299
300
301 OBJ_CONSTRUCT(&ompi_pml_ucx.persistent_reqs, mca_pml_ucx_freelist_t);
302 OBJ_CONSTRUCT(&ompi_pml_ucx.convs, mca_pml_ucx_freelist_t);
303
304
305 OBJ_CONSTRUCT(&ompi_pml_ucx.completed_send_req, ompi_request_t);
306 mca_pml_ucx_completed_request_init(&ompi_pml_ucx.completed_send_req);
307
308 opal_progress_register(mca_pml_ucx_progress);
309
310 PML_UCX_VERBOSE(2, "created ucp context %p, worker %p",
311 (void *)ompi_pml_ucx.ucp_context,
312 (void *)ompi_pml_ucx.ucp_worker);
313 return OMPI_SUCCESS;
314
315 err_destroy_worker:
316 ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
317 ompi_pml_ucx.ucp_worker = NULL;
318 err:
319 return rc;
320 }
321
322 int mca_pml_ucx_cleanup(void)
323 {
324 int i;
325
326 PML_UCX_VERBOSE(1, "mca_pml_ucx_cleanup");
327
328 opal_progress_unregister(mca_pml_ucx_progress);
329
330 if (ompi_pml_ucx.datatype_attr_keyval != MPI_KEYVAL_INVALID) {
331 ompi_attr_free_keyval(TYPE_ATTR, &ompi_pml_ucx.datatype_attr_keyval, false);
332 }
333
334 for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
335 if (ompi_pml_ucx.predefined_types[i] != PML_UCX_DATATYPE_INVALID) {
336 ucp_dt_destroy(ompi_pml_ucx.predefined_types[i]);
337 ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID;
338 }
339 }
340
341 ompi_pml_ucx.completed_send_req.req_state = OMPI_REQUEST_INVALID;
342 OMPI_REQUEST_FINI(&ompi_pml_ucx.completed_send_req);
343 OBJ_DESTRUCT(&ompi_pml_ucx.completed_send_req);
344
345 OBJ_DESTRUCT(&ompi_pml_ucx.convs);
346 OBJ_DESTRUCT(&ompi_pml_ucx.persistent_reqs);
347
348 if (ompi_pml_ucx.ucp_worker) {
349 ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
350 ompi_pml_ucx.ucp_worker = NULL;
351 }
352
353 return OMPI_SUCCESS;
354 }
355
356 static ucp_ep_h mca_pml_ucx_add_proc_common(ompi_proc_t *proc)
357 {
358 ucp_ep_params_t ep_params;
359 ucp_address_t *address;
360 ucs_status_t status;
361 size_t addrlen;
362 ucp_ep_h ep;
363 int ret;
364
365 ret = mca_pml_ucx_recv_worker_address(proc, &address, &addrlen);
366 if (ret < 0) {
367 return NULL;
368 }
369
370 PML_UCX_VERBOSE(2, "connecting to proc. %d", proc->super.proc_name.vpid);
371
372 ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
373 ep_params.address = address;
374
375 status = ucp_ep_create(ompi_pml_ucx.ucp_worker, &ep_params, &ep);
376 free(address);
377 if (UCS_OK != status) {
378 PML_UCX_ERROR("ucp_ep_create(proc=%d) failed: %s",
379 proc->super.proc_name.vpid,
380 ucs_status_string(status));
381 return NULL;
382 }
383
384 proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;
385 return ep;
386 }
387
388 static ucp_ep_h mca_pml_ucx_add_proc(ompi_communicator_t *comm, int dst)
389 {
390 ompi_proc_t *proc0 = ompi_comm_peer_lookup(comm, 0);
391 ompi_proc_t *proc_peer = ompi_comm_peer_lookup(comm, dst);
392 int ret;
393
394
395 if (OMPI_SUCCESS != (ret = mca_pml_base_pml_check_selected("ucx",
396 &proc0,
397 dst))) {
398 return NULL;
399 }
400
401 return mca_pml_ucx_add_proc_common(proc_peer);
402 }
403
404 int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
405 {
406 ompi_proc_t *proc;
407 ucp_ep_h ep;
408 size_t i;
409 int ret;
410
411 if (OMPI_SUCCESS != (ret = mca_pml_base_pml_check_selected("ucx",
412 procs,
413 nprocs))) {
414 return ret;
415 }
416
417 for (i = 0; i < nprocs; ++i) {
418 proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
419 ep = mca_pml_ucx_add_proc_common(proc);
420 if (ep == NULL) {
421 return OMPI_ERROR;
422 }
423 }
424
425 opal_common_ucx_mca_proc_added();
426 return OMPI_SUCCESS;
427 }
428
429 static inline ucp_ep_h mca_pml_ucx_get_ep(ompi_communicator_t *comm, int rank)
430 {
431 ucp_ep_h ep;
432
433 ep = ompi_comm_peer_lookup(comm, rank)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
434 if (OPAL_LIKELY(ep != NULL)) {
435 return ep;
436 }
437
438 ep = mca_pml_ucx_add_proc(comm, rank);
439 if (OPAL_LIKELY(ep != NULL)) {
440 return ep;
441 }
442
443 if (rank >= ompi_comm_size(comm)) {
444 PML_UCX_ERROR("Rank number (%d) is larger than communicator size (%d)",
445 rank, ompi_comm_size(comm));
446 } else {
447 PML_UCX_ERROR("Failed to resolve UCX endpoint for rank %d", rank);
448 }
449
450 return NULL;
451 }
452
453 int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
454 {
455 ompi_proc_t *proc;
456 opal_common_ucx_del_proc_t *del_procs;
457 size_t i;
458 int ret;
459
460 del_procs = malloc(sizeof(*del_procs) * nprocs);
461 if (del_procs == NULL) {
462 return OMPI_ERR_OUT_OF_RESOURCE;
463 }
464
465 for (i = 0; i < nprocs; ++i) {
466 proc = procs[i];
467 del_procs[i].ep = proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
468 del_procs[i].vpid = proc->super.proc_name.vpid;
469
470
471 proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
472 }
473
474 ret = opal_common_ucx_del_procs(del_procs, nprocs, OMPI_PROC_MY_NAME->vpid,
475 ompi_pml_ucx.num_disconnect, ompi_pml_ucx.ucp_worker);
476 free(del_procs);
477
478 return ret;
479 }
480
481 int mca_pml_ucx_enable(bool enable)
482 {
483 ompi_attribute_fn_ptr_union_t copy_fn;
484 ompi_attribute_fn_ptr_union_t del_fn;
485 int ret;
486
487
488 copy_fn.attr_datatype_copy_fn =
489 (MPI_Type_internal_copy_attr_function*)MPI_TYPE_NULL_COPY_FN;
490 del_fn.attr_datatype_delete_fn = mca_pml_ucx_datatype_attr_del_fn;
491 ret = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn,
492 &ompi_pml_ucx.datatype_attr_keyval, NULL, 0,
493 NULL);
494 if (ret != OMPI_SUCCESS) {
495 PML_UCX_ERROR("Failed to create keyval for UCX datatypes: %d", ret);
496 return ret;
497 }
498
499 PML_UCX_FREELIST_INIT(&ompi_pml_ucx.persistent_reqs,
500 mca_pml_ucx_persistent_request_t,
501 128, -1, 128);
502 PML_UCX_FREELIST_INIT(&ompi_pml_ucx.convs,
503 mca_pml_ucx_convertor_t,
504 128, -1, 128);
505 return OMPI_SUCCESS;
506 }
507
508 int mca_pml_ucx_progress(void)
509 {
510 ucp_worker_progress(ompi_pml_ucx.ucp_worker);
511 return OMPI_SUCCESS;
512 }
513
514 int mca_pml_ucx_add_comm(struct ompi_communicator_t* comm)
515 {
516 return OMPI_SUCCESS;
517 }
518
519 int mca_pml_ucx_del_comm(struct ompi_communicator_t* comm)
520 {
521 return OMPI_SUCCESS;
522 }
523
524 int mca_pml_ucx_irecv_init(void *buf, size_t count, ompi_datatype_t *datatype,
525 int src, int tag, struct ompi_communicator_t* comm,
526 struct ompi_request_t **request)
527 {
528 mca_pml_ucx_persistent_request_t *req;
529
530 req = (mca_pml_ucx_persistent_request_t *)PML_UCX_FREELIST_GET(&ompi_pml_ucx.persistent_reqs);
531 if (req == NULL) {
532 return OMPI_ERR_OUT_OF_RESOURCE;
533 }
534
535 PML_UCX_TRACE_RECV("irecv_init request *%p=%p", buf, count, datatype, src,
536 tag, comm, (void*)request, (void*)req);
537
538 req->ompi.req_state = OMPI_REQUEST_INACTIVE;
539 req->ompi.req_mpi_object.comm = comm;
540 req->flags = 0;
541 req->buffer = buf;
542 req->count = count;
543 req->datatype.datatype = mca_pml_ucx_get_datatype(datatype);
544
545 PML_UCX_MAKE_RECV_TAG(req->tag, req->recv.tag_mask, tag, src, comm);
546
547 *request = &req->ompi;
548 return OMPI_SUCCESS;
549 }
550
551 int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
552 int src, int tag, struct ompi_communicator_t* comm,
553 struct ompi_request_t **request)
554 {
555 ucp_tag_t ucp_tag, ucp_tag_mask;
556 ompi_request_t *req;
557
558 PML_UCX_TRACE_RECV("irecv request *%p", buf, count, datatype, src, tag, comm,
559 (void*)request);
560
561 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
562 req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, buf, count,
563 mca_pml_ucx_get_datatype(datatype),
564 ucp_tag, ucp_tag_mask,
565 mca_pml_ucx_recv_completion);
566 if (UCS_PTR_IS_ERR(req)) {
567 PML_UCX_ERROR("ucx recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
568 return OMPI_ERROR;
569 }
570
571 PML_UCX_VERBOSE(8, "got request %p", (void*)req);
572 req->req_mpi_object.comm = comm;
573 *request = req;
574 return OMPI_SUCCESS;
575 }
576
577 int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src,
578 int tag, struct ompi_communicator_t* comm,
579 ompi_status_public_t* mpi_status)
580 {
581 ucp_tag_t ucp_tag, ucp_tag_mask;
582 ucp_tag_recv_info_t info;
583 ucs_status_t status;
584 void *req;
585
586 PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv");
587
588
589 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
590 req = PML_UCX_REQ_ALLOCA();
591 status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
592 mca_pml_ucx_get_datatype(datatype),
593 ucp_tag, ucp_tag_mask, req);
594
595 MCA_COMMON_UCX_PROGRESS_LOOP(ompi_pml_ucx.ucp_worker) {
596 status = ucp_request_test(req, &info);
597 if (status != UCS_INPROGRESS) {
598 mca_pml_ucx_set_recv_status_safe(mpi_status, status, &info);
599 return OMPI_SUCCESS;
600 }
601 }
602 }
603
604 static inline const char *mca_pml_ucx_send_mode_name(mca_pml_base_send_mode_t mode)
605 {
606 switch (mode) {
607 case MCA_PML_BASE_SEND_SYNCHRONOUS:
608 return "sync";
609 case MCA_PML_BASE_SEND_COMPLETE:
610 return "complete";
611 case MCA_PML_BASE_SEND_BUFFERED:
612 return "buffered";
613 case MCA_PML_BASE_SEND_READY:
614 return "ready";
615 case MCA_PML_BASE_SEND_STANDARD:
616 return "standard";
617 case MCA_PML_BASE_SEND_SIZE:
618 return "size";
619 default:
620 return "unknown";
621 }
622 }
623
624 int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datatype,
625 int dst, int tag, mca_pml_base_send_mode_t mode,
626 struct ompi_communicator_t* comm,
627 struct ompi_request_t **request)
628 {
629 mca_pml_ucx_persistent_request_t *req;
630 ucp_ep_h ep;
631
632 req = (mca_pml_ucx_persistent_request_t *)PML_UCX_FREELIST_GET(&ompi_pml_ucx.persistent_reqs);
633 if (req == NULL) {
634 return OMPI_ERR_OUT_OF_RESOURCE;
635 }
636
637 PML_UCX_TRACE_SEND("isend_init request *%p=%p", buf, count, datatype, dst,
638 tag, mode, comm, (void*)request, (void*)req)
639
640 ep = mca_pml_ucx_get_ep(comm, dst);
641 if (OPAL_UNLIKELY(NULL == ep)) {
642 return OMPI_ERROR;
643 }
644
645 req->ompi.req_state = OMPI_REQUEST_INACTIVE;
646 req->ompi.req_mpi_object.comm = comm;
647 req->flags = MCA_PML_UCX_REQUEST_FLAG_SEND;
648 req->buffer = (void *)buf;
649 req->count = count;
650 req->tag = PML_UCX_MAKE_SEND_TAG(tag, comm);
651 req->send.mode = mode;
652 req->send.ep = ep;
653
654 if (MCA_PML_BASE_SEND_BUFFERED == mode) {
655 req->datatype.ompi_datatype = datatype;
656 OBJ_RETAIN(datatype);
657 } else {
658 req->datatype.datatype = mca_pml_ucx_get_datatype(datatype);
659 }
660
661 *request = &req->ompi;
662 return OMPI_SUCCESS;
663 }
664
665 static ucs_status_ptr_t
666 mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
667 ompi_datatype_t *datatype, uint64_t pml_tag)
668 {
669 ompi_request_t *req;
670 void *packed_data;
671 size_t packed_length;
672 size_t offset;
673 uint32_t iov_count;
674 struct iovec iov;
675 opal_convertor_t opal_conv;
676
677 OBJ_CONSTRUCT(&opal_conv, opal_convertor_t);
678 opal_convertor_copy_and_prepare_for_send(ompi_proc_local_proc->super.proc_convertor,
679 &datatype->super, count, buf, 0,
680 &opal_conv);
681 opal_convertor_get_packed_size(&opal_conv, &packed_length);
682
683 packed_data = mca_pml_base_bsend_request_alloc_buf(packed_length);
684 if (OPAL_UNLIKELY(NULL == packed_data)) {
685 OBJ_DESTRUCT(&opal_conv);
686 PML_UCX_ERROR("bsend: failed to allocate buffer");
687 return UCS_STATUS_PTR(OMPI_ERROR);
688 }
689
690 iov_count = 1;
691 iov.iov_base = packed_data;
692 iov.iov_len = packed_length;
693
694 PML_UCX_VERBOSE(8, "bsend of packed buffer %p len %zu", packed_data, packed_length);
695 offset = 0;
696 opal_convertor_set_position(&opal_conv, &offset);
697 if (0 > opal_convertor_pack(&opal_conv, &iov, &iov_count, &packed_length)) {
698 mca_pml_base_bsend_request_free(packed_data);
699 OBJ_DESTRUCT(&opal_conv);
700 PML_UCX_ERROR("bsend: failed to pack user datatype");
701 return UCS_STATUS_PTR(OMPI_ERROR);
702 }
703
704 OBJ_DESTRUCT(&opal_conv);
705
706 req = (ompi_request_t*)ucp_tag_send_nb(ep, packed_data, packed_length,
707 ucp_dt_make_contig(1), pml_tag,
708 mca_pml_ucx_bsend_completion);
709 if (NULL == req) {
710
711 mca_pml_base_bsend_request_free(packed_data);
712 return NULL;
713 }
714
715 if (OPAL_UNLIKELY(UCS_PTR_IS_ERR(req))) {
716 mca_pml_base_bsend_request_free(packed_data);
717 PML_UCX_ERROR("ucx bsend failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
718 return UCS_STATUS_PTR(OMPI_ERROR);
719 }
720
721 req->req_complete_cb_data = packed_data;
722 return NULL;
723 }
724
725 static inline ucs_status_ptr_t mca_pml_ucx_common_send(ucp_ep_h ep, const void *buf,
726 size_t count,
727 ompi_datatype_t *datatype,
728 ucp_datatype_t ucx_datatype,
729 ucp_tag_t tag,
730 mca_pml_base_send_mode_t mode,
731 ucp_send_callback_t cb)
732 {
733 if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
734 return mca_pml_ucx_bsend(ep, buf, count, datatype, tag);
735 } else if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_SYNCHRONOUS == mode)) {
736 return ucp_tag_send_sync_nb(ep, buf, count, ucx_datatype, tag, cb);
737 } else {
738 return ucp_tag_send_nb(ep, buf, count, ucx_datatype, tag, cb);
739 }
740 }
741
742 int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
743 int dst, int tag, mca_pml_base_send_mode_t mode,
744 struct ompi_communicator_t* comm,
745 struct ompi_request_t **request)
746 {
747 ompi_request_t *req;
748 ucp_ep_h ep;
749
750 PML_UCX_TRACE_SEND("i%ssend request *%p",
751 buf, count, datatype, dst, tag, mode, comm,
752 mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "",
753 (void*)request)
754
755 ep = mca_pml_ucx_get_ep(comm, dst);
756 if (OPAL_UNLIKELY(NULL == ep)) {
757 return OMPI_ERROR;
758 }
759
760 req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype,
761 mca_pml_ucx_get_datatype(datatype),
762 PML_UCX_MAKE_SEND_TAG(tag, comm), mode,
763 mca_pml_ucx_send_completion);
764
765 if (req == NULL) {
766 PML_UCX_VERBOSE(8, "returning completed request");
767 *request = &ompi_pml_ucx.completed_send_req;
768 return OMPI_SUCCESS;
769 } else if (!UCS_PTR_IS_ERR(req)) {
770 PML_UCX_VERBOSE(8, "got request %p", (void*)req);
771 req->req_mpi_object.comm = comm;
772 *request = req;
773 return OMPI_SUCCESS;
774 } else {
775 PML_UCX_ERROR("ucx send failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
776 return OMPI_ERROR;
777 }
778 }
779
780 static inline __opal_attribute_always_inline__ int
781 mca_pml_ucx_send_nb(ucp_ep_h ep, const void *buf, size_t count,
782 ompi_datatype_t *datatype, ucp_datatype_t ucx_datatype,
783 ucp_tag_t tag, mca_pml_base_send_mode_t mode,
784 ucp_send_callback_t cb)
785 {
786 ompi_request_t *req;
787
788 req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype,
789 mca_pml_ucx_get_datatype(datatype),
790 tag, mode, cb);
791 if (OPAL_LIKELY(req == NULL)) {
792 return OMPI_SUCCESS;
793 } else if (!UCS_PTR_IS_ERR(req)) {
794 PML_UCX_VERBOSE(8, "got request %p", (void*)req);
795 MCA_COMMON_UCX_WAIT_LOOP(req, ompi_pml_ucx.ucp_worker, "ucx send", ompi_request_free(&req));
796 } else {
797 PML_UCX_ERROR("ucx send failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
798 return OMPI_ERROR;
799 }
800 }
801
802 #if HAVE_DECL_UCP_TAG_SEND_NBR
803 static inline __opal_attribute_always_inline__ int
804 mca_pml_ucx_send_nbr(ucp_ep_h ep, const void *buf, size_t count,
805 ucp_datatype_t ucx_datatype, ucp_tag_t tag)
806
807 {
808 ucs_status_ptr_t req;
809 ucs_status_t status;
810
811
812 req = PML_UCX_REQ_ALLOCA();
813 status = ucp_tag_send_nbr(ep, buf, count, ucx_datatype, tag, req);
814 if (OPAL_LIKELY(status == UCS_OK)) {
815 return OMPI_SUCCESS;
816 }
817
818 MCA_COMMON_UCX_WAIT_LOOP(req, ompi_pml_ucx.ucp_worker, "ucx send", (void)0);
819 }
820 #endif
821
822 int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, int dst,
823 int tag, mca_pml_base_send_mode_t mode,
824 struct ompi_communicator_t* comm)
825 {
826 ucp_ep_h ep;
827
828 PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm,
829 mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send");
830
831 ep = mca_pml_ucx_get_ep(comm, dst);
832 if (OPAL_UNLIKELY(NULL == ep)) {
833 return OMPI_ERROR;
834 }
835
836 #if HAVE_DECL_UCP_TAG_SEND_NBR
837 if (OPAL_LIKELY((MCA_PML_BASE_SEND_BUFFERED != mode) &&
838 (MCA_PML_BASE_SEND_SYNCHRONOUS != mode))) {
839 return mca_pml_ucx_send_nbr(ep, buf, count,
840 mca_pml_ucx_get_datatype(datatype),
841 PML_UCX_MAKE_SEND_TAG(tag, comm));
842 }
843 #endif
844
845 return mca_pml_ucx_send_nb(ep, buf, count, datatype,
846 mca_pml_ucx_get_datatype(datatype),
847 PML_UCX_MAKE_SEND_TAG(tag, comm), mode,
848 mca_pml_ucx_send_completion);
849 }
850
851 int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
852 int *matched, ompi_status_public_t* mpi_status)
853 {
854 static unsigned progress_count = 0;
855
856 ucp_tag_t ucp_tag, ucp_tag_mask;
857 ucp_tag_recv_info_t info;
858 ucp_tag_message_h ucp_msg;
859
860 PML_UCX_TRACE_PROBE("iprobe", src, tag, comm);
861
862 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
863 ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask,
864 0, &info);
865 if (ucp_msg != NULL) {
866 *matched = 1;
867 mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
868 } else {
869 (++progress_count % opal_common_ucx.progress_iterations) ?
870 (void)ucp_worker_progress(ompi_pml_ucx.ucp_worker) : opal_progress();
871 *matched = 0;
872 }
873 return OMPI_SUCCESS;
874 }
875
876 int mca_pml_ucx_probe(int src, int tag, struct ompi_communicator_t* comm,
877 ompi_status_public_t* mpi_status)
878 {
879 ucp_tag_t ucp_tag, ucp_tag_mask;
880 ucp_tag_recv_info_t info;
881 ucp_tag_message_h ucp_msg;
882
883 PML_UCX_TRACE_PROBE("probe", src, tag, comm);
884
885 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
886
887 MCA_COMMON_UCX_PROGRESS_LOOP(ompi_pml_ucx.ucp_worker) {
888 ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag,
889 ucp_tag_mask, 0, &info);
890 if (ucp_msg != NULL) {
891 mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
892 return OMPI_SUCCESS;
893 }
894 }
895 }
896
897 int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
898 int *matched, struct ompi_message_t **message,
899 ompi_status_public_t* mpi_status)
900 {
901 static unsigned progress_count = 0;
902
903 ucp_tag_t ucp_tag, ucp_tag_mask;
904 ucp_tag_recv_info_t info;
905 ucp_tag_message_h ucp_msg;
906
907 PML_UCX_TRACE_PROBE("improbe", src, tag, comm);
908
909 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
910 ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask,
911 1, &info);
912 if (ucp_msg != NULL) {
913 PML_UCX_MESSAGE_NEW(comm, ucp_msg, &info, message);
914 PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg);
915 *matched = 1;
916 mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
917 } else {
918 (++progress_count % opal_common_ucx.progress_iterations) ?
919 (void)ucp_worker_progress(ompi_pml_ucx.ucp_worker) : opal_progress();
920 *matched = 0;
921 }
922 return OMPI_SUCCESS;
923 }
924
925 int mca_pml_ucx_mprobe(int src, int tag, struct ompi_communicator_t* comm,
926 struct ompi_message_t **message,
927 ompi_status_public_t* mpi_status)
928 {
929 ucp_tag_t ucp_tag, ucp_tag_mask;
930 ucp_tag_recv_info_t info;
931 ucp_tag_message_h ucp_msg;
932
933 PML_UCX_TRACE_PROBE("mprobe", src, tag, comm);
934
935 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
936 MCA_COMMON_UCX_PROGRESS_LOOP(ompi_pml_ucx.ucp_worker) {
937 ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask,
938 1, &info);
939 if (ucp_msg != NULL) {
940 PML_UCX_MESSAGE_NEW(comm, ucp_msg, &info, message);
941 PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg);
942 mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
943 return OMPI_SUCCESS;
944 }
945 }
946 }
947
948 int mca_pml_ucx_imrecv(void *buf, size_t count, ompi_datatype_t *datatype,
949 struct ompi_message_t **message,
950 struct ompi_request_t **request)
951 {
952 ompi_request_t *req;
953
954 PML_UCX_TRACE_MRECV("imrecv", buf, count, datatype, message);
955
956 req = (ompi_request_t*)ucp_tag_msg_recv_nb(ompi_pml_ucx.ucp_worker, buf, count,
957 mca_pml_ucx_get_datatype(datatype),
958 (*message)->req_ptr,
959 mca_pml_ucx_recv_completion);
960 if (UCS_PTR_IS_ERR(req)) {
961 PML_UCX_ERROR("ucx msg recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
962 return OMPI_ERROR;
963 }
964
965 PML_UCX_VERBOSE(8, "got request %p", (void*)req);
966 PML_UCX_MESSAGE_RELEASE(message);
967 *request = req;
968 return OMPI_SUCCESS;
969 }
970
971 int mca_pml_ucx_mrecv(void *buf, size_t count, ompi_datatype_t *datatype,
972 struct ompi_message_t **message,
973 ompi_status_public_t* status)
974 {
975 ompi_request_t *req;
976
977 PML_UCX_TRACE_MRECV("mrecv", buf, count, datatype, message);
978
979 req = (ompi_request_t*)ucp_tag_msg_recv_nb(ompi_pml_ucx.ucp_worker, buf, count,
980 mca_pml_ucx_get_datatype(datatype),
981 (*message)->req_ptr,
982 mca_pml_ucx_recv_completion);
983 if (UCS_PTR_IS_ERR(req)) {
984 PML_UCX_ERROR("ucx msg recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
985 return OMPI_ERROR;
986 }
987
988 PML_UCX_MESSAGE_RELEASE(message);
989
990 ompi_request_wait(&req, status);
991 return OMPI_SUCCESS;
992 }
993
994 int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
995 {
996 mca_pml_ucx_persistent_request_t *preq;
997 ompi_request_t *tmp_req;
998 size_t i;
999
1000 for (i = 0; i < count; ++i) {
1001 preq = (mca_pml_ucx_persistent_request_t *)requests[i];
1002
1003 if ((preq == NULL) || (OMPI_REQUEST_PML != preq->ompi.req_type)) {
1004
1005 continue;
1006 }
1007
1008 PML_UCX_ASSERT(preq->ompi.req_state != OMPI_REQUEST_INVALID);
1009 preq->ompi.req_state = OMPI_REQUEST_ACTIVE;
1010 mca_pml_ucx_request_reset(&preq->ompi);
1011
1012 if (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) {
1013 tmp_req = (ompi_request_t*)mca_pml_ucx_common_send(preq->send.ep,
1014 preq->buffer,
1015 preq->count,
1016 preq->datatype.ompi_datatype,
1017 preq->datatype.datatype,
1018 preq->tag,
1019 preq->send.mode,
1020 mca_pml_ucx_psend_completion);
1021 } else {
1022 PML_UCX_VERBOSE(8, "start recv request %p", (void*)preq);
1023 tmp_req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker,
1024 preq->buffer, preq->count,
1025 preq->datatype.datatype,
1026 preq->tag,
1027 preq->recv.tag_mask,
1028 mca_pml_ucx_precv_completion);
1029 }
1030
1031 if (tmp_req == NULL) {
1032
1033 PML_UCX_ASSERT(preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND);
1034
1035 PML_UCX_VERBOSE(8, "send completed immediately, completing persistent request %p",
1036 (void*)preq);
1037 mca_pml_ucx_set_send_status(&preq->ompi.req_status, UCS_OK);
1038 ompi_request_complete(&preq->ompi, true);
1039 } else if (!UCS_PTR_IS_ERR(tmp_req)) {
1040 if (REQUEST_COMPLETE(tmp_req)) {
1041
1042 PML_UCX_VERBOSE(8, "completing persistent request %p", (void*)preq);
1043 mca_pml_ucx_persistent_request_complete(preq, tmp_req);
1044 } else {
1045
1046
1047 PML_UCX_VERBOSE(8, "temporary request %p will complete persistent request %p",
1048 (void*)tmp_req, (void*)preq);
1049 tmp_req->req_complete_cb_data = preq;
1050 preq->tmp_req = tmp_req;
1051 }
1052 } else {
1053 PML_UCX_ERROR("ucx %s failed: %s",
1054 (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) ? "send" : "recv",
1055 ucs_status_string(UCS_PTR_STATUS(tmp_req)));
1056 return OMPI_ERROR;
1057 }
1058 }
1059
1060 return OMPI_SUCCESS;
1061 }
1062
1063 int mca_pml_ucx_dump(struct ompi_communicator_t* comm, int verbose)
1064 {
1065 return OMPI_SUCCESS;
1066 }