source: src/mpi/coll/iscatter.c

Last change on this file was 0ad10e, checked in by Ken Raffenetti <raffenet@…>, 4 months ago

Make MPIR, MPIU, MPID, MPII, MPICH namespace consistent.

MPIR should only be used for functionality that is exposed by the
MPI-layer downward to the device. Other functionality owned by the
MPI layer that is used internally within that layer should be called
MPII. MPID functionality is device-specific functionality that is
exposed to the MPI layer. MPIU namespace is being removed for now.

Signed-off-by: Ken Raffenetti <raffenet@…>

  • Property mode set to 100644
File size: 29.7 KB
Line 
1/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
2/*
3 *  (C) 2010 by Argonne National Laboratory.
4 *      See COPYRIGHT in top-level directory.
5 */
6
7#include "mpiimpl.h"
8
9/* -- Begin Profiling Symbol Block for routine MPI_Iscatter */
10#if defined(HAVE_PRAGMA_WEAK)
11#pragma weak MPI_Iscatter = PMPI_Iscatter
12#elif defined(HAVE_PRAGMA_HP_SEC_DEF)
13#pragma _HP_SECONDARY_DEF PMPI_Iscatter  MPI_Iscatter
14#elif defined(HAVE_PRAGMA_CRI_DUP)
15#pragma _CRI duplicate MPI_Iscatter as PMPI_Iscatter
16#elif defined(HAVE_WEAK_ATTRIBUTE)
17int MPI_Iscatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf,
18                 int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm,
19                 MPI_Request *request)
20                 __attribute__((weak,alias("PMPI_Iscatter")));
21#endif
22/* -- End Profiling Symbol Block */
23
24/* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build
25   the MPI routines */
26#ifndef MPICH_MPI_FROM_PMPI
27#undef MPI_Iscatter
28#define MPI_Iscatter PMPI_Iscatter
29
30/* helper callbacks and associated state structures */
31struct shared_state {
32    int sendcount;
33    int curr_count;
34    MPI_Aint send_subtree_count;
35    int nbytes;
36    MPI_Status status;
37};
38static int get_count(MPIR_Comm *comm, int tag, void *state)
39{
40    struct shared_state *ss = state;
41    MPIR_Get_count_impl(&ss->status, MPI_BYTE, &ss->curr_count);
42    return MPI_SUCCESS;
43}
44static int calc_send_count_root(MPIR_Comm *comm, int tag, void *state, void *state2)
45{
46    struct shared_state *ss = state;
47    int mask = (int)(size_t)state2;
48    ss->send_subtree_count = ss->curr_count - ss->sendcount * mask;
49    return MPI_SUCCESS;
50}
51static int calc_send_count_non_root(MPIR_Comm *comm, int tag, void *state, void *state2)
52{
53    struct shared_state *ss = state;
54    int mask = (int)(size_t)state2;
55    ss->send_subtree_count = ss->curr_count - ss->nbytes * mask;
56    return MPI_SUCCESS;
57}
58static int calc_curr_count(MPIR_Comm *comm, int tag, void *state)
59{
60    struct shared_state *ss = state;
61    ss->curr_count -= ss->send_subtree_count;
62    return MPI_SUCCESS;
63}
64
65/* any non-MPI functions go here, especially non-static ones */
66
67/* This is the default implementation of scatter. The algorithm is:
68
69   Algorithm: MPI_Scatter
70
71   We use a binomial tree algorithm for both short and
72   long messages. At nodes other than leaf nodes we need to allocate
73   a temporary buffer to store the incoming message. If the root is
74   not rank 0, we reorder the sendbuf in order of relative ranks by
75   copying it into a temporary buffer, so that all the sends from the
76   root are contiguous and in the right order. In the heterogeneous
77   case, we first pack the buffer by using MPI_Pack and then do the
78   scatter.
79
80   Cost = lgp.alpha + n.((p-1)/p).beta
81   where n is the total size of the data to be scattered from the root.
82
83   Possible improvements:
84
85   End Algorithm: MPI_Scatter
86*/
87#undef FUNCNAME
88#define FUNCNAME MPIR_Iscatter_intra
89#undef FCNAME
90#define FCNAME MPL_QUOTE(FUNCNAME)
91int MPIR_Iscatter_intra(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
92                        void *recvbuf, int recvcount, MPI_Datatype recvtype,
93                        int root, MPIR_Comm *comm_ptr, MPIR_Sched_t s)
94{
95    int mpi_errno = MPI_SUCCESS;
96    MPI_Aint extent = 0;
97    int rank, comm_size, is_homogeneous, sendtype_size;
98    int relative_rank;
99    int mask, recvtype_size=0, src, dst;
100    int tmp_buf_size = 0;
101    void *tmp_buf = NULL;
102    struct shared_state *ss = NULL;
103    MPIR_SCHED_CHKPMEM_DECL(4);
104
105    comm_size = comm_ptr->local_size;
106    rank = comm_ptr->rank;
107
108    if (((rank == root) && (sendcount == 0)) || ((rank != root) && (recvcount == 0)))
109        goto fn_exit;
110
111    is_homogeneous = 1;
112#ifdef MPID_HAS_HETERO
113    if (comm_ptr->is_hetero)
114        is_homogeneous = 0;
115#endif
116
117/* Use binomial tree algorithm */
118
119    MPIR_SCHED_CHKPMEM_MALLOC(ss, struct shared_state *, sizeof(struct shared_state), mpi_errno, "shared_state");
120    ss->sendcount = sendcount;
121
122    if (rank == root)
123        MPID_Datatype_get_extent_macro(sendtype, extent);
124
125    relative_rank = (rank >= root) ? rank - root : rank - root + comm_size;
126
127    if (is_homogeneous) {
128        /* communicator is homogeneous */
129        if (rank == root) {
130            /* We separate the two cases (root and non-root) because
131               in the event of recvbuf=MPI_IN_PLACE on the root,
132               recvcount and recvtype are not valid */
133            MPID_Datatype_get_size_macro(sendtype, sendtype_size);
134            MPIR_Ensure_Aint_fits_in_pointer(MPIR_VOID_PTR_CAST_TO_MPI_AINT sendbuf +
135                                             extent*sendcount*comm_size);
136
137            ss->nbytes = sendtype_size * sendcount;
138        }
139        else {
140            MPID_Datatype_get_size_macro(recvtype, recvtype_size);
141            MPIR_Ensure_Aint_fits_in_pointer(extent*recvcount*comm_size);
142            ss->nbytes = recvtype_size * recvcount;
143        }
144
145        ss->curr_count = 0;
146
147        /* all even nodes other than root need a temporary buffer to
148           receive data of max size (ss->nbytes*comm_size)/2 */
149        if (relative_rank && !(relative_rank % 2)) {
150            tmp_buf_size = (ss->nbytes*comm_size)/2;
151            MPIR_SCHED_CHKPMEM_MALLOC(tmp_buf, void *, tmp_buf_size, mpi_errno, "tmp_buf");
152        }
153
154        /* if the root is not rank 0, we reorder the sendbuf in order of
155           relative ranks and copy it into a temporary buffer, so that
156           all the sends from the root are contiguous and in the right
157           order. */
158        if (rank == root) {
159            if (root != 0) {
160                tmp_buf_size = ss->nbytes*comm_size;
161                MPIR_SCHED_CHKPMEM_MALLOC(tmp_buf, void *, tmp_buf_size, mpi_errno, "tmp_buf");
162
163                if (recvbuf != MPI_IN_PLACE)
164                    mpi_errno = MPIR_Sched_copy(((char *) sendbuf + extent*sendcount*rank),
165                                                sendcount*(comm_size-rank), sendtype,
166                                                tmp_buf, ss->nbytes*(comm_size-rank), MPI_BYTE, s);
167                else
168                    mpi_errno = MPIR_Sched_copy(((char *) sendbuf + extent*sendcount*(rank+1)),
169                                                sendcount*(comm_size-rank-1), sendtype,
170                                                ((char *)tmp_buf + ss->nbytes),
171                                                ss->nbytes*(comm_size-rank-1), MPI_BYTE, s);
172                if (mpi_errno) MPIR_ERR_POP(mpi_errno);
173
174                mpi_errno = MPIR_Sched_copy(sendbuf, sendcount*rank, sendtype,
175                                            ((char *) tmp_buf + ss->nbytes*(comm_size-rank)),
176                                            ss->nbytes*rank, MPI_BYTE, s);
177                if (mpi_errno) MPIR_ERR_POP(mpi_errno);
178
179                MPIR_SCHED_BARRIER(s);
180                ss->curr_count = ss->nbytes*comm_size;
181            }
182            else
183                ss->curr_count = sendcount*comm_size;
184        }
185
186        /* root has all the data; others have zero so far */
187
188        mask = 0x1;
189        while (mask < comm_size) {
190            if (relative_rank & mask) {
191                src = rank - mask;
192                if (src < 0) src += comm_size;
193
194                /* The leaf nodes receive directly into recvbuf because
195                   they don't have to forward data to anyone. Others
196                   receive data into a temporary buffer. */
197                if (relative_rank % 2) {
198                    mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, recvtype, src, comm_ptr, s);
199                    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
200                    MPIR_SCHED_BARRIER(s);
201                }
202                else {
203
204                    /* the recv size is larger than what may be sent in
205                       some cases. query amount of data actually received */
206                    mpi_errno = MPIR_Sched_recv_status(tmp_buf, tmp_buf_size, MPI_BYTE, src, comm_ptr, &ss->status, s);
207                    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
208                    MPIR_SCHED_BARRIER(s);
209                    mpi_errno = MPIR_Sched_cb(&get_count, ss, s);
210                    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
211                    MPIR_SCHED_BARRIER(s);
212                }
213                break;
214            }
215            mask <<= 1;
216        }
217
218        /* This process is responsible for all processes that have bits
219           set from the LSB upto (but not including) mask.  Because of
220           the "not including", we start by shifting mask back down
221           one. */
222
223        mask >>= 1;
224        while (mask > 0) {
225            if (relative_rank + mask < comm_size) {
226                dst = rank + mask;
227                if (dst >= comm_size) dst -= comm_size;
228
229                if ((rank == root) && (root == 0))
230                {
231#if 0
232                    /* FIXME how can this be right? shouldn't (sendcount*mask)
233                     * be the amount sent and curr_cnt be reduced by that?  Or
234                     * is it always true the (curr_cnt/2==sendcount*mask)? */
235                    send_subtree_cnt = curr_cnt - sendcount * mask;
236#endif
237                    mpi_errno = MPIR_Sched_cb2(&calc_send_count_root, ss, ((void *)(size_t)mask), s);
238                    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
239                    MPIR_SCHED_BARRIER(s);
240
241                    /* mask is also the size of this process's subtree */
242                    mpi_errno = MPIR_Sched_send_defer(((char *)sendbuf + extent*sendcount*mask),
243                                                      &ss->send_subtree_count, sendtype, dst,
244                                                      comm_ptr, s);
245                    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
246                    MPIR_SCHED_BARRIER(s);
247                }
248                else
249                {
250                    /* non-zero root and others */
251                    mpi_errno = MPIR_Sched_cb2(&calc_send_count_non_root, ss, ((void *)(size_t)mask), s);
252                    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
253                    MPIR_SCHED_BARRIER(s);
254
255                    /* mask is also the size of this process's subtree */
256                    mpi_errno = MPIR_Sched_send_defer(((char *)tmp_buf + ss->nbytes*mask),
257                                                      &ss->send_subtree_count, MPI_BYTE, dst,
258                                                      comm_ptr, s);
259                    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
260                    MPIR_SCHED_BARRIER(s);
261                }
262                mpi_errno = MPIR_Sched_cb(&calc_curr_count, ss, s);
263                if (mpi_errno) MPIR_ERR_POP(mpi_errno);
264                MPIR_SCHED_BARRIER(s);
265            }
266            mask >>= 1;
267        }
268
269        if ((rank == root) && (root == 0) && (recvbuf != MPI_IN_PLACE)) {
270            /* for root=0, put root's data in recvbuf if not MPI_IN_PLACE */
271            mpi_errno = MPIR_Sched_copy(sendbuf, sendcount, sendtype,
272                                        recvbuf, recvcount, recvtype, s);
273            if (mpi_errno) MPIR_ERR_POP(mpi_errno);
274            MPIR_SCHED_BARRIER(s);
275        }
276        else if (!(relative_rank % 2) && (recvbuf != MPI_IN_PLACE)) {
277            /* for non-zero root and non-leaf nodes, copy from tmp_buf
278               into recvbuf */
279            mpi_errno = MPIR_Sched_copy(tmp_buf, ss->nbytes, MPI_BYTE,
280                                        recvbuf, recvcount, recvtype, s);
281            if (mpi_errno) MPIR_ERR_POP(mpi_errno);
282            MPIR_SCHED_BARRIER(s);
283        }
284
285    }
286#ifdef MPID_HAS_HETERO
287    else { /* communicator is heterogeneous */
288        int position;
289        MPIR_Assertp(FALSE); /* hetero case not yet implemented */
290
291        if (rank == root) {
292            MPIR_Pack_size_impl(sendcount*comm_size, sendtype, &tmp_buf_size);
293
294            MPIR_CHKLMEM_MALLOC(tmp_buf, void *, tmp_buf_size, mpi_errno, "tmp_buf");
295
296          /* calculate the value of nbytes, the number of bytes in packed
297             representation that each process receives. We can't
298             accurately calculate that from tmp_buf_size because
299             MPI_Pack_size returns an upper bound on the amount of memory
300             required. (For example, for a single integer, MPICH-1 returns
301             pack_size=12.) Therefore, we actually pack some data into
302             tmp_buf and see by how much 'position' is incremented. */
303
304            position = 0;
305            mpi_errno = MPIR_Pack_impl(sendbuf, 1, sendtype, tmp_buf, tmp_buf_size, &position);
306            if (mpi_errno) MPIR_ERR_POP(mpi_errno);
307
308            nbytes = position*sendcount;
309
310            curr_cnt = nbytes*comm_size;
311
312            if (root == 0) {
313                if (recvbuf != MPI_IN_PLACE) {
314                    position = 0;
315                    mpi_errno = MPIR_Pack_impl(sendbuf, sendcount*comm_size, sendtype, tmp_buf,
316                                               tmp_buf_size, &position);
317                    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
318                }
319                else {
320                    position = nbytes;
321                    mpi_errno = MPIR_Pack_impl(((char *) sendbuf + extent*sendcount),
322                                               sendcount*(comm_size-1), sendtype, tmp_buf,
323                                               tmp_buf_size, &position);
324                    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
325                }
326            }
327            else {
328                if (recvbuf != MPI_IN_PLACE) {
329                    position = 0;
330                    mpi_errno = MPIR_Pack_impl(((char *) sendbuf + extent*sendcount*rank),
331                                               sendcount*(comm_size-rank), sendtype, tmp_buf,
332                                               tmp_buf_size, &position);
333                    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
334                }
335                else {
336                    position = nbytes;
337                    mpi_errno = MPIR_Pack_impl(((char *) sendbuf + extent*sendcount*(rank+1)),
338                                               sendcount*(comm_size-rank-1), sendtype, tmp_buf,
339                                               tmp_buf_size, &position);
340                    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
341                }
342                mpi_errno = MPIR_Pack_impl(sendbuf, sendcount*rank, sendtype, tmp_buf,
343                                           tmp_buf_size, &position);
344                if (mpi_errno) MPIR_ERR_POP(mpi_errno);
345            }
346        }
347        else {
348            MPIR_Pack_size_impl(recvcount*(comm_size/2), recvtype, &tmp_buf_size);
349            MPIR_CHKLMEM_MALLOC(tmp_buf, void *, tmp_buf_size, mpi_errno, "tmp_buf");
350
351            /* calculate nbytes */
352            position = 0;
353            mpi_errno = MPIR_Pack_impl(recvbuf, 1, recvtype, tmp_buf, tmp_buf_size, &position);
354            if (mpi_errno) MPIR_ERR_POP(mpi_errno);
355            nbytes = position*recvcount;
356
357            curr_cnt = 0;
358        }
359
360        mask = 0x1;
361        while (mask < comm_size) {
362            if (relative_rank & mask) {
363                src = rank - mask;
364                if (src < 0) src += comm_size;
365
366                mpi_errno = MPIC_Recv(tmp_buf, tmp_buf_size, MPI_BYTE, src,
367                                         MPIR_SCATTER_TAG, comm_ptr, &status, errflag);
368                if (mpi_errno) {
369                    /* for communication errors, just record the error but continue */
370                    *errflag = MPIR_ERR_GET_CLASS(mpi_errno);
371                    MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
372                    MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
373                    curr_cnt = 0;
374                } else
375                    /* the recv size is larger than what may be sent in
376                       some cases. query amount of data actually received */
377                    MPIR_Get_count_impl(&status, MPI_BYTE, &curr_cnt);
378                break;
379            }
380            mask <<= 1;
381        }
382
383        /* This process is responsible for all processes that have bits
384           set from the LSB upto (but not including) mask.  Because of
385           the "not including", we start by shifting mask back down
386           one. */
387
388        mask >>= 1;
389        while (mask > 0) {
390            if (relative_rank + mask < comm_size) {
391                dst = rank + mask;
392                if (dst >= comm_size) dst -= comm_size;
393
394                send_subtree_cnt = curr_cnt - nbytes * mask;
395                /* mask is also the size of this process's subtree */
396                mpi_errno = MPIC_Send(((char *)tmp_buf + nbytes*mask),
397                                         send_subtree_cnt, MPI_BYTE, dst,
398                                         MPIR_SCATTER_TAG, comm_ptr, errflag);
399                if (mpi_errno) {
400                    /* for communication errors, just record the error but continue */
401                    *errflag = MPIR_ERR_GET_CLASS(mpi_errno);
402                    MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
403                    MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
404                }
405                curr_cnt -= send_subtree_cnt;
406            }
407            mask >>= 1;
408        }
409
410        /* copy local data into recvbuf */
411        position = 0;
412        if (recvbuf != MPI_IN_PLACE) {
413            mpi_errno = MPIR_Unpack_impl(tmp_buf, tmp_buf_size, &position, recvbuf,
414                                         recvcount, recvtype);
415            if (mpi_errno) MPIR_ERR_POP(mpi_errno);
416        }
417    }
418#endif /* MPID_HAS_HETERO */
419
420
421    MPIR_SCHED_CHKPMEM_COMMIT(s);
422 fn_exit:
423    return mpi_errno;
424 fn_fail:
425    MPIR_SCHED_CHKPMEM_REAP(s);
426    goto fn_exit;
427}
428
429#undef FUNCNAME
430#define FUNCNAME MPIR_Iscatter_inter
431#undef FCNAME
432#define FCNAME MPL_QUOTE(FUNCNAME)
433int MPIR_Iscatter_inter(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
434                        void *recvbuf, int recvcount, MPI_Datatype recvtype,
435                        int root, MPIR_Comm *comm_ptr, MPIR_Sched_t s)
436{
437/*  Intercommunicator scatter.
438    For short messages, root sends to rank 0 in remote group. rank 0
439    does local intracommunicator scatter (binomial tree).
440    Cost: (lgp+1).alpha + n.((p-1)/p).beta + n.beta
441
442    For long messages, we use linear scatter to avoid the extra n.beta.
443    Cost: p.alpha + n.beta
444*/
445    int mpi_errno = MPI_SUCCESS;
446    int rank, local_size, remote_size;
447    int i, nbytes, sendtype_size, recvtype_size;
448    MPI_Aint extent, true_extent, true_lb = 0;
449    void *tmp_buf = NULL;
450    MPIR_Comm *newcomm_ptr = NULL;
451    MPIR_SCHED_CHKPMEM_DECL(1);
452
453    if (root == MPI_PROC_NULL) {
454        /* local processes other than root do nothing */
455        goto fn_exit;
456    }
457
458    remote_size = comm_ptr->remote_size;
459    local_size  = comm_ptr->local_size;
460
461    if (root == MPI_ROOT) {
462        MPID_Datatype_get_size_macro(sendtype, sendtype_size);
463        nbytes = sendtype_size * sendcount * remote_size;
464    }
465    else {
466        /* remote side */
467        MPID_Datatype_get_size_macro(recvtype, recvtype_size);
468        nbytes = recvtype_size * recvcount * local_size;
469    }
470
471    if (nbytes < MPIR_CVAR_SCATTER_INTER_SHORT_MSG_SIZE) {
472        if (root == MPI_ROOT) {
473            /* root sends all data to rank 0 on remote group and returns */
474            mpi_errno = MPIR_Sched_send(sendbuf, sendcount*remote_size, sendtype, 0, comm_ptr, s);
475            if (mpi_errno) MPIR_ERR_POP(mpi_errno);
476            MPIR_SCHED_BARRIER(s);
477            goto fn_exit;
478        }
479        else {
480            /* remote group. rank 0 receives data from root. need to
481               allocate temporary buffer to store this data. */
482            rank = comm_ptr->rank;
483
484            if (rank == 0) {
485                MPIR_Type_get_true_extent_impl(recvtype, &true_lb, &true_extent);
486
487                MPID_Datatype_get_extent_macro(recvtype, extent);
488                MPIR_Ensure_Aint_fits_in_pointer(extent*recvcount*local_size);
489                MPIR_Ensure_Aint_fits_in_pointer(MPIR_VOID_PTR_CAST_TO_MPI_AINT sendbuf +
490                                                 sendcount*remote_size*extent);
491
492                MPIR_SCHED_CHKPMEM_MALLOC(tmp_buf, void *, recvcount*local_size*(MPL_MAX(extent,true_extent)),
493                                          mpi_errno, "tmp_buf");
494
495                /* adjust for potential negative lower bound in datatype */
496                tmp_buf = (void *)((char*)tmp_buf - true_lb);
497
498                mpi_errno = MPIR_Sched_recv(tmp_buf, recvcount*local_size, recvtype, root, comm_ptr, s);
499                if (mpi_errno) MPIR_ERR_POP(mpi_errno);
500                MPIR_SCHED_BARRIER(s);
501            }
502
503            /* Get the local intracommunicator */
504            if (!comm_ptr->local_comm)
505                MPII_Setup_intercomm_localcomm(comm_ptr);
506
507            newcomm_ptr = comm_ptr->local_comm;
508
509            /* now do the usual scatter on this intracommunicator */
510            MPIR_Assert(newcomm_ptr->coll_fns != NULL);
511            MPIR_Assert(newcomm_ptr->coll_fns->Iscatter_sched != NULL);
512            mpi_errno = newcomm_ptr->coll_fns->Iscatter_sched(tmp_buf, recvcount, recvtype,
513                                                        recvbuf, recvcount, recvtype,
514                                                        0, newcomm_ptr, s);
515            if (mpi_errno) MPIR_ERR_POP(mpi_errno);
516            MPIR_SCHED_BARRIER(s);
517        }
518    }
519    else {
520        /* long message. use linear algorithm. */
521        if (root == MPI_ROOT) {
522            MPID_Datatype_get_extent_macro(sendtype, extent);
523            for (i = 0; i < remote_size; i++) {
524                mpi_errno = MPIR_Sched_send(((char *)sendbuf+sendcount*i*extent),
525                                            sendcount, sendtype, i, comm_ptr, s);
526                if (mpi_errno) MPIR_ERR_POP(mpi_errno);
527            }
528            MPIR_SCHED_BARRIER(s);
529        }
530        else {
531            mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, recvtype, root, comm_ptr, s);
532            if (mpi_errno) MPIR_ERR_POP(mpi_errno);
533            MPIR_SCHED_BARRIER(s);
534        }
535    }
536
537
538    MPIR_SCHED_CHKPMEM_COMMIT(s);
539fn_exit:
540    return mpi_errno;
541fn_fail:
542    MPIR_SCHED_CHKPMEM_REAP(s);
543    goto fn_exit;
544}
545
546#undef FUNCNAME
547#define FUNCNAME MPIR_Iscatter_impl
548#undef FCNAME
549#define FCNAME MPL_QUOTE(FUNCNAME)
550int MPIR_Iscatter_impl(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPIR_Comm *comm_ptr, MPI_Request *request)
551{
552    int mpi_errno = MPI_SUCCESS;
553    MPIR_Request *reqp = NULL;
554    int tag = -1;
555    MPIR_Sched_t s = MPIR_SCHED_NULL;
556
557    *request = MPI_REQUEST_NULL;
558
559    mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag);
560    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
561    mpi_errno = MPIR_Sched_create(&s);
562    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
563
564    MPIR_Assert(comm_ptr->coll_fns->Iscatter_sched != NULL);
565    mpi_errno = comm_ptr->coll_fns->Iscatter_sched(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, s);
566    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
567
568    mpi_errno = MPIR_Sched_start(&s, comm_ptr, tag, &reqp);
569    if (reqp)
570        *request = reqp->handle;
571    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
572
573fn_exit:
574    return mpi_errno;
575fn_fail:
576    goto fn_exit;
577}
578
579#endif /* MPICH_MPI_FROM_PMPI */
580
581#undef FUNCNAME
582#define FUNCNAME MPI_Iscatter
583#undef FCNAME
584#define FCNAME MPL_QUOTE(FUNCNAME)
585/*@
586MPI_Iscatter - Sends data from one process to all other processes in a
587               communicator in a nonblocking way
588
589Input Parameters:
590+ sendbuf - address of send buffer (significant only at root) (choice)
591. sendcount - number of elements sent to each process (significant only at root) (non-negative integer)
592. sendtype - data type of send buffer elements (significant only at root) (handle)
593. recvcount - number of elements in receive buffer (non-negative integer)
594. recvtype - data type of receive buffer elements (handle)
595. root - rank of sending process (integer)
596- comm - communicator (handle)
597
598Output Parameters:
599+ recvbuf - starting address of the receive buffer (choice)
600- request - communication request (handle)
601
602.N ThreadSafe
603
604.N Fortran
605
606.N Errors
607@*/
608int MPI_Iscatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
609                 void *recvbuf, int recvcount, MPI_Datatype recvtype, int root,
610                 MPI_Comm comm, MPI_Request *request)
611{
612    int mpi_errno = MPI_SUCCESS;
613    MPIR_Comm *comm_ptr = NULL;
614    MPIR_FUNC_TERSE_STATE_DECL(MPID_STATE_MPI_ISCATTER);
615
616    MPID_THREAD_CS_ENTER(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
617    MPIR_FUNC_TERSE_ENTER(MPID_STATE_MPI_ISCATTER);
618
619    /* Validate parameters, especially handles needing to be converted */
620#   ifdef HAVE_ERROR_CHECKING
621    {
622        MPID_BEGIN_ERROR_CHECKS
623        {
624            MPIR_ERRTEST_COMM(comm, mpi_errno);
625
626            /* TODO more checks may be appropriate */
627        }
628        MPID_END_ERROR_CHECKS
629    }
630#   endif /* HAVE_ERROR_CHECKING */
631
632    /* Convert MPI object handles to object pointers */
633    MPIR_Comm_get_ptr(comm, comm_ptr);
634
635    /* Validate parameters and objects (post conversion) */
636#   ifdef HAVE_ERROR_CHECKING
637    {
638        MPID_BEGIN_ERROR_CHECKS
639        {
640            MPIR_Datatype *sendtype_ptr, *recvtype_ptr;
641            MPIR_Comm_valid_ptr( comm_ptr, mpi_errno, FALSE );
642            if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
643                MPIR_ERRTEST_INTRA_ROOT(comm_ptr, root, mpi_errno);
644
645                if (comm_ptr->rank == root) {
646                    MPIR_ERRTEST_COUNT(sendcount, mpi_errno);
647                    MPIR_ERRTEST_DATATYPE(sendtype, "sendtype", mpi_errno);
648                    if (HANDLE_GET_KIND(sendtype) != HANDLE_KIND_BUILTIN) {
649                        MPID_Datatype_get_ptr(sendtype, sendtype_ptr);
650                        MPIR_Datatype_valid_ptr(sendtype_ptr, mpi_errno);
651                        if (mpi_errno != MPI_SUCCESS) goto fn_fail;
652                        MPID_Datatype_committed_ptr(sendtype_ptr, mpi_errno);
653                        if (mpi_errno != MPI_SUCCESS) goto fn_fail;
654                    }
655                    MPIR_ERRTEST_USERBUFFER(sendbuf,sendcount,sendtype,mpi_errno);
656                    MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, sendcount, mpi_errno);
657
658                    /* catch common aliasing cases */
659                    if (recvbuf != MPI_IN_PLACE && sendtype == recvtype && sendcount == recvcount && recvcount != 0) {
660                        int sendtype_size;
661                        MPID_Datatype_get_size_macro(sendtype, sendtype_size);
662                        MPIR_ERRTEST_ALIAS_COLL(recvbuf, (char*)sendbuf + comm_ptr->rank*sendcount*sendtype_size, mpi_errno);
663                    }
664                }
665                else
666                    MPIR_ERRTEST_RECVBUF_INPLACE(recvbuf, recvcount, mpi_errno);
667
668                if (recvbuf != MPI_IN_PLACE) {
669                    MPIR_ERRTEST_COUNT(recvcount, mpi_errno);
670                    MPIR_ERRTEST_DATATYPE(recvtype, "recvtype", mpi_errno);
671                    if (HANDLE_GET_KIND(recvtype) != HANDLE_KIND_BUILTIN) {
672                        MPID_Datatype_get_ptr(recvtype, recvtype_ptr);
673                        MPIR_Datatype_valid_ptr(recvtype_ptr, mpi_errno);
674                        if (mpi_errno != MPI_SUCCESS) goto fn_fail;
675                        MPID_Datatype_committed_ptr(recvtype_ptr, mpi_errno);
676                        if (mpi_errno != MPI_SUCCESS) goto fn_fail;
677                    }
678                    MPIR_ERRTEST_USERBUFFER(recvbuf,recvcount,recvtype,mpi_errno);
679                }
680            }
681
682            if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) {
683                MPIR_ERRTEST_INTER_ROOT(comm_ptr, root, mpi_errno);
684
685                if (root == MPI_ROOT) {
686                    MPIR_ERRTEST_COUNT(sendcount, mpi_errno);
687                    MPIR_ERRTEST_DATATYPE(sendtype, "sendtype", mpi_errno);
688                    if (HANDLE_GET_KIND(sendtype) != HANDLE_KIND_BUILTIN) {
689                        MPID_Datatype_get_ptr(sendtype, sendtype_ptr);
690                        MPIR_Datatype_valid_ptr(sendtype_ptr, mpi_errno);
691                        if (mpi_errno != MPI_SUCCESS) goto fn_fail;
692                        MPID_Datatype_committed_ptr(sendtype_ptr, mpi_errno);
693                        if (mpi_errno != MPI_SUCCESS) goto fn_fail;
694                    }
695                    MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, sendcount, mpi_errno);
696                    MPIR_ERRTEST_USERBUFFER(sendbuf,sendcount,sendtype,mpi_errno);
697                }
698                else if (root != MPI_PROC_NULL) {
699                    MPIR_ERRTEST_COUNT(recvcount, mpi_errno);
700                    MPIR_ERRTEST_DATATYPE(recvtype, "recvtype", mpi_errno);
701                    if (HANDLE_GET_KIND(recvtype) != HANDLE_KIND_BUILTIN) {
702                        MPID_Datatype_get_ptr(recvtype, recvtype_ptr);
703                        MPIR_Datatype_valid_ptr(recvtype_ptr, mpi_errno);
704                        if (mpi_errno != MPI_SUCCESS) goto fn_fail;
705                        MPID_Datatype_committed_ptr(recvtype_ptr, mpi_errno);
706                        if (mpi_errno != MPI_SUCCESS) goto fn_fail;
707                    }
708                    MPIR_ERRTEST_RECVBUF_INPLACE(recvbuf, recvcount, mpi_errno);
709                    MPIR_ERRTEST_USERBUFFER(recvbuf,recvcount,recvtype,mpi_errno);
710                }
711            }
712        }
713        MPID_END_ERROR_CHECKS
714    }
715#   endif /* HAVE_ERROR_CHECKING */
716
717    /* ... body of routine ...  */
718
719    mpi_errno = MPID_Iscatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, request);
720    if (mpi_errno) MPIR_ERR_POP(mpi_errno);
721
722    /* ... end of body of routine ... */
723
724fn_exit:
725    MPIR_FUNC_TERSE_EXIT(MPID_STATE_MPI_ISCATTER);
726    MPID_THREAD_CS_EXIT(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
727    return mpi_errno;
728
729fn_fail:
730    /* --BEGIN ERROR HANDLING-- */
731#   ifdef HAVE_ERROR_CHECKING
732    {
733        mpi_errno = MPIR_Err_create_code(
734            mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER,
735            "**mpi_iscatter", "**mpi_iscatter %p %d %D %p %d %D %d %C %p", sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, request);
736    }
737#   endif
738    mpi_errno = MPIR_Err_return_comm(comm_ptr, FCNAME, mpi_errno);
739    goto fn_exit;
740    /* --END ERROR HANDLING-- */
741    goto fn_exit;
742}
Note: See TracBrowser for help on using the repository browser.