root/ompi/patterns/net/allreduce.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. send_completion
  2. recv_completion
  3. op_reduce

   1 /*
   2  * Copyright (c) 2009-2012 Mellanox Technologies.  All rights reserved.
   3  * Copyright (c) 2009-2012 Oak Ridge National Laboratory.  All rights reserved.
   4  * Copyright (c) 2012      Los Alamos National Security, LLC.
   5  *                         All rights reserved.
   6  * Copyright (c) 2017      IBM Corporation. All rights reserved.
   7  * $COPYRIGHT$
   8  *
   9  * Additional copyrights may follow
  10  *
  11  * $HEADER$
  12  */
  13 /** @file */
  14 
  15 #include "ompi_config.h"
  16 
  17 #include "ompi/constants.h"
  18 #include "coll_sm2.h"
  19 #include "ompi/op/op.h"
  20 #include "ompi/datatype/ompi_datatype.h"
  21 #include "ompi/communicator/communicator.h"
  22 #include "ompi/mca/rte/rte.h"
  23 
  24 void send_completion(nt status, struct ompi_process_name_t* peer, struct iovec* msg,
  25                      int count, ompi_rml_tag_t tag, void* cbdata)
  26 {
  27     /* set send completion flag */
  28     *(int *)cbdata=1;
  29 }
  30 
  31 
  32 void recv_completion(nt status, struct ompi_process_name_t* peer, struct iovec* msg,
  33                      int count, ompi_rml_tag_t tag, void* cbdata)
  34 {
  35     /* set receive completion flag */
  36     MB();
  37     *(int *)cbdata=1;
  38 }
  39 
  40 
  41 static void op_reduce(int op_type,(void *)src_dest_buf,(void *) src_buf, int count,
  42         int data_type)
  43 {
  44     /* local variables */
  45     int ret;
  46 
  47     /* op type */
  48     switch (op_type) {
  49 
  50         case OP_SUM:
  51 
  52 
  53             switch (data_type) {
  54                 case TYPE_INT4:
  55                     int *int_src_ptr=(int *)src_ptr;
  56                     int *int_src_dst_ptr=(int *)src_dst_ptr;
  57                     int cnt;
  58                     for(cnt=0 ; cnt < count ; ) {
  59                         (*(int_src_dst_ptr))+=(*(int_src_ptr));
  60                     break;
  61                 default:
  62                     ret=OMPI_ERROR;
  63                     goto Error;
  64             }
  65 
  66             break;
  67 
  68         default:
  69         ret=OMPI_ERROR;
  70         goto Error;
  71     }
  72 Error:
  73     return ret;
  74 }
  75 
  76 /**
  77  * All-reduce for contigous primitive types
  78  */
  79 static
  80 comm_allreduce(void *sbuf, void *rbuf, int count, opal_datatype_t *dtype,
  81         int op_type, opal_list_t *peers)
  82 {
  83     /* local variables */
  84     int rc=OMPI_SUCCESS,n_dts_per_buffer,n_data_segments,stripe_number;
  85     int pair_rank,exchange,extra_rank;
  86     int index_read,index_write;
  87     netpatterns_pair_exchange_node_t my_exchange_node;
  88     int my_rank,count_processed,count_this_stripe;
  89     size_t n_peers,message_extent,len_data_buffer;
  90     size_t dt_size;
  91     long long tag, base_tag;
  92     sm_work_buffer_t *sm_buffer_desc;
  93     opal_list_item_t *item;
  94     char scratch_bufers[2][MAX_TMP_BUFFER];
  95     int send_buffer=0;recv_buffer=1;
  96     char *sbuf_current,*rbuf_current;
  97     ompi_proc_t **proc_array;
  98     struct iovec send_iov, recv_iov;
  99     volatile int *recv_done, *send_done;
 100     int recv_completion_flag, send_completion_flag;
 101     int data_type;
 102 
 103     /* get size of data needed - same layout as user data, so that
 104      *   we can apply the reudction routines directly on these buffers
 105      */
 106     rc=opal_datatype_type_size(dtype, &dt_size);
 107     if( OMPI_SUCCESS != rc ) {
 108         goto Error;
 109     }
 110     message_extent=dt_extent*count;
 111 
 112     /* lenght of control and data regions */
 113     len_data_buffer=sm_module->data_memory_per_proc_per_segment;
 114 
 115     /* number of data types copies that the scratch buffer can hold */
 116     n_dts_per_buffer=((int) MAX_TMP_BUFFER)/dt_size;
 117     if ( 0 == n_dts_per_buffer ) {
 118         rc=OMPI_ERROR;
 119         goto Error;
 120     }
 121 
 122     /* need a read and a write buffer for a pair-wise exchange of data */
 123     n_dts_per_buffer/=2;
 124     len_data_buffer=n_dts_per_buffer*dt_size;
 125 
 126     /* compute number of stripes needed to process this collective */
 127     n_data_segments=(count+n_dts_per_buffer -1 ) / n_dts_per_buffer ;
 128 
 129     /* */
 130     n_peers=opal_list_get_size(peers);
 131 
 132     /* get my rank in the list */
 133     my_rank=0;
 134     for (item = opal_list_get_first(peers) ;
 135             item != opal_list_get_end(peers) ;
 136             item = opal_list_get_next(peers)) {
 137         if(ompi_proc_local()==(ompi_proc_t *)item){
 138             /* this is the pointer to my proc strucuture */
 139             break;
 140         }
 141         my_rank++;
 142     }
 143     proc_array=(ompi_proc_t **)malloc(sizeof(ompi_proc_t *)*n_peers);
 144     if( NULL == proc_array) {
 145         goto Error;
 146     }
 147     cnt=0;
 148     for (item = opal_list_get_first(peers) ;
 149             item != opal_list_get_end(peers) ;
 150             item = opal_list_get_next(peers)) {
 151         proc_array[cnt]=(ompi_proc_t *)item;
 152         cnt++;
 153     }
 154 
 155     /* get my reduction communication pattern */
 156     ret=ompi_netpatterns_setup_recursive_doubling_tree_node(n_peers,my_rank,&my_exchange_node);
 157     if(OMPI_SUCCESS != ret){
 158         return ret;
 159     }
 160 
 161     /* setup flags for non-blocking communications */
 162     recv_done=&recv_completion_flag;
 163     send_done=&send_completion_flag;
 164 
 165     /* set data type */
 166     if(&opal_datatype_int4==dtype) {
 167         data_type=TYPE_INT4;
 168     }
 169 
 170     count_processed=0;
 171 
 172     /* get a pointer to the shared-memory working buffer */
 173     /* NOTE: starting with a rather synchronous approach */
 174     for( stripe_number=0 ; stripe_number < n_data_segments ; stripe_number++ ) {
 175 
 176         /* get number of elements to process in this stripe */
 177         count_this_stripe=n_dts_per_buffer;
 178         if( count_processed + count_this_stripe > count )
 179             count_this_stripe=count-count_processed;
 180 
 181         /* copy data from the input buffer into the temp buffer */
 182         sbuf_current=(char *)sbuf+count_processed*dt_size;
 183         memcopy(scratch_bufers[send_buffer],sbuf_current,count_this_stripe*dt_size);
 184 
 185         /* copy data in from the "extra" source, if need be */
 186         if(0 < my_exchange_node->n_extra_sources)  {
 187 
 188             if ( EXCHANGE_NODE == my_exchange_node->node_type ) {
 189 
 190                 /*
 191                 ** Receive data from extra node
 192                 */
 193 
 194                 extra_rank=my_exchange_node.rank_extra_source;
 195                 recv_iov.iov_base=scratch_bufers[recv_buffer];
 196                 recv_iov.iov_len=count_this_stripe*dt_size;
 197                 rc = ompi_rte_recv(&(proc_array[extra_rank]->proc_name), &recv_iov, 1,
 198                         OMPI_RML_TAG_ALLREDUCE , 0);
 199                 if(OMPI_SUCCESS != rc ) {
 200                     goto  Error;
 201                 }
 202 
 203                 /* apply collective operation to first half of the data */
 204                 if( 0 < count_this_stripe ) {
 205                     op_reduce(op_type,(void *)scratch_bufers[recv_buffer],
 206                             (void *)scratch_bufers[send_buffer], n_my_count,TYPE_INT4);
 207                 }
 208 
 209 
 210             } else {
 211 
 212                 /*
 213                 ** Send data to "partner" node
 214                 */
 215                 extra_rank=my_exchange_node.rank_extra_source;
 216                 send_iov.iov_base=scratch_bufers[send_buffer];
 217                 send_iov.iov_len=count_this_stripe*dt_size;
 218                 rc = ompi_rte_send(&(proc_array[extra_rank]->proc_name), &send_iov, 1,
 219                         OMPI_RML_TAG_ALLREDUCE , 0);
 220                 if(OMPI_SUCCESS != rc ) {
 221                     goto  Error;
 222                 }
 223             }
 224 
 225             /* change pointer to scratch buffer - this was we can send data
 226             ** that we have summed w/o a memory copy, and receive data into the
 227             ** other buffer, w/o fear of over writting data that has not yet
 228             ** completed being send
 229             */
 230             recv_buffer^=1;
 231             send_buffer^=1;
 232         }
 233 
 234         MB();
 235         /*
 236          * Signal parent that data is ready
 237          */
 238         tag=base_tag+1;
 239         my_ctl_pointer->flag=tag;
 240 
 241         /* loop over data exchanges */
 242         for(exchange=0 ; exchange < my_exchange_node->n_exchanges ; exchange++) {
 243 
 244             /* debug
 245             t4=opal_sys_timer_get_cycles();
 246              end debug */
 247 
 248 
 249             my_write_pointer=my_tmp_data_buffer[index_write];
 250             my_read_pointer=my_tmp_data_buffer[index_read];
 251 
 252             /* is the remote data read */
 253             pair_rank=my_exchange_node->rank_exchanges[exchange];
 254 
 255             *recv_done=0;
 256             *send_done=0;
 257             MB();
 258 
 259             /* post non-blocking receive */
 260             recv_iov.iov_base=scratch_bufers[send_buffer];
 261             recv_iov.iov_len=count_this_stripe*dt_size;
 262             rc = ompi_rte_recv_nb(&(proc_array[extra_rank]->proc_name), recv_iov, 1,
 263                         OMPI_RML_TAG_ALLREDUCE , 0, recv_completion, recv_done);
 264 
 265             /* post non-blocking send */
 266             send_iov.iov_base=scratch_bufers[send_buffer];
 267             send_iov.iov_len=count_this_stripe*dt_size;
 268             rc = ompi_rte_send_nb(&(proc_array[extra_rank]->proc_name), send_iov, 1,
 269                         OMPI_RML_TAG_ALLREDUCE , 0, send_completion, send_done);
 270 
 271             /* wait on receive completion */
 272             while(!(*recv_done) ) {
 273                 opal_progress();
 274             }
 275 
 276             /* reduce the data */
 277             if( 0 < count_this_stripe ) {
 278                 op_reduce(op_type,(void *)scratch_bufers[recv_buffer],
 279                         (void *)scratch_bufers[send_buffer], n_my_count,TYPE_INT4);
 280             }
 281 
 282 
 283             /* get ready for next step */
 284             index_read=(exchange&1);
 285             index_write=((exchange+1)&1);
 286 
 287             /* wait on send completion */
 288             while(!(*send_done) ) {
 289                 opal_progress();
 290             }
 291 
 292         }
 293 
 294         /* copy data in from the "extra" source, if need be */
 295         if(0 < my_exchange_node->n_extra_sources)  {
 296 
 297             if ( EXTRA_NODE == my_exchange_node->node_type ) {
 298                 /*
 299                 ** receive the data
 300                 ** */
 301                 extra_rank=my_exchange_node->rank_extra_source;
 302 
 303                 recv_iov.iov_base=scratch_bufers[recv_buffer];
 304                 recv_iov.iov_len=count_this_stripe*dt_size;
 305                 rc = ompi_rte_recv(&(proc_array[extra_rank]->proc_name), &recv_iov, 1,
 306                         OMPI_RML_TAG_ALLREDUCE , 0);
 307                 if(OMPI_SUCCESS != rc ) {
 308                     goto  Error;
 309                 }
 310 
 311             } else {
 312                 /* send the data to the pair-rank outside of the power of 2 set
 313                 ** of ranks
 314                 */
 315 
 316                 extra_rank=my_exchange_node->rank_extra_source;
 317                 send_iov.iov_base=scratch_bufers[recv_buffer];
 318                 send_iov.iov_len=count_this_stripe*dt_size;
 319                 rc = ompi_rte_recv(&(proc_array[extra_rank]->proc_name), &send_iov, 1,
 320                         OMPI_RML_TAG_ALLREDUCE , 0);
 321                 if(OMPI_SUCCESS != rc ) {
 322                     goto  Error;
 323                 }
 324             }
 325         }
 326 
 327         /* copy data into the destination buffer */
 328         rc=ompi_datatype_copy_content_same_ddt(dtype, count_this_stripe,
 329                 (char *)((char *)rbuf+dt_extent*count_processed),
 330                 (char *)my_write_pointer);
 331         if( 0 != rc ) {
 332             return OMPI_ERROR;
 333         }
 334 
 335         /* copy data from the temp buffer into the output buffer */
 336         rbuf_current=(char *)rbuf+count_processed*dt_size;
 337         memcopy(scratch_bufers[recv_buffer],rbuf_current,count_this_stripe*dt_size);
 338 
 339         /* update the count of elements processed */
 340         count_processed+=count_this_stripe;
 341     }
 342 
 343     /* return */
 344     return rc;
 345 
 346 Error:
 347     return rc;
 348 }

/* [<][>][^][v][top][bottom][index][help] */