This source file includes following definitions.
- check_sync_state
- create_iov_list
- ddt_put_get
- start_atomicity
- end_atomicity
- get_dynamic_win_info
- ompi_osc_ucx_put
- ompi_osc_ucx_get
- ompi_osc_ucx_accumulate
- ompi_osc_ucx_compare_and_swap
- ompi_osc_ucx_fetch_and_op
- ompi_osc_ucx_get_accumulate
- ompi_osc_ucx_rput
- ompi_osc_ucx_rget
- ompi_osc_ucx_raccumulate
- ompi_osc_ucx_rget_accumulate
1
2
3
4
5
6
7
8
9
10 #include "ompi_config.h"
11
12 #include "ompi/mca/osc/osc.h"
13 #include "ompi/mca/osc/base/base.h"
14 #include "ompi/mca/osc/base/osc_base_obj_convert.h"
15 #include "opal/mca/common/ucx/common_ucx.h"
16
17 #include "osc_ucx.h"
18 #include "osc_ucx_request.h"
19
20
21 #define CHECK_VALID_RKEY(_module, _target, _count) \
22 if (!((_module)->win_info_array[_target]).rkey_init && ((_count) > 0)) { \
23 OSC_UCX_VERBOSE(1, "window with non-zero length does not have an rkey"); \
24 return OMPI_ERROR; \
25 }
26
27 typedef struct ucx_iovec {
28 void *addr;
29 size_t len;
30 } ucx_iovec_t;
31
32 static inline int check_sync_state(ompi_osc_ucx_module_t *module, int target,
33 bool is_req_ops) {
34 if (is_req_ops == false) {
35 if (module->epoch_type.access == NONE_EPOCH) {
36 return OMPI_ERR_RMA_SYNC;
37 } else if (module->epoch_type.access == START_COMPLETE_EPOCH) {
38 int i, size = ompi_group_size(module->start_group);
39 for (i = 0; i < size; i++) {
40 if (module->start_grp_ranks[i] == target) {
41 break;
42 }
43 }
44 if (i == size) {
45 return OMPI_ERR_RMA_SYNC;
46 }
47 } else if (module->epoch_type.access == PASSIVE_EPOCH) {
48 ompi_osc_ucx_lock_t *item = NULL;
49 opal_hash_table_get_value_uint32(&module->outstanding_locks, (uint32_t) target, (void **) &item);
50 if (item == NULL) {
51 return OMPI_ERR_RMA_SYNC;
52 }
53 }
54 } else {
55 if (module->epoch_type.access != PASSIVE_EPOCH &&
56 module->epoch_type.access != PASSIVE_ALL_EPOCH) {
57 return OMPI_ERR_RMA_SYNC;
58 } else if (module->epoch_type.access == PASSIVE_EPOCH) {
59 ompi_osc_ucx_lock_t *item = NULL;
60 opal_hash_table_get_value_uint32(&module->outstanding_locks, (uint32_t) target, (void **) &item);
61 if (item == NULL) {
62 return OMPI_ERR_RMA_SYNC;
63 }
64 }
65 }
66 return OMPI_SUCCESS;
67 }
68
69 static inline int create_iov_list(const void *addr, int count, ompi_datatype_t *datatype,
70 ucx_iovec_t **ucx_iov, uint32_t *ucx_iov_count) {
71 int ret = OMPI_SUCCESS;
72 size_t size;
73 bool done = false;
74 opal_convertor_t convertor;
75 uint32_t iov_count, iov_idx;
76 struct iovec iov[OSC_UCX_IOVEC_MAX];
77 uint32_t ucx_iov_idx;
78
79 OBJ_CONSTRUCT(&convertor, opal_convertor_t);
80 ret = opal_convertor_copy_and_prepare_for_send(ompi_mpi_local_convertor,
81 &datatype->super, count,
82 addr, 0, &convertor);
83 if (ret != OMPI_SUCCESS) {
84 return ret;
85 }
86
87 (*ucx_iov_count) = 0;
88 ucx_iov_idx = 0;
89
90 do {
91 iov_count = OSC_UCX_IOVEC_MAX;
92 iov_idx = 0;
93
94 done = opal_convertor_raw(&convertor, iov, &iov_count, &size);
95
96 (*ucx_iov_count) += iov_count;
97 (*ucx_iov) = (ucx_iovec_t *)realloc((*ucx_iov), (*ucx_iov_count) * sizeof(ucx_iovec_t));
98 if (*ucx_iov == NULL) {
99 return OMPI_ERR_TEMP_OUT_OF_RESOURCE;
100 }
101
102 while (iov_idx != iov_count) {
103 (*ucx_iov)[ucx_iov_idx].addr = iov[iov_idx].iov_base;
104 (*ucx_iov)[ucx_iov_idx].len = iov[iov_idx].iov_len;
105 ucx_iov_idx++;
106 iov_idx++;
107 }
108
109 assert((*ucx_iov_count) == ucx_iov_idx);
110
111 } while (!done);
112
113 opal_convertor_cleanup(&convertor);
114 OBJ_DESTRUCT(&convertor);
115
116 return ret;
117 }
118
119 static inline int ddt_put_get(ompi_osc_ucx_module_t *module,
120 const void *origin_addr, int origin_count,
121 struct ompi_datatype_t *origin_dt,
122 bool is_origin_contig, ptrdiff_t origin_lb,
123 int target, uint64_t remote_addr,
124 int target_count, struct ompi_datatype_t *target_dt,
125 bool is_target_contig, ptrdiff_t target_lb, bool is_get) {
126 ucx_iovec_t *origin_ucx_iov = NULL, *target_ucx_iov = NULL;
127 uint32_t origin_ucx_iov_count = 0, target_ucx_iov_count = 0;
128 uint32_t origin_ucx_iov_idx = 0, target_ucx_iov_idx = 0;
129 int status;
130 int ret = OMPI_SUCCESS;
131
132 if (!is_origin_contig) {
133 ret = create_iov_list(origin_addr, origin_count, origin_dt,
134 &origin_ucx_iov, &origin_ucx_iov_count);
135 if (ret != OMPI_SUCCESS) {
136 return ret;
137 }
138 }
139
140 if (!is_target_contig) {
141 ret = create_iov_list(NULL, target_count, target_dt,
142 &target_ucx_iov, &target_ucx_iov_count);
143 if (ret != OMPI_SUCCESS) {
144 return ret;
145 }
146 }
147
148 if (!is_origin_contig && !is_target_contig) {
149 size_t curr_len = 0;
150 opal_common_ucx_op_t op;
151 while (origin_ucx_iov_idx < origin_ucx_iov_count) {
152 curr_len = MIN(origin_ucx_iov[origin_ucx_iov_idx].len,
153 target_ucx_iov[target_ucx_iov_idx].len);
154 if (is_get) {
155 op = OPAL_COMMON_UCX_GET;
156 } else {
157 op = OPAL_COMMON_UCX_PUT;
158 }
159 status = opal_common_ucx_wpmem_putget(module->mem, op, target,
160 origin_ucx_iov[origin_ucx_iov_idx].addr, curr_len,
161 remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr));
162 if (OPAL_SUCCESS != status) {
163 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status);
164 return OMPI_ERROR;
165 }
166
167 origin_ucx_iov[origin_ucx_iov_idx].addr = (void *)((intptr_t)origin_ucx_iov[origin_ucx_iov_idx].addr + curr_len);
168 target_ucx_iov[target_ucx_iov_idx].addr = (void *)((intptr_t)target_ucx_iov[target_ucx_iov_idx].addr + curr_len);
169
170 origin_ucx_iov[origin_ucx_iov_idx].len -= curr_len;
171 if (origin_ucx_iov[origin_ucx_iov_idx].len == 0) {
172 origin_ucx_iov_idx++;
173 }
174 target_ucx_iov[target_ucx_iov_idx].len -= curr_len;
175 if (target_ucx_iov[target_ucx_iov_idx].len == 0) {
176 target_ucx_iov_idx++;
177 }
178 }
179
180 assert(origin_ucx_iov_idx == origin_ucx_iov_count &&
181 target_ucx_iov_idx == target_ucx_iov_count);
182
183 } else if (!is_origin_contig) {
184 size_t prev_len = 0;
185 opal_common_ucx_op_t op;
186 while (origin_ucx_iov_idx < origin_ucx_iov_count) {
187 if (is_get) {
188 op = OPAL_COMMON_UCX_GET;
189 } else {
190 op = OPAL_COMMON_UCX_PUT;
191 }
192 status = opal_common_ucx_wpmem_putget(module->mem, op, target,
193 origin_ucx_iov[origin_ucx_iov_idx].addr,
194 origin_ucx_iov[origin_ucx_iov_idx].len,
195 remote_addr + target_lb + prev_len);
196 if (OPAL_SUCCESS != status) {
197 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status);
198 return OMPI_ERROR;
199 }
200
201 prev_len += origin_ucx_iov[origin_ucx_iov_idx].len;
202 origin_ucx_iov_idx++;
203 }
204 } else {
205 size_t prev_len = 0;
206 opal_common_ucx_op_t op;
207 while (target_ucx_iov_idx < target_ucx_iov_count) {
208 if (is_get) {
209 op = OPAL_COMMON_UCX_GET;
210 } else {
211 op = OPAL_COMMON_UCX_PUT;
212 }
213
214 status = opal_common_ucx_wpmem_putget(module->mem, op, target,
215 (void *)((intptr_t)origin_addr + origin_lb + prev_len),
216 target_ucx_iov[target_ucx_iov_idx].len,
217 remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr));
218 if (OPAL_SUCCESS != status) {
219 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status);
220 return OMPI_ERROR;
221 }
222
223 prev_len += target_ucx_iov[target_ucx_iov_idx].len;
224 target_ucx_iov_idx++;
225 }
226 }
227
228 if (origin_ucx_iov != NULL) {
229 free(origin_ucx_iov);
230 }
231 if (target_ucx_iov != NULL) {
232 free(target_ucx_iov);
233 }
234
235 return ret;
236 }
237
238 static inline int start_atomicity(ompi_osc_ucx_module_t *module, int target) {
239 uint64_t result_value = -1;
240 uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_ACC_LOCK_OFFSET;
241 int ret = OMPI_SUCCESS;
242
243 for (;;) {
244 ret = opal_common_ucx_wpmem_cmpswp(module->state_mem,
245 TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE,
246 target, &result_value, sizeof(result_value),
247 remote_addr);
248 if (ret != OMPI_SUCCESS) {
249 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_cmpswp failed: %d", ret);
250 return OMPI_ERROR;
251 }
252 if (result_value == TARGET_LOCK_UNLOCKED) {
253 return OMPI_SUCCESS;
254 }
255
256 ucp_worker_progress(mca_osc_ucx_component.wpool->dflt_worker);
257 }
258 }
259
260 static inline int end_atomicity(ompi_osc_ucx_module_t *module, int target) {
261 uint64_t result_value = 0;
262 uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_ACC_LOCK_OFFSET;
263 int ret = OMPI_SUCCESS;
264
265 ret = opal_common_ucx_wpmem_fetch(module->state_mem,
266 UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED,
267 target, &result_value, sizeof(result_value),
268 remote_addr);
269 if (ret != OMPI_SUCCESS) {
270 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fetch failed: %d", ret);
271 return OMPI_ERROR;
272 }
273
274 assert(result_value == TARGET_LOCK_EXCLUSIVE);
275
276 return ret;
277 }
278
279 static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module_t *module,
280 int target) {
281 uint64_t remote_state_addr = (module->state_addrs)[target] + OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET;
282 size_t len = sizeof(uint64_t) + sizeof(ompi_osc_dynamic_win_info_t) * OMPI_OSC_UCX_ATTACH_MAX;
283 char *temp_buf = malloc(len);
284 ompi_osc_dynamic_win_info_t *temp_dynamic_wins;
285 uint64_t win_count;
286 int contain, insert = -1;
287 int ret;
288
289 ret = opal_common_ucx_wpmem_putget(module->state_mem, OPAL_COMMON_UCX_GET, target,
290 (void *)((intptr_t)temp_buf),
291 len, remote_state_addr);
292 if (OPAL_SUCCESS != ret) {
293 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
294 return OMPI_ERROR;
295 }
296
297 ret = opal_common_ucx_wpmem_flush(module->state_mem, OPAL_COMMON_UCX_SCOPE_EP, target);
298 if (ret != OMPI_SUCCESS) {
299 return ret;
300 }
301
302 memcpy(&win_count, temp_buf, sizeof(uint64_t));
303 assert(win_count > 0 && win_count <= OMPI_OSC_UCX_ATTACH_MAX);
304
305 temp_dynamic_wins = (ompi_osc_dynamic_win_info_t *)(temp_buf + sizeof(uint64_t));
306 contain = ompi_osc_find_attached_region_position(temp_dynamic_wins, 0, win_count,
307 remote_addr, 1, &insert);
308 assert(contain >= 0 && (uint64_t)contain < win_count);
309
310 if (module->local_dynamic_win_info[contain].mem->mem_addrs == NULL) {
311 module->local_dynamic_win_info[contain].mem->mem_addrs = calloc(ompi_comm_size(module->comm),
312 OMPI_OSC_UCX_MEM_ADDR_MAX_LEN);
313 module->local_dynamic_win_info[contain].mem->mem_displs =calloc(ompi_comm_size(module->comm),
314 sizeof(int));
315 }
316
317 memcpy(module->local_dynamic_win_info[contain].mem->mem_addrs + target * OMPI_OSC_UCX_MEM_ADDR_MAX_LEN,
318 temp_dynamic_wins[contain].mem_addr, OMPI_OSC_UCX_MEM_ADDR_MAX_LEN);
319 module->local_dynamic_win_info[contain].mem->mem_displs[target] = target * OMPI_OSC_UCX_MEM_ADDR_MAX_LEN;
320
321 free(temp_buf);
322
323 return ret;
324 }
325
326 int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt,
327 int target, ptrdiff_t target_disp, int target_count,
328 struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
329 ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
330 uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
331 bool is_origin_contig = false, is_target_contig = false;
332 ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
333 int ret = OMPI_SUCCESS;
334
335 ret = check_sync_state(module, target, false);
336 if (ret != OMPI_SUCCESS) {
337 return ret;
338 }
339
340 if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
341 ret = get_dynamic_win_info(remote_addr, module, target);
342 if (ret != OMPI_SUCCESS) {
343 return ret;
344 }
345 }
346
347 if (!target_count) {
348 return OMPI_SUCCESS;
349 }
350
351 ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
352 ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);
353
354 is_origin_contig = ompi_datatype_is_contiguous_memory_layout(origin_dt, origin_count);
355 is_target_contig = ompi_datatype_is_contiguous_memory_layout(target_dt, target_count);
356
357 if (is_origin_contig && is_target_contig) {
358
359 size_t origin_len;
360
361 ompi_datatype_type_size(origin_dt, &origin_len);
362 origin_len *= origin_count;
363
364 ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_PUT, target,
365 (void *)((intptr_t)origin_addr + origin_lb),
366 origin_len, remote_addr + target_lb);
367 if (OPAL_SUCCESS != ret) {
368 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
369 return OMPI_ERROR;
370 }
371 return ret;
372 } else {
373 return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig,
374 origin_lb, target, remote_addr, target_count, target_dt,
375 is_target_contig, target_lb, false);
376 }
377 }
378
379 int ompi_osc_ucx_get(void *origin_addr, int origin_count,
380 struct ompi_datatype_t *origin_dt,
381 int target, ptrdiff_t target_disp, int target_count,
382 struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
383 ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
384 uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
385 ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
386 bool is_origin_contig = false, is_target_contig = false;
387 int ret = OMPI_SUCCESS;
388
389 ret = check_sync_state(module, target, false);
390 if (ret != OMPI_SUCCESS) {
391 return ret;
392 }
393
394 if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
395 ret = get_dynamic_win_info(remote_addr, module, target);
396 if (ret != OMPI_SUCCESS) {
397 return ret;
398 }
399 }
400
401 if (!target_count) {
402 return OMPI_SUCCESS;
403 }
404
405
406 ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
407 ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);
408
409 is_origin_contig = ompi_datatype_is_contiguous_memory_layout(origin_dt, origin_count);
410 is_target_contig = ompi_datatype_is_contiguous_memory_layout(target_dt, target_count);
411
412 if (is_origin_contig && is_target_contig) {
413
414 size_t origin_len;
415
416 ompi_datatype_type_size(origin_dt, &origin_len);
417 origin_len *= origin_count;
418
419 ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET, target,
420 (void *)((intptr_t)origin_addr + origin_lb),
421 origin_len, remote_addr + target_lb);
422 if (OPAL_SUCCESS != ret) {
423 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
424 return OMPI_ERROR;
425 }
426
427 return ret;
428 } else {
429 return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig,
430 origin_lb, target, remote_addr, target_count, target_dt,
431 is_target_contig, target_lb, true);
432 }
433 }
434
435 int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
436 struct ompi_datatype_t *origin_dt,
437 int target, ptrdiff_t target_disp, int target_count,
438 struct ompi_datatype_t *target_dt,
439 struct ompi_op_t *op, struct ompi_win_t *win) {
440 ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
441 int ret = OMPI_SUCCESS;
442
443 ret = check_sync_state(module, target, false);
444 if (ret != OMPI_SUCCESS) {
445 return ret;
446 }
447
448 if (op == &ompi_mpi_op_no_op.op) {
449 return ret;
450 }
451
452 ret = start_atomicity(module, target);
453 if (ret != OMPI_SUCCESS) {
454 return ret;
455 }
456
457 if (op == &ompi_mpi_op_replace.op) {
458 ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt, target,
459 target_disp, target_count, target_dt, win);
460 if (ret != OMPI_SUCCESS) {
461 return ret;
462 }
463 } else {
464 void *temp_addr_holder = NULL;
465 void *temp_addr = NULL;
466 uint32_t temp_count;
467 ompi_datatype_t *temp_dt;
468 ptrdiff_t temp_lb, temp_extent;
469 bool is_origin_contig = ompi_datatype_is_contiguous_memory_layout(origin_dt, origin_count);
470
471 if (ompi_datatype_is_predefined(target_dt)) {
472 temp_dt = target_dt;
473 temp_count = target_count;
474 } else {
475 ret = ompi_osc_base_get_primitive_type_info(target_dt, &temp_dt, &temp_count);
476 if (ret != OMPI_SUCCESS) {
477 return ret;
478 }
479 }
480 ompi_datatype_get_true_extent(temp_dt, &temp_lb, &temp_extent);
481 temp_addr = temp_addr_holder = malloc(temp_extent * temp_count);
482 if (temp_addr == NULL) {
483 return OMPI_ERR_TEMP_OUT_OF_RESOURCE;
484 }
485
486 ret = ompi_osc_ucx_get(temp_addr, (int)temp_count, temp_dt,
487 target, target_disp, target_count, target_dt, win);
488 if (ret != OMPI_SUCCESS) {
489 return ret;
490 }
491
492 ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
493 if (ret != OMPI_SUCCESS) {
494 return ret;
495 }
496
497 if (ompi_datatype_is_predefined(origin_dt) || is_origin_contig) {
498 ompi_op_reduce(op, (void *)origin_addr, temp_addr, (int)temp_count, temp_dt);
499 } else {
500 ucx_iovec_t *origin_ucx_iov = NULL;
501 uint32_t origin_ucx_iov_count = 0;
502 uint32_t origin_ucx_iov_idx = 0;
503
504 ret = create_iov_list(origin_addr, origin_count, origin_dt,
505 &origin_ucx_iov, &origin_ucx_iov_count);
506 if (ret != OMPI_SUCCESS) {
507 return ret;
508 }
509
510 if ((op != &ompi_mpi_op_maxloc.op && op != &ompi_mpi_op_minloc.op) ||
511 ompi_datatype_is_contiguous_memory_layout(temp_dt, temp_count)) {
512 size_t temp_size;
513 ompi_datatype_type_size(temp_dt, &temp_size);
514 while (origin_ucx_iov_idx < origin_ucx_iov_count) {
515 int curr_count = origin_ucx_iov[origin_ucx_iov_idx].len / temp_size;
516 ompi_op_reduce(op, origin_ucx_iov[origin_ucx_iov_idx].addr,
517 temp_addr, curr_count, temp_dt);
518 temp_addr = (void *)((char *)temp_addr + curr_count * temp_size);
519 origin_ucx_iov_idx++;
520 }
521 } else {
522 int i;
523 void *curr_origin_addr = origin_ucx_iov[origin_ucx_iov_idx].addr;
524 for (i = 0; i < (int)temp_count; i++) {
525 ompi_op_reduce(op, curr_origin_addr,
526 (void *)((char *)temp_addr + i * temp_extent),
527 1, temp_dt);
528 curr_origin_addr = (void *)((char *)curr_origin_addr + temp_extent);
529 origin_ucx_iov_idx++;
530 if (curr_origin_addr >= (void *)((char *)origin_ucx_iov[origin_ucx_iov_idx].addr + origin_ucx_iov[origin_ucx_iov_idx].len)) {
531 origin_ucx_iov_idx++;
532 curr_origin_addr = origin_ucx_iov[origin_ucx_iov_idx].addr;
533 }
534 }
535 }
536
537 free(origin_ucx_iov);
538 }
539
540 ret = ompi_osc_ucx_put(temp_addr, (int)temp_count, temp_dt, target, target_disp,
541 target_count, target_dt, win);
542 if (ret != OMPI_SUCCESS) {
543 return ret;
544 }
545
546 ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
547 if (ret != OMPI_SUCCESS) {
548 return ret;
549 }
550
551 free(temp_addr_holder);
552 }
553
554 return end_atomicity(module, target);
555 }
556
557 int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr,
558 void *result_addr, struct ompi_datatype_t *dt,
559 int target, ptrdiff_t target_disp,
560 struct ompi_win_t *win) {
561 ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module;
562 uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
563 size_t dt_bytes;
564 int ret = OMPI_SUCCESS;
565
566 ret = check_sync_state(module, target, false);
567 if (ret != OMPI_SUCCESS) {
568 return ret;
569 }
570
571 ret = start_atomicity(module, target);
572 if (ret != OMPI_SUCCESS) {
573 return ret;
574 }
575
576 if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
577 ret = get_dynamic_win_info(remote_addr, module, target);
578 if (ret != OMPI_SUCCESS) {
579 return ret;
580 }
581 }
582
583 ompi_datatype_type_size(dt, &dt_bytes);
584 ret = opal_common_ucx_wpmem_cmpswp(module->mem,*(uint64_t *)compare_addr,
585 *(uint64_t *)origin_addr, target,
586 result_addr, dt_bytes, remote_addr);
587 if (ret != OMPI_SUCCESS) {
588 return ret;
589 }
590
591 return end_atomicity(module, target);
592 }
593
594 int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
595 struct ompi_datatype_t *dt, int target,
596 ptrdiff_t target_disp, struct ompi_op_t *op,
597 struct ompi_win_t *win) {
598 ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
599 int ret = OMPI_SUCCESS;
600
601 ret = check_sync_state(module, target, false);
602 if (ret != OMPI_SUCCESS) {
603 return ret;
604 }
605
606 if (op == &ompi_mpi_op_no_op.op || op == &ompi_mpi_op_replace.op ||
607 op == &ompi_mpi_op_sum.op) {
608 uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
609 uint64_t value = origin_addr ? *(uint64_t *)origin_addr : 0;
610 ucp_atomic_fetch_op_t opcode;
611 size_t dt_bytes;
612
613 ret = start_atomicity(module, target);
614 if (ret != OMPI_SUCCESS) {
615 return ret;
616 }
617
618 if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
619 ret = get_dynamic_win_info(remote_addr, module, target);
620 if (ret != OMPI_SUCCESS) {
621 return ret;
622 }
623 }
624
625 ompi_datatype_type_size(dt, &dt_bytes);
626
627 if (op == &ompi_mpi_op_replace.op) {
628 opcode = UCP_ATOMIC_FETCH_OP_SWAP;
629 } else {
630 opcode = UCP_ATOMIC_FETCH_OP_FADD;
631 if (op == &ompi_mpi_op_no_op.op) {
632 value = 0;
633 }
634 }
635
636 ret = opal_common_ucx_wpmem_fetch(module->mem, opcode, value, target,
637 (void *)result_addr, dt_bytes, remote_addr);
638 if (ret != OMPI_SUCCESS) {
639 return ret;
640 }
641
642 return end_atomicity(module, target);
643 } else {
644 return ompi_osc_ucx_get_accumulate(origin_addr, 1, dt, result_addr, 1, dt,
645 target, target_disp, 1, dt, op, win);
646 }
647 }
648
649 int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
650 struct ompi_datatype_t *origin_dt,
651 void *result_addr, int result_count,
652 struct ompi_datatype_t *result_dt,
653 int target, ptrdiff_t target_disp,
654 int target_count, struct ompi_datatype_t *target_dt,
655 struct ompi_op_t *op, struct ompi_win_t *win) {
656 ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
657 int ret = OMPI_SUCCESS;
658
659 ret = check_sync_state(module, target, false);
660 if (ret != OMPI_SUCCESS) {
661 return ret;
662 }
663
664 ret = start_atomicity(module, target);
665 if (ret != OMPI_SUCCESS) {
666 return ret;
667 }
668
669 ret = ompi_osc_ucx_get(result_addr, result_count, result_dt, target,
670 target_disp, target_count, target_dt, win);
671 if (ret != OMPI_SUCCESS) {
672 return ret;
673 }
674
675 if (op != &ompi_mpi_op_no_op.op) {
676 if (op == &ompi_mpi_op_replace.op) {
677 ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt,
678 target, target_disp, target_count,
679 target_dt, win);
680 if (ret != OMPI_SUCCESS) {
681 return ret;
682 }
683 } else {
684 void *temp_addr_holder = NULL;
685 void *temp_addr = NULL;
686 uint32_t temp_count;
687 ompi_datatype_t *temp_dt;
688 ptrdiff_t temp_lb, temp_extent;
689 bool is_origin_contig = ompi_datatype_is_contiguous_memory_layout(origin_dt, origin_count);
690
691 if (ompi_datatype_is_predefined(target_dt)) {
692 temp_dt = target_dt;
693 temp_count = target_count;
694 } else {
695 ret = ompi_osc_base_get_primitive_type_info(target_dt, &temp_dt, &temp_count);
696 if (ret != OMPI_SUCCESS) {
697 return ret;
698 }
699 }
700 ompi_datatype_get_true_extent(temp_dt, &temp_lb, &temp_extent);
701 temp_addr = temp_addr_holder = malloc(temp_extent * temp_count);
702 if (temp_addr == NULL) {
703 return OMPI_ERR_TEMP_OUT_OF_RESOURCE;
704 }
705
706 ret = ompi_osc_ucx_get(temp_addr, (int)temp_count, temp_dt,
707 target, target_disp, target_count, target_dt, win);
708 if (ret != OMPI_SUCCESS) {
709 return ret;
710 }
711
712 ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
713 if (ret != OMPI_SUCCESS) {
714 return ret;
715 }
716
717 if (ompi_datatype_is_predefined(origin_dt) || is_origin_contig) {
718 ompi_op_reduce(op, (void *)origin_addr, temp_addr, (int)temp_count, temp_dt);
719 } else {
720 ucx_iovec_t *origin_ucx_iov = NULL;
721 uint32_t origin_ucx_iov_count = 0;
722 uint32_t origin_ucx_iov_idx = 0;
723
724 ret = create_iov_list(origin_addr, origin_count, origin_dt,
725 &origin_ucx_iov, &origin_ucx_iov_count);
726 if (ret != OMPI_SUCCESS) {
727 return ret;
728 }
729
730 if ((op != &ompi_mpi_op_maxloc.op && op != &ompi_mpi_op_minloc.op) ||
731 ompi_datatype_is_contiguous_memory_layout(temp_dt, temp_count)) {
732 size_t temp_size;
733 ompi_datatype_type_size(temp_dt, &temp_size);
734 while (origin_ucx_iov_idx < origin_ucx_iov_count) {
735 int curr_count = origin_ucx_iov[origin_ucx_iov_idx].len / temp_size;
736 ompi_op_reduce(op, origin_ucx_iov[origin_ucx_iov_idx].addr,
737 temp_addr, curr_count, temp_dt);
738 temp_addr = (void *)((char *)temp_addr + curr_count * temp_size);
739 origin_ucx_iov_idx++;
740 }
741 } else {
742 int i;
743 void *curr_origin_addr = origin_ucx_iov[origin_ucx_iov_idx].addr;
744 for (i = 0; i < (int)temp_count; i++) {
745 ompi_op_reduce(op, curr_origin_addr,
746 (void *)((char *)temp_addr + i * temp_extent),
747 1, temp_dt);
748 curr_origin_addr = (void *)((char *)curr_origin_addr + temp_extent);
749 origin_ucx_iov_idx++;
750 if (curr_origin_addr >= (void *)((char *)origin_ucx_iov[origin_ucx_iov_idx].addr + origin_ucx_iov[origin_ucx_iov_idx].len)) {
751 origin_ucx_iov_idx++;
752 curr_origin_addr = origin_ucx_iov[origin_ucx_iov_idx].addr;
753 }
754 }
755 }
756 free(origin_ucx_iov);
757 }
758
759 ret = ompi_osc_ucx_put(temp_addr, (int)temp_count, temp_dt, target, target_disp,
760 target_count, target_dt, win);
761 if (ret != OMPI_SUCCESS) {
762 return ret;
763 }
764
765 ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
766 if (ret != OMPI_SUCCESS) {
767 return ret;
768 }
769
770 free(temp_addr_holder);
771 }
772 }
773
774 return end_atomicity(module, target);
775 }
776
777 int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
778 struct ompi_datatype_t *origin_dt,
779 int target, ptrdiff_t target_disp, int target_count,
780 struct ompi_datatype_t *target_dt,
781 struct ompi_win_t *win, struct ompi_request_t **request) {
782 ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
783 uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
784 ompi_osc_ucx_request_t *ucx_req = NULL;
785 int ret = OMPI_SUCCESS;
786
787 ret = check_sync_state(module, target, true);
788 if (ret != OMPI_SUCCESS) {
789 return ret;
790 }
791
792 if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
793 ret = get_dynamic_win_info(remote_addr, module, target);
794 if (ret != OMPI_SUCCESS) {
795 return ret;
796 }
797 }
798
799 OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
800 assert(NULL != ucx_req);
801
802 ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt, target, target_disp,
803 target_count, target_dt, win);
804 if (ret != OMPI_SUCCESS) {
805 return ret;
806 }
807
808 ret = opal_common_ucx_wpmem_fence(module->mem);
809 if (ret != OMPI_SUCCESS) {
810 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret);
811 return OMPI_ERROR;
812 }
813
814 mca_osc_ucx_component.num_incomplete_req_ops++;
815 ret = opal_common_ucx_wpmem_fetch_nb(module->mem, UCP_ATOMIC_FETCH_OP_FADD,
816 0, target, &(module->req_result),
817 sizeof(uint64_t), remote_addr,
818 req_completion, ucx_req);
819 if (ret != OMPI_SUCCESS) {
820 return ret;
821 }
822
823 *request = &ucx_req->super;
824
825 return ret;
826 }
827
828 int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
829 struct ompi_datatype_t *origin_dt,
830 int target, ptrdiff_t target_disp, int target_count,
831 struct ompi_datatype_t *target_dt, struct ompi_win_t *win,
832 struct ompi_request_t **request) {
833 ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
834 uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
835 ompi_osc_ucx_request_t *ucx_req = NULL;
836 int ret = OMPI_SUCCESS;
837
838 ret = check_sync_state(module, target, true);
839 if (ret != OMPI_SUCCESS) {
840 return ret;
841 }
842
843 if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
844 ret = get_dynamic_win_info(remote_addr, module, target);
845 if (ret != OMPI_SUCCESS) {
846 return ret;
847 }
848 }
849
850 OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
851 assert(NULL != ucx_req);
852
853 ret = ompi_osc_ucx_get(origin_addr, origin_count, origin_dt, target, target_disp,
854 target_count, target_dt, win);
855 if (ret != OMPI_SUCCESS) {
856 return ret;
857 }
858
859 ret = opal_common_ucx_wpmem_fence(module->mem);
860 if (ret != OMPI_SUCCESS) {
861 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret);
862 return OMPI_ERROR;
863 }
864
865 mca_osc_ucx_component.num_incomplete_req_ops++;
866 ret = opal_common_ucx_wpmem_fetch_nb(module->mem, UCP_ATOMIC_FETCH_OP_FADD,
867 0, target, &(module->req_result),
868 sizeof(uint64_t), remote_addr,
869 req_completion, ucx_req);
870 if (ret != OMPI_SUCCESS) {
871 return ret;
872 }
873
874 *request = &ucx_req->super;
875
876 return ret;
877 }
878
879 int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
880 struct ompi_datatype_t *origin_dt,
881 int target, ptrdiff_t target_disp, int target_count,
882 struct ompi_datatype_t *target_dt, struct ompi_op_t *op,
883 struct ompi_win_t *win, struct ompi_request_t **request) {
884 ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
885 ompi_osc_ucx_request_t *ucx_req = NULL;
886 int ret = OMPI_SUCCESS;
887
888 ret = check_sync_state(module, target, true);
889 if (ret != OMPI_SUCCESS) {
890 return ret;
891 }
892
893 OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
894 assert(NULL != ucx_req);
895
896 ret = ompi_osc_ucx_accumulate(origin_addr, origin_count, origin_dt, target, target_disp,
897 target_count, target_dt, op, win);
898 if (ret != OMPI_SUCCESS) {
899 return ret;
900 }
901
902 ompi_request_complete(&ucx_req->super, true);
903 *request = &ucx_req->super;
904
905 return ret;
906 }
907
908 int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
909 struct ompi_datatype_t *origin_datatype,
910 void *result_addr, int result_count,
911 struct ompi_datatype_t *result_datatype,
912 int target, ptrdiff_t target_disp, int target_count,
913 struct ompi_datatype_t *target_datatype,
914 struct ompi_op_t *op, struct ompi_win_t *win,
915 struct ompi_request_t **request) {
916 ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
917 ompi_osc_ucx_request_t *ucx_req = NULL;
918 int ret = OMPI_SUCCESS;
919
920 ret = check_sync_state(module, target, true);
921 if (ret != OMPI_SUCCESS) {
922 return ret;
923 }
924
925 OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
926 assert(NULL != ucx_req);
927
928 ret = ompi_osc_ucx_get_accumulate(origin_addr, origin_count, origin_datatype,
929 result_addr, result_count, result_datatype,
930 target, target_disp, target_count,
931 target_datatype, op, win);
932 if (ret != OMPI_SUCCESS) {
933 return ret;
934 }
935
936 ompi_request_complete(&ucx_req->super, true);
937
938 *request = &ucx_req->super;
939
940 return ret;
941 }