root/ompi/mca/osc/ucx/osc_ucx_comm.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. check_sync_state
  2. create_iov_list
  3. ddt_put_get
  4. start_atomicity
  5. end_atomicity
  6. get_dynamic_win_info
  7. ompi_osc_ucx_put
  8. ompi_osc_ucx_get
  9. ompi_osc_ucx_accumulate
  10. ompi_osc_ucx_compare_and_swap
  11. ompi_osc_ucx_fetch_and_op
  12. ompi_osc_ucx_get_accumulate
  13. ompi_osc_ucx_rput
  14. ompi_osc_ucx_rget
  15. ompi_osc_ucx_raccumulate
  16. ompi_osc_ucx_rget_accumulate

   1 /*
   2  * Copyright (C) Mellanox Technologies Ltd. 2001-2017. ALL RIGHTS RESERVED.
   3  * $COPYRIGHT$
   4  *
   5  * Additional copyrights may follow
   6  *
   7  * $HEADER$
   8  */
   9 
  10 #include "ompi_config.h"
  11 
  12 #include "ompi/mca/osc/osc.h"
  13 #include "ompi/mca/osc/base/base.h"
  14 #include "ompi/mca/osc/base/osc_base_obj_convert.h"
  15 #include "opal/mca/common/ucx/common_ucx.h"
  16 
  17 #include "osc_ucx.h"
  18 #include "osc_ucx_request.h"
  19 
  20 
  21 #define CHECK_VALID_RKEY(_module, _target, _count)                               \
  22     if (!((_module)->win_info_array[_target]).rkey_init && ((_count) > 0)) {     \
  23         OSC_UCX_VERBOSE(1, "window with non-zero length does not have an rkey"); \
  24         return OMPI_ERROR;                                                       \
  25     }
  26 
  27 typedef struct ucx_iovec {
  28     void *addr;
  29     size_t len;
  30 } ucx_iovec_t;
  31 
  32 static inline int check_sync_state(ompi_osc_ucx_module_t *module, int target,
  33                                    bool is_req_ops) {
  34     if (is_req_ops == false) {
  35         if (module->epoch_type.access == NONE_EPOCH) {
  36             return OMPI_ERR_RMA_SYNC;
  37         } else if (module->epoch_type.access == START_COMPLETE_EPOCH) {
  38             int i, size = ompi_group_size(module->start_group);
  39             for (i = 0; i < size; i++) {
  40                 if (module->start_grp_ranks[i] == target) {
  41                     break;
  42                 }
  43             }
  44             if (i == size) {
  45                 return OMPI_ERR_RMA_SYNC;
  46             }
  47         } else if (module->epoch_type.access == PASSIVE_EPOCH) {
  48             ompi_osc_ucx_lock_t *item = NULL;
  49             opal_hash_table_get_value_uint32(&module->outstanding_locks, (uint32_t) target, (void **) &item);
  50             if (item == NULL) {
  51                 return OMPI_ERR_RMA_SYNC;
  52             }
  53         }
  54     } else {
  55         if (module->epoch_type.access != PASSIVE_EPOCH &&
  56             module->epoch_type.access != PASSIVE_ALL_EPOCH) {
  57             return OMPI_ERR_RMA_SYNC;
  58         } else if (module->epoch_type.access == PASSIVE_EPOCH) {
  59             ompi_osc_ucx_lock_t *item = NULL;
  60             opal_hash_table_get_value_uint32(&module->outstanding_locks, (uint32_t) target, (void **) &item);
  61             if (item == NULL) {
  62                 return OMPI_ERR_RMA_SYNC;
  63             }
  64         }
  65     }
  66     return OMPI_SUCCESS;
  67 }
  68 
  69 static inline int create_iov_list(const void *addr, int count, ompi_datatype_t *datatype,
  70                                   ucx_iovec_t **ucx_iov, uint32_t *ucx_iov_count) {
  71     int ret = OMPI_SUCCESS;
  72     size_t size;
  73     bool done = false;
  74     opal_convertor_t convertor;
  75     uint32_t iov_count, iov_idx;
  76     struct iovec iov[OSC_UCX_IOVEC_MAX];
  77     uint32_t ucx_iov_idx;
  78 
  79     OBJ_CONSTRUCT(&convertor, opal_convertor_t);
  80     ret = opal_convertor_copy_and_prepare_for_send(ompi_mpi_local_convertor,
  81                                                    &datatype->super, count,
  82                                                    addr, 0, &convertor);
  83     if (ret != OMPI_SUCCESS) {
  84         return ret;
  85     }
  86 
  87     (*ucx_iov_count) = 0;
  88     ucx_iov_idx = 0;
  89 
  90     do {
  91         iov_count = OSC_UCX_IOVEC_MAX;
  92         iov_idx = 0;
  93 
  94         done = opal_convertor_raw(&convertor, iov, &iov_count, &size);
  95 
  96         (*ucx_iov_count) += iov_count;
  97         (*ucx_iov) = (ucx_iovec_t *)realloc((*ucx_iov), (*ucx_iov_count) * sizeof(ucx_iovec_t));
  98         if (*ucx_iov == NULL) {
  99             return OMPI_ERR_TEMP_OUT_OF_RESOURCE;
 100         }
 101 
 102         while (iov_idx != iov_count) {
 103             (*ucx_iov)[ucx_iov_idx].addr = iov[iov_idx].iov_base;
 104             (*ucx_iov)[ucx_iov_idx].len = iov[iov_idx].iov_len;
 105             ucx_iov_idx++;
 106             iov_idx++;
 107         }
 108 
 109         assert((*ucx_iov_count) == ucx_iov_idx);
 110 
 111     } while (!done);
 112 
 113     opal_convertor_cleanup(&convertor);
 114     OBJ_DESTRUCT(&convertor);
 115 
 116     return ret;
 117 }
 118 
 119 static inline int ddt_put_get(ompi_osc_ucx_module_t *module,
 120                               const void *origin_addr, int origin_count,
 121                               struct ompi_datatype_t *origin_dt,
 122                               bool is_origin_contig, ptrdiff_t origin_lb,
 123                               int target, uint64_t remote_addr,
 124                               int target_count, struct ompi_datatype_t *target_dt,
 125                               bool is_target_contig, ptrdiff_t target_lb, bool is_get) {
 126     ucx_iovec_t *origin_ucx_iov = NULL, *target_ucx_iov = NULL;
 127     uint32_t origin_ucx_iov_count = 0, target_ucx_iov_count = 0;
 128     uint32_t origin_ucx_iov_idx = 0, target_ucx_iov_idx = 0;
 129     int status;
 130     int ret = OMPI_SUCCESS;
 131 
 132     if (!is_origin_contig) {
 133         ret = create_iov_list(origin_addr, origin_count, origin_dt,
 134                               &origin_ucx_iov, &origin_ucx_iov_count);
 135         if (ret != OMPI_SUCCESS) {
 136             return ret;
 137         }
 138     }
 139 
 140     if (!is_target_contig) {
 141         ret = create_iov_list(NULL, target_count, target_dt,
 142                               &target_ucx_iov, &target_ucx_iov_count);
 143         if (ret != OMPI_SUCCESS) {
 144             return ret;
 145         }
 146     }
 147 
 148     if (!is_origin_contig && !is_target_contig) {
 149         size_t curr_len = 0;
 150         opal_common_ucx_op_t op;
 151         while (origin_ucx_iov_idx < origin_ucx_iov_count) {
 152             curr_len = MIN(origin_ucx_iov[origin_ucx_iov_idx].len,
 153                            target_ucx_iov[target_ucx_iov_idx].len);
 154             if (is_get) {
 155                 op = OPAL_COMMON_UCX_GET;
 156             } else {
 157                 op = OPAL_COMMON_UCX_PUT;
 158             }
 159             status = opal_common_ucx_wpmem_putget(module->mem, op, target,
 160                                                 origin_ucx_iov[origin_ucx_iov_idx].addr, curr_len,
 161                                                 remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr));
 162             if (OPAL_SUCCESS != status) {
 163                 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status);
 164                 return OMPI_ERROR;
 165             }
 166 
 167             origin_ucx_iov[origin_ucx_iov_idx].addr = (void *)((intptr_t)origin_ucx_iov[origin_ucx_iov_idx].addr + curr_len);
 168             target_ucx_iov[target_ucx_iov_idx].addr = (void *)((intptr_t)target_ucx_iov[target_ucx_iov_idx].addr + curr_len);
 169 
 170             origin_ucx_iov[origin_ucx_iov_idx].len -= curr_len;
 171             if (origin_ucx_iov[origin_ucx_iov_idx].len == 0) {
 172                 origin_ucx_iov_idx++;
 173             }
 174             target_ucx_iov[target_ucx_iov_idx].len -= curr_len;
 175             if (target_ucx_iov[target_ucx_iov_idx].len == 0) {
 176                 target_ucx_iov_idx++;
 177             }
 178         }
 179 
 180         assert(origin_ucx_iov_idx == origin_ucx_iov_count &&
 181                target_ucx_iov_idx == target_ucx_iov_count);
 182 
 183     } else if (!is_origin_contig) {
 184         size_t prev_len = 0;
 185         opal_common_ucx_op_t op;
 186         while (origin_ucx_iov_idx < origin_ucx_iov_count) {
 187             if (is_get) {
 188                 op = OPAL_COMMON_UCX_GET;
 189             } else {
 190                 op = OPAL_COMMON_UCX_PUT;
 191             }
 192             status = opal_common_ucx_wpmem_putget(module->mem, op, target,
 193                                                 origin_ucx_iov[origin_ucx_iov_idx].addr,
 194                                                 origin_ucx_iov[origin_ucx_iov_idx].len,
 195                                                 remote_addr + target_lb + prev_len);
 196             if (OPAL_SUCCESS != status) {
 197                 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status);
 198                 return OMPI_ERROR;
 199             }
 200 
 201             prev_len += origin_ucx_iov[origin_ucx_iov_idx].len;
 202             origin_ucx_iov_idx++;
 203         }
 204     } else {
 205         size_t prev_len = 0;
 206         opal_common_ucx_op_t op;
 207         while (target_ucx_iov_idx < target_ucx_iov_count) {
 208             if (is_get) {
 209                 op = OPAL_COMMON_UCX_GET;
 210             } else {
 211                 op = OPAL_COMMON_UCX_PUT;
 212             }
 213 
 214             status = opal_common_ucx_wpmem_putget(module->mem, op, target,
 215                                                 (void *)((intptr_t)origin_addr + origin_lb + prev_len),
 216                                                 target_ucx_iov[target_ucx_iov_idx].len,
 217                                                 remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr));
 218             if (OPAL_SUCCESS != status) {
 219                 OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status);
 220                 return OMPI_ERROR;
 221             }
 222 
 223             prev_len += target_ucx_iov[target_ucx_iov_idx].len;
 224             target_ucx_iov_idx++;
 225         }
 226     }
 227 
 228     if (origin_ucx_iov != NULL) {
 229         free(origin_ucx_iov);
 230     }
 231     if (target_ucx_iov != NULL) {
 232         free(target_ucx_iov);
 233     }
 234 
 235     return ret;
 236 }
 237 
 238 static inline int start_atomicity(ompi_osc_ucx_module_t *module, int target) {
 239     uint64_t result_value = -1;
 240     uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_ACC_LOCK_OFFSET;
 241     int ret = OMPI_SUCCESS;
 242 
 243     for (;;) {
 244         ret = opal_common_ucx_wpmem_cmpswp(module->state_mem,
 245                                          TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE,
 246                                          target, &result_value, sizeof(result_value),
 247                                          remote_addr);
 248         if (ret != OMPI_SUCCESS) {
 249             OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_cmpswp failed: %d", ret);
 250             return OMPI_ERROR;
 251         }
 252         if (result_value == TARGET_LOCK_UNLOCKED) {
 253             return OMPI_SUCCESS;
 254         }
 255 
 256         ucp_worker_progress(mca_osc_ucx_component.wpool->dflt_worker);
 257     }
 258 }
 259 
 260 static inline int end_atomicity(ompi_osc_ucx_module_t *module, int target) {
 261     uint64_t result_value = 0;
 262     uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_ACC_LOCK_OFFSET;
 263     int ret = OMPI_SUCCESS;
 264 
 265     ret = opal_common_ucx_wpmem_fetch(module->state_mem,
 266                                     UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED,
 267                                     target, &result_value, sizeof(result_value),
 268                                     remote_addr);
 269     if (ret != OMPI_SUCCESS) {
 270         OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fetch failed: %d", ret);
 271         return OMPI_ERROR;
 272     }
 273 
 274     assert(result_value == TARGET_LOCK_EXCLUSIVE);
 275 
 276     return ret;
 277 }
 278 
 279 static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module_t *module,
 280                                        int target) {
 281     uint64_t remote_state_addr = (module->state_addrs)[target] + OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET;
 282     size_t len = sizeof(uint64_t) + sizeof(ompi_osc_dynamic_win_info_t) * OMPI_OSC_UCX_ATTACH_MAX;
 283     char *temp_buf = malloc(len);
 284     ompi_osc_dynamic_win_info_t *temp_dynamic_wins;
 285     uint64_t win_count;
 286     int contain, insert = -1;
 287     int ret;
 288 
 289     ret = opal_common_ucx_wpmem_putget(module->state_mem, OPAL_COMMON_UCX_GET, target,
 290                                        (void *)((intptr_t)temp_buf),
 291                                        len, remote_state_addr);
 292     if (OPAL_SUCCESS != ret) {
 293         OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
 294         return OMPI_ERROR;
 295     }
 296 
 297     ret = opal_common_ucx_wpmem_flush(module->state_mem, OPAL_COMMON_UCX_SCOPE_EP, target);
 298     if (ret != OMPI_SUCCESS) {
 299         return ret;
 300     }
 301 
 302     memcpy(&win_count, temp_buf, sizeof(uint64_t));
 303     assert(win_count > 0 && win_count <= OMPI_OSC_UCX_ATTACH_MAX);
 304 
 305     temp_dynamic_wins = (ompi_osc_dynamic_win_info_t *)(temp_buf + sizeof(uint64_t));
 306     contain = ompi_osc_find_attached_region_position(temp_dynamic_wins, 0, win_count,
 307                                                      remote_addr, 1, &insert);
 308     assert(contain >= 0 && (uint64_t)contain < win_count);
 309 
 310     if (module->local_dynamic_win_info[contain].mem->mem_addrs == NULL) {
 311         module->local_dynamic_win_info[contain].mem->mem_addrs = calloc(ompi_comm_size(module->comm),
 312                                                                         OMPI_OSC_UCX_MEM_ADDR_MAX_LEN);
 313         module->local_dynamic_win_info[contain].mem->mem_displs =calloc(ompi_comm_size(module->comm),
 314                                                                         sizeof(int));
 315     }
 316 
 317     memcpy(module->local_dynamic_win_info[contain].mem->mem_addrs + target * OMPI_OSC_UCX_MEM_ADDR_MAX_LEN,
 318            temp_dynamic_wins[contain].mem_addr, OMPI_OSC_UCX_MEM_ADDR_MAX_LEN);
 319     module->local_dynamic_win_info[contain].mem->mem_displs[target] = target * OMPI_OSC_UCX_MEM_ADDR_MAX_LEN;
 320 
 321     free(temp_buf);
 322 
 323     return ret;
 324 }
 325 
 326 int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt,
 327                      int target, ptrdiff_t target_disp, int target_count,
 328                      struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
 329     ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
 330     uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
 331     bool is_origin_contig = false, is_target_contig = false;
 332     ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
 333     int ret = OMPI_SUCCESS;
 334 
 335     ret = check_sync_state(module, target, false);
 336     if (ret != OMPI_SUCCESS) {
 337         return ret;
 338     }
 339 
 340     if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
 341         ret = get_dynamic_win_info(remote_addr, module, target);
 342         if (ret != OMPI_SUCCESS) {
 343             return ret;
 344         }
 345     }
 346 
 347     if (!target_count) {
 348         return OMPI_SUCCESS;
 349     }
 350 
 351     ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
 352     ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);
 353 
 354     is_origin_contig = ompi_datatype_is_contiguous_memory_layout(origin_dt, origin_count);
 355     is_target_contig = ompi_datatype_is_contiguous_memory_layout(target_dt, target_count);
 356 
 357     if (is_origin_contig && is_target_contig) {
 358         /* fast path */
 359         size_t origin_len;
 360 
 361         ompi_datatype_type_size(origin_dt, &origin_len);
 362         origin_len *= origin_count;
 363 
 364         ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_PUT, target,
 365                                          (void *)((intptr_t)origin_addr + origin_lb),
 366                                          origin_len, remote_addr + target_lb);
 367         if (OPAL_SUCCESS != ret) {
 368             OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
 369             return OMPI_ERROR;
 370         }
 371         return ret;
 372     } else {
 373         return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig,
 374                            origin_lb, target, remote_addr, target_count, target_dt,
 375                            is_target_contig, target_lb, false);
 376     }
 377 }
 378 
 379 int ompi_osc_ucx_get(void *origin_addr, int origin_count,
 380                      struct ompi_datatype_t *origin_dt,
 381                      int target, ptrdiff_t target_disp, int target_count,
 382                      struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
 383     ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
 384     uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
 385     ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
 386     bool is_origin_contig = false, is_target_contig = false;
 387     int ret = OMPI_SUCCESS;
 388 
 389     ret = check_sync_state(module, target, false);
 390     if (ret != OMPI_SUCCESS) {
 391         return ret;
 392     }
 393 
 394     if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
 395         ret = get_dynamic_win_info(remote_addr, module, target);
 396         if (ret != OMPI_SUCCESS) {
 397             return ret;
 398         }
 399     }
 400 
 401     if (!target_count) {
 402         return OMPI_SUCCESS;
 403     }
 404 
 405 
 406     ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
 407     ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);
 408 
 409     is_origin_contig = ompi_datatype_is_contiguous_memory_layout(origin_dt, origin_count);
 410     is_target_contig = ompi_datatype_is_contiguous_memory_layout(target_dt, target_count);
 411 
 412     if (is_origin_contig && is_target_contig) {
 413         /* fast path */
 414         size_t origin_len;
 415 
 416         ompi_datatype_type_size(origin_dt, &origin_len);
 417         origin_len *= origin_count;
 418 
 419         ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET, target,
 420                                          (void *)((intptr_t)origin_addr + origin_lb),
 421                                          origin_len, remote_addr + target_lb);
 422         if (OPAL_SUCCESS != ret) {
 423             OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
 424             return OMPI_ERROR;
 425         }
 426 
 427         return ret;
 428     } else {
 429         return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig,
 430                            origin_lb, target, remote_addr, target_count, target_dt,
 431                            is_target_contig, target_lb, true);
 432     }
 433 }
 434 
 435 int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
 436                             struct ompi_datatype_t *origin_dt,
 437                             int target, ptrdiff_t target_disp, int target_count,
 438                             struct ompi_datatype_t *target_dt,
 439                             struct ompi_op_t *op, struct ompi_win_t *win) {
 440     ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
 441     int ret = OMPI_SUCCESS;
 442 
 443     ret = check_sync_state(module, target, false);
 444     if (ret != OMPI_SUCCESS) {
 445         return ret;
 446     }
 447 
 448     if (op == &ompi_mpi_op_no_op.op) {
 449         return ret;
 450     }
 451 
 452     ret = start_atomicity(module, target);
 453     if (ret != OMPI_SUCCESS) {
 454         return ret;
 455     }
 456 
 457     if (op == &ompi_mpi_op_replace.op) {
 458         ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt, target,
 459                                target_disp, target_count, target_dt, win);
 460         if (ret != OMPI_SUCCESS) {
 461             return ret;
 462         }
 463     } else {
 464         void *temp_addr_holder = NULL;
 465         void *temp_addr = NULL;
 466         uint32_t temp_count;
 467         ompi_datatype_t *temp_dt;
 468         ptrdiff_t temp_lb, temp_extent;
 469         bool is_origin_contig = ompi_datatype_is_contiguous_memory_layout(origin_dt, origin_count);
 470 
 471         if (ompi_datatype_is_predefined(target_dt)) {
 472             temp_dt = target_dt;
 473             temp_count = target_count;
 474         } else {
 475             ret = ompi_osc_base_get_primitive_type_info(target_dt, &temp_dt, &temp_count);
 476             if (ret != OMPI_SUCCESS) {
 477                 return ret;
 478             }
 479         }
 480         ompi_datatype_get_true_extent(temp_dt, &temp_lb, &temp_extent);
 481         temp_addr = temp_addr_holder = malloc(temp_extent * temp_count);
 482         if (temp_addr == NULL) {
 483             return OMPI_ERR_TEMP_OUT_OF_RESOURCE;
 484         }
 485 
 486         ret = ompi_osc_ucx_get(temp_addr, (int)temp_count, temp_dt,
 487                                target, target_disp, target_count, target_dt, win);
 488         if (ret != OMPI_SUCCESS) {
 489             return ret;
 490         }
 491 
 492         ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
 493         if (ret != OMPI_SUCCESS) {
 494             return ret;
 495         }
 496 
 497         if (ompi_datatype_is_predefined(origin_dt) || is_origin_contig) {
 498             ompi_op_reduce(op, (void *)origin_addr, temp_addr, (int)temp_count, temp_dt);
 499         } else {
 500             ucx_iovec_t *origin_ucx_iov = NULL;
 501             uint32_t origin_ucx_iov_count = 0;
 502             uint32_t origin_ucx_iov_idx = 0;
 503 
 504             ret = create_iov_list(origin_addr, origin_count, origin_dt,
 505                                   &origin_ucx_iov, &origin_ucx_iov_count);
 506             if (ret != OMPI_SUCCESS) {
 507                 return ret;
 508             }
 509 
 510             if ((op != &ompi_mpi_op_maxloc.op && op != &ompi_mpi_op_minloc.op) ||
 511                 ompi_datatype_is_contiguous_memory_layout(temp_dt, temp_count)) {
 512                 size_t temp_size;
 513                 ompi_datatype_type_size(temp_dt, &temp_size);
 514                 while (origin_ucx_iov_idx < origin_ucx_iov_count) {
 515                     int curr_count = origin_ucx_iov[origin_ucx_iov_idx].len / temp_size;
 516                     ompi_op_reduce(op, origin_ucx_iov[origin_ucx_iov_idx].addr,
 517                                    temp_addr, curr_count, temp_dt);
 518                     temp_addr = (void *)((char *)temp_addr + curr_count * temp_size);
 519                     origin_ucx_iov_idx++;
 520                 }
 521             } else {
 522                 int i;
 523                 void *curr_origin_addr = origin_ucx_iov[origin_ucx_iov_idx].addr;
 524                 for (i = 0; i < (int)temp_count; i++) {
 525                     ompi_op_reduce(op, curr_origin_addr,
 526                                    (void *)((char *)temp_addr + i * temp_extent),
 527                                    1, temp_dt);
 528                     curr_origin_addr = (void *)((char *)curr_origin_addr + temp_extent);
 529                     origin_ucx_iov_idx++;
 530                     if (curr_origin_addr >= (void *)((char *)origin_ucx_iov[origin_ucx_iov_idx].addr + origin_ucx_iov[origin_ucx_iov_idx].len)) {
 531                         origin_ucx_iov_idx++;
 532                         curr_origin_addr = origin_ucx_iov[origin_ucx_iov_idx].addr;
 533                     }
 534                 }
 535             }
 536 
 537             free(origin_ucx_iov);
 538         }
 539 
 540         ret = ompi_osc_ucx_put(temp_addr, (int)temp_count, temp_dt, target, target_disp,
 541                                target_count, target_dt, win);
 542         if (ret != OMPI_SUCCESS) {
 543             return ret;
 544         }
 545 
 546         ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
 547         if (ret != OMPI_SUCCESS) {
 548             return ret;
 549         }
 550 
 551         free(temp_addr_holder);
 552     }
 553 
 554     return end_atomicity(module, target);
 555 }
 556 
 557 int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr,
 558                                   void *result_addr, struct ompi_datatype_t *dt,
 559                                   int target, ptrdiff_t target_disp,
 560                                   struct ompi_win_t *win) {
 561     ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module;
 562     uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
 563     size_t dt_bytes;
 564     int ret = OMPI_SUCCESS;
 565 
 566     ret = check_sync_state(module, target, false);
 567     if (ret != OMPI_SUCCESS) {
 568         return ret;
 569     }
 570 
 571     ret = start_atomicity(module, target);
 572     if (ret != OMPI_SUCCESS) {
 573         return ret;
 574     }
 575 
 576     if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
 577         ret = get_dynamic_win_info(remote_addr, module, target);
 578         if (ret != OMPI_SUCCESS) {
 579             return ret;
 580         }
 581     }
 582 
 583     ompi_datatype_type_size(dt, &dt_bytes);
 584     ret = opal_common_ucx_wpmem_cmpswp(module->mem,*(uint64_t *)compare_addr,
 585                                      *(uint64_t *)origin_addr, target,
 586                                      result_addr, dt_bytes, remote_addr);
 587     if (ret != OMPI_SUCCESS) {
 588         return ret;
 589     }
 590 
 591     return end_atomicity(module, target);
 592 }
 593 
 594 int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
 595                               struct ompi_datatype_t *dt, int target,
 596                               ptrdiff_t target_disp, struct ompi_op_t *op,
 597                               struct ompi_win_t *win) {
 598     ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
 599     int ret = OMPI_SUCCESS;
 600 
 601     ret = check_sync_state(module, target, false);
 602     if (ret != OMPI_SUCCESS) {
 603         return ret;
 604     }
 605 
 606     if (op == &ompi_mpi_op_no_op.op || op == &ompi_mpi_op_replace.op ||
 607         op == &ompi_mpi_op_sum.op) {
 608         uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
 609         uint64_t value = origin_addr ? *(uint64_t *)origin_addr : 0;
 610         ucp_atomic_fetch_op_t opcode;
 611         size_t dt_bytes;
 612 
 613         ret = start_atomicity(module, target);
 614         if (ret != OMPI_SUCCESS) {
 615             return ret;
 616         }
 617 
 618         if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
 619             ret = get_dynamic_win_info(remote_addr, module, target);
 620             if (ret != OMPI_SUCCESS) {
 621                 return ret;
 622             }
 623         }
 624 
 625         ompi_datatype_type_size(dt, &dt_bytes);
 626 
 627         if (op == &ompi_mpi_op_replace.op) {
 628             opcode = UCP_ATOMIC_FETCH_OP_SWAP;
 629         } else {
 630             opcode = UCP_ATOMIC_FETCH_OP_FADD;
 631             if (op == &ompi_mpi_op_no_op.op) {
 632                 value = 0;
 633             }
 634         }
 635 
 636         ret = opal_common_ucx_wpmem_fetch(module->mem, opcode, value, target,
 637                                         (void *)result_addr, dt_bytes, remote_addr);
 638         if (ret != OMPI_SUCCESS) {
 639             return ret;
 640         }
 641 
 642         return end_atomicity(module, target);
 643     } else {
 644         return ompi_osc_ucx_get_accumulate(origin_addr, 1, dt, result_addr, 1, dt,
 645                                            target, target_disp, 1, dt, op, win);
 646     }
 647 }
 648 
 649 int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
 650                                 struct ompi_datatype_t *origin_dt,
 651                                 void *result_addr, int result_count,
 652                                 struct ompi_datatype_t *result_dt,
 653                                 int target, ptrdiff_t target_disp,
 654                                 int target_count, struct ompi_datatype_t *target_dt,
 655                                 struct ompi_op_t *op, struct ompi_win_t *win) {
 656     ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
 657     int ret = OMPI_SUCCESS;
 658 
 659     ret = check_sync_state(module, target, false);
 660     if (ret != OMPI_SUCCESS) {
 661         return ret;
 662     }
 663 
 664     ret = start_atomicity(module, target);
 665     if (ret != OMPI_SUCCESS) {
 666         return ret;
 667     }
 668 
 669     ret = ompi_osc_ucx_get(result_addr, result_count, result_dt, target,
 670                            target_disp, target_count, target_dt, win);
 671     if (ret != OMPI_SUCCESS) {
 672         return ret;
 673     }
 674 
 675     if (op != &ompi_mpi_op_no_op.op) {
 676         if (op == &ompi_mpi_op_replace.op) {
 677             ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt,
 678                                    target, target_disp, target_count,
 679                                    target_dt, win);
 680             if (ret != OMPI_SUCCESS) {
 681                 return ret;
 682             }
 683         } else {
 684             void *temp_addr_holder = NULL;
 685             void *temp_addr = NULL;
 686             uint32_t temp_count;
 687             ompi_datatype_t *temp_dt;
 688             ptrdiff_t temp_lb, temp_extent;
 689             bool is_origin_contig = ompi_datatype_is_contiguous_memory_layout(origin_dt, origin_count);
 690 
 691             if (ompi_datatype_is_predefined(target_dt)) {
 692                 temp_dt = target_dt;
 693                 temp_count = target_count;
 694             } else {
 695                 ret = ompi_osc_base_get_primitive_type_info(target_dt, &temp_dt, &temp_count);
 696                 if (ret != OMPI_SUCCESS) {
 697                     return ret;
 698                 }
 699             }
 700             ompi_datatype_get_true_extent(temp_dt, &temp_lb, &temp_extent);
 701             temp_addr = temp_addr_holder = malloc(temp_extent * temp_count);
 702             if (temp_addr == NULL) {
 703                 return OMPI_ERR_TEMP_OUT_OF_RESOURCE;
 704             }
 705 
 706             ret = ompi_osc_ucx_get(temp_addr, (int)temp_count, temp_dt,
 707                                    target, target_disp, target_count, target_dt, win);
 708             if (ret != OMPI_SUCCESS) {
 709                 return ret;
 710             }
 711 
 712             ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
 713             if (ret != OMPI_SUCCESS) {
 714                 return ret;
 715             }
 716 
 717             if (ompi_datatype_is_predefined(origin_dt) || is_origin_contig) {
 718                 ompi_op_reduce(op, (void *)origin_addr, temp_addr, (int)temp_count, temp_dt);
 719             } else {
 720                 ucx_iovec_t *origin_ucx_iov = NULL;
 721                 uint32_t origin_ucx_iov_count = 0;
 722                 uint32_t origin_ucx_iov_idx = 0;
 723 
 724                 ret = create_iov_list(origin_addr, origin_count, origin_dt,
 725                                       &origin_ucx_iov, &origin_ucx_iov_count);
 726                 if (ret != OMPI_SUCCESS) {
 727                     return ret;
 728                 }
 729 
 730                 if ((op != &ompi_mpi_op_maxloc.op && op != &ompi_mpi_op_minloc.op) ||
 731                     ompi_datatype_is_contiguous_memory_layout(temp_dt, temp_count)) {
 732                     size_t temp_size;
 733                     ompi_datatype_type_size(temp_dt, &temp_size);
 734                     while (origin_ucx_iov_idx < origin_ucx_iov_count) {
 735                         int curr_count = origin_ucx_iov[origin_ucx_iov_idx].len / temp_size;
 736                         ompi_op_reduce(op, origin_ucx_iov[origin_ucx_iov_idx].addr,
 737                                        temp_addr, curr_count, temp_dt);
 738                         temp_addr = (void *)((char *)temp_addr + curr_count * temp_size);
 739                         origin_ucx_iov_idx++;
 740                     }
 741                 } else {
 742                     int i;
 743                     void *curr_origin_addr = origin_ucx_iov[origin_ucx_iov_idx].addr;
 744                     for (i = 0; i < (int)temp_count; i++) {
 745                         ompi_op_reduce(op, curr_origin_addr,
 746                                        (void *)((char *)temp_addr + i * temp_extent),
 747                                        1, temp_dt);
 748                         curr_origin_addr = (void *)((char *)curr_origin_addr + temp_extent);
 749                         origin_ucx_iov_idx++;
 750                         if (curr_origin_addr >= (void *)((char *)origin_ucx_iov[origin_ucx_iov_idx].addr + origin_ucx_iov[origin_ucx_iov_idx].len)) {
 751                             origin_ucx_iov_idx++;
 752                             curr_origin_addr = origin_ucx_iov[origin_ucx_iov_idx].addr;
 753                         }
 754                     }
 755                 }
 756                 free(origin_ucx_iov);
 757             }
 758 
 759             ret = ompi_osc_ucx_put(temp_addr, (int)temp_count, temp_dt, target, target_disp,
 760                                    target_count, target_dt, win);
 761             if (ret != OMPI_SUCCESS) {
 762                 return ret;
 763             }
 764 
 765             ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
 766             if (ret != OMPI_SUCCESS) {
 767                 return ret;
 768             }
 769 
 770             free(temp_addr_holder);
 771         }
 772     }
 773 
 774     return end_atomicity(module, target);
 775 }
 776 
 777 int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
 778                       struct ompi_datatype_t *origin_dt,
 779                       int target, ptrdiff_t target_disp, int target_count,
 780                       struct ompi_datatype_t *target_dt,
 781                       struct ompi_win_t *win, struct ompi_request_t **request) {
 782     ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
 783     uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
 784     ompi_osc_ucx_request_t *ucx_req = NULL;
 785     int ret = OMPI_SUCCESS;
 786 
 787     ret = check_sync_state(module, target, true);
 788     if (ret != OMPI_SUCCESS) {
 789         return ret;
 790     }
 791 
 792     if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
 793         ret = get_dynamic_win_info(remote_addr, module, target);
 794         if (ret != OMPI_SUCCESS) {
 795             return ret;
 796         }
 797     }
 798 
 799     OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
 800     assert(NULL != ucx_req);
 801 
 802     ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt, target, target_disp,
 803                            target_count, target_dt, win);
 804     if (ret != OMPI_SUCCESS) {
 805         return ret;
 806     }
 807 
 808     ret = opal_common_ucx_wpmem_fence(module->mem);
 809     if (ret != OMPI_SUCCESS) {
 810         OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret);
 811         return OMPI_ERROR;
 812     }
 813 
 814     mca_osc_ucx_component.num_incomplete_req_ops++;
 815     ret = opal_common_ucx_wpmem_fetch_nb(module->mem, UCP_ATOMIC_FETCH_OP_FADD,
 816                                          0, target, &(module->req_result),
 817                                          sizeof(uint64_t), remote_addr,
 818                                          req_completion, ucx_req);
 819     if (ret != OMPI_SUCCESS) {
 820         return ret;
 821     }
 822 
 823     *request = &ucx_req->super;
 824 
 825     return ret;
 826 }
 827 
 828 int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
 829                       struct ompi_datatype_t *origin_dt,
 830                       int target, ptrdiff_t target_disp, int target_count,
 831                       struct ompi_datatype_t *target_dt, struct ompi_win_t *win,
 832                       struct ompi_request_t **request) {
 833     ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
 834     uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
 835     ompi_osc_ucx_request_t *ucx_req = NULL;
 836     int ret = OMPI_SUCCESS;
 837 
 838     ret = check_sync_state(module, target, true);
 839     if (ret != OMPI_SUCCESS) {
 840         return ret;
 841     }
 842 
 843     if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
 844         ret = get_dynamic_win_info(remote_addr, module, target);
 845         if (ret != OMPI_SUCCESS) {
 846             return ret;
 847         }
 848     }
 849 
 850     OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
 851     assert(NULL != ucx_req);
 852 
 853     ret = ompi_osc_ucx_get(origin_addr, origin_count, origin_dt, target, target_disp,
 854                            target_count, target_dt, win);
 855     if (ret != OMPI_SUCCESS) {
 856         return ret;
 857     }
 858 
 859     ret = opal_common_ucx_wpmem_fence(module->mem);
 860     if (ret != OMPI_SUCCESS) {
 861         OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret);
 862         return OMPI_ERROR;
 863     }
 864 
 865     mca_osc_ucx_component.num_incomplete_req_ops++;
 866     ret = opal_common_ucx_wpmem_fetch_nb(module->mem, UCP_ATOMIC_FETCH_OP_FADD,
 867                                          0, target, &(module->req_result),
 868                                          sizeof(uint64_t), remote_addr,
 869                                          req_completion, ucx_req);
 870     if (ret != OMPI_SUCCESS) {
 871         return ret;
 872     }
 873 
 874     *request = &ucx_req->super;
 875 
 876     return ret;
 877 }
 878 
 879 int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
 880                              struct ompi_datatype_t *origin_dt,
 881                              int target, ptrdiff_t target_disp, int target_count,
 882                              struct ompi_datatype_t *target_dt, struct ompi_op_t *op,
 883                              struct ompi_win_t *win, struct ompi_request_t **request) {
 884     ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
 885     ompi_osc_ucx_request_t *ucx_req = NULL;
 886     int ret = OMPI_SUCCESS;
 887 
 888     ret = check_sync_state(module, target, true);
 889     if (ret != OMPI_SUCCESS) {
 890         return ret;
 891     }
 892 
 893     OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
 894     assert(NULL != ucx_req);
 895 
 896     ret = ompi_osc_ucx_accumulate(origin_addr, origin_count, origin_dt, target, target_disp,
 897                                   target_count, target_dt, op, win);
 898     if (ret != OMPI_SUCCESS) {
 899         return ret;
 900     }
 901 
 902     ompi_request_complete(&ucx_req->super, true);
 903     *request = &ucx_req->super;
 904 
 905     return ret;
 906 }
 907 
 908 int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
 909                                  struct ompi_datatype_t *origin_datatype,
 910                                  void *result_addr, int result_count,
 911                                  struct ompi_datatype_t *result_datatype,
 912                                  int target, ptrdiff_t target_disp, int target_count,
 913                                  struct ompi_datatype_t *target_datatype,
 914                                  struct ompi_op_t *op, struct ompi_win_t *win,
 915                                  struct ompi_request_t **request) {
 916     ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
 917     ompi_osc_ucx_request_t *ucx_req = NULL;
 918     int ret = OMPI_SUCCESS;
 919 
 920     ret = check_sync_state(module, target, true);
 921     if (ret != OMPI_SUCCESS) {
 922         return ret;
 923     }
 924 
 925     OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
 926     assert(NULL != ucx_req);
 927 
 928     ret = ompi_osc_ucx_get_accumulate(origin_addr, origin_count, origin_datatype,
 929                                       result_addr, result_count, result_datatype,
 930                                       target, target_disp, target_count,
 931                                       target_datatype, op, win);
 932     if (ret != OMPI_SUCCESS) {
 933         return ret;
 934     }
 935 
 936     ompi_request_complete(&ucx_req->super, true);
 937 
 938     *request = &ucx_req->super;
 939 
 940     return ret;
 941 }

/* [<][>][^][v][top][bottom][index][help] */