1
2
3
4
5
6
7
8
9
10 #ifndef OMPI_OSC_UCX_H
11 #define OMPI_OSC_UCX_H
12
13 #include <ucp/api/ucp.h>
14
15 #include "ompi/group/group.h"
16 #include "ompi/communicator/communicator.h"
17 #include "opal/mca/common/ucx/common_ucx.h"
18 #include "opal/mca/common/ucx/common_ucx_wpool.h"
19
20 #define OSC_UCX_ASSERT MCA_COMMON_UCX_ASSERT
21 #define OSC_UCX_ERROR MCA_COMMON_UCX_ERROR
22 #define OSC_UCX_VERBOSE MCA_COMMON_UCX_VERBOSE
23
24 #define OMPI_OSC_UCX_POST_PEER_MAX 32
25 #define OMPI_OSC_UCX_ATTACH_MAX 32
26 #define OMPI_OSC_UCX_MEM_ADDR_MAX_LEN 1024
27
28 typedef struct ompi_osc_ucx_component {
29 ompi_osc_base_component_t super;
30 opal_common_ucx_wpool_t *wpool;
31 bool enable_mpi_threads;
32 opal_free_list_t requests;
33 bool env_initialized;
34 int num_incomplete_req_ops;
35 int num_modules;
36 unsigned int priority;
37 } ompi_osc_ucx_component_t;
38
39 OMPI_DECLSPEC extern ompi_osc_ucx_component_t mca_osc_ucx_component;
40
41 typedef enum ompi_osc_ucx_epoch {
42 NONE_EPOCH,
43 FENCE_EPOCH,
44 POST_WAIT_EPOCH,
45 START_COMPLETE_EPOCH,
46 PASSIVE_EPOCH,
47 PASSIVE_ALL_EPOCH
48 } ompi_osc_ucx_epoch_t;
49
50 typedef struct ompi_osc_ucx_epoch_type {
51 ompi_osc_ucx_epoch_t access;
52 ompi_osc_ucx_epoch_t exposure;
53 } ompi_osc_ucx_epoch_type_t;
54
55 #define TARGET_LOCK_UNLOCKED ((uint64_t)(0x0000000000000000ULL))
56 #define TARGET_LOCK_EXCLUSIVE ((uint64_t)(0x0000000100000000ULL))
57
58 #define OSC_UCX_IOVEC_MAX 128
59
60 #define OSC_UCX_STATE_LOCK_OFFSET 0
61 #define OSC_UCX_STATE_REQ_FLAG_OFFSET sizeof(uint64_t)
62 #define OSC_UCX_STATE_ACC_LOCK_OFFSET (sizeof(uint64_t) * 2)
63 #define OSC_UCX_STATE_COMPLETE_COUNT_OFFSET (sizeof(uint64_t) * 3)
64 #define OSC_UCX_STATE_POST_INDEX_OFFSET (sizeof(uint64_t) * 4)
65 #define OSC_UCX_STATE_POST_STATE_OFFSET (sizeof(uint64_t) * 5)
66 #define OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET (sizeof(uint64_t) * (5 + OMPI_OSC_UCX_POST_PEER_MAX))
67
68 typedef struct ompi_osc_dynamic_win_info {
69 uint64_t base;
70 size_t size;
71 char mem_addr[OMPI_OSC_UCX_MEM_ADDR_MAX_LEN];
72 } ompi_osc_dynamic_win_info_t;
73
74 typedef struct ompi_osc_local_dynamic_win_info {
75 opal_common_ucx_wpmem_t *mem;
76 char *my_mem_addr;
77 int my_mem_addr_size;
78 int refcnt;
79 } ompi_osc_local_dynamic_win_info_t;
80
81 typedef struct ompi_osc_ucx_state {
82 volatile uint64_t lock;
83 volatile uint64_t req_flag;
84 volatile uint64_t acc_lock;
85 volatile uint64_t complete_count;
86 volatile uint64_t post_index;
87 volatile uint64_t post_state[OMPI_OSC_UCX_POST_PEER_MAX];
88 volatile uint64_t dynamic_win_count;
89 volatile ompi_osc_dynamic_win_info_t dynamic_wins[OMPI_OSC_UCX_ATTACH_MAX];
90 } ompi_osc_ucx_state_t;
91
92 typedef struct ompi_osc_ucx_module {
93 ompi_osc_base_module_t super;
94 struct ompi_communicator_t *comm;
95 int flavor;
96 size_t size;
97 uint64_t *addrs;
98 uint64_t *state_addrs;
99 int disp_unit;
100
101
102 int *disp_units;
103
104 ompi_osc_ucx_state_t state;
105 ompi_osc_local_dynamic_win_info_t local_dynamic_win_info[OMPI_OSC_UCX_ATTACH_MAX];
106 ompi_osc_ucx_epoch_type_t epoch_type;
107 ompi_group_t *start_group;
108 ompi_group_t *post_group;
109 opal_hash_table_t outstanding_locks;
110 opal_list_t pending_posts;
111 int lock_count;
112 int post_count;
113 uint64_t req_result;
114 int *start_grp_ranks;
115 bool lock_all_is_nocheck;
116 opal_common_ucx_ctx_t *ctx;
117 opal_common_ucx_wpmem_t *mem;
118 opal_common_ucx_wpmem_t *state_mem;
119 } ompi_osc_ucx_module_t;
120
121 typedef enum locktype {
122 LOCK_EXCLUSIVE,
123 LOCK_SHARED
124 } lock_type_t;
125
126 typedef struct ompi_osc_ucx_lock {
127 opal_object_t super;
128 int target_rank;
129 lock_type_t type;
130 bool is_nocheck;
131 } ompi_osc_ucx_lock_t;
132
133 #define OSC_UCX_GET_EP(comm_, rank_) (ompi_comm_peer_lookup(comm_, rank_)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_UCX])
134 #define OSC_UCX_GET_DISP(module_, rank_) ((module_->disp_unit < 0) ? module_->disp_units[rank_] : module_->disp_unit)
135
136 int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len);
137 int ompi_osc_ucx_win_detach(struct ompi_win_t *win, const void *base);
138 int ompi_osc_ucx_free(struct ompi_win_t *win);
139
140 int ompi_osc_ucx_put(const void *origin_addr, int origin_count,
141 struct ompi_datatype_t *origin_dt,
142 int target, ptrdiff_t target_disp, int target_count,
143 struct ompi_datatype_t *target_dt, struct ompi_win_t *win);
144 int ompi_osc_ucx_get(void *origin_addr, int origin_count,
145 struct ompi_datatype_t *origin_dt,
146 int target, ptrdiff_t target_disp, int target_count,
147 struct ompi_datatype_t *target_dt, struct ompi_win_t *win);
148 int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
149 struct ompi_datatype_t *origin_dt,
150 int target, ptrdiff_t target_disp, int target_count,
151 struct ompi_datatype_t *target_dt,
152 struct ompi_op_t *op, struct ompi_win_t *win);
153 int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr,
154 void *result_addr, struct ompi_datatype_t *dt,
155 int target, ptrdiff_t target_disp,
156 struct ompi_win_t *win);
157 int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
158 struct ompi_datatype_t *dt, int target,
159 ptrdiff_t target_disp, struct ompi_op_t *op,
160 struct ompi_win_t *win);
161 int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
162 struct ompi_datatype_t *origin_datatype,
163 void *result_addr, int result_count,
164 struct ompi_datatype_t *result_datatype,
165 int target_rank, ptrdiff_t target_disp,
166 int target_count, struct ompi_datatype_t *target_datatype,
167 struct ompi_op_t *op, struct ompi_win_t *win);
168 int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
169 struct ompi_datatype_t *origin_dt,
170 int target, ptrdiff_t target_disp, int target_count,
171 struct ompi_datatype_t *target_dt,
172 struct ompi_win_t *win, struct ompi_request_t **request);
173 int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
174 struct ompi_datatype_t *origin_dt,
175 int target, ptrdiff_t target_disp, int target_count,
176 struct ompi_datatype_t *target_dt, struct ompi_win_t *win,
177 struct ompi_request_t **request);
178 int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
179 struct ompi_datatype_t *origin_dt,
180 int target, ptrdiff_t target_disp, int target_count,
181 struct ompi_datatype_t *target_dt, struct ompi_op_t *op,
182 struct ompi_win_t *win, struct ompi_request_t **request);
183 int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
184 struct ompi_datatype_t *origin_datatype,
185 void *result_addr, int result_count,
186 struct ompi_datatype_t *result_datatype,
187 int target_rank, ptrdiff_t target_disp, int target_count,
188 struct ompi_datatype_t *target_datatype,
189 struct ompi_op_t *op, struct ompi_win_t *win,
190 struct ompi_request_t **request);
191
192 int ompi_osc_ucx_fence(int assert, struct ompi_win_t *win);
193 int ompi_osc_ucx_start(struct ompi_group_t *group, int assert, struct ompi_win_t *win);
194 int ompi_osc_ucx_complete(struct ompi_win_t *win);
195 int ompi_osc_ucx_post(struct ompi_group_t *group, int assert, struct ompi_win_t *win);
196 int ompi_osc_ucx_wait(struct ompi_win_t *win);
197 int ompi_osc_ucx_test(struct ompi_win_t *win, int *flag);
198
199 int ompi_osc_ucx_lock(int lock_type, int target, int assert, struct ompi_win_t *win);
200 int ompi_osc_ucx_unlock(int target, struct ompi_win_t *win);
201 int ompi_osc_ucx_lock_all(int assert, struct ompi_win_t *win);
202 int ompi_osc_ucx_unlock_all(struct ompi_win_t *win);
203 int ompi_osc_ucx_sync(struct ompi_win_t *win);
204 int ompi_osc_ucx_flush(int target, struct ompi_win_t *win);
205 int ompi_osc_ucx_flush_all(struct ompi_win_t *win);
206 int ompi_osc_ucx_flush_local(int target, struct ompi_win_t *win);
207 int ompi_osc_ucx_flush_local_all(struct ompi_win_t *win);
208
209 int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_wins,
210 int min_index, int max_index,
211 uint64_t base, size_t len, int *insert);
212
213 #endif