Given an array of size numpoints, we want to split on num_nodes nodes. This function returns two arrays: count and displs.
The i-th element of the count array gives the number of elements that must be calculated by the process with id (i-1). The i-th element of the displs array gives the displacement of the array calculated locally on the process with id (i-1) with respect to the global array.
These values are those to be passed to the functions MPI_Scatterv, MPI_Gatherv and MPI_Alltoallv.
one can use the following do loop to run over the needed elements, if the full array is stored on all nodes: do i=displs(my_node_id)+1,displs(my_node_id)+counts(my_node_id)
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
integer, | intent(in) | :: | numpoints | Number of elements of the array to be scattered |
||
integer, | intent(out), | dimension(0:num_nodes - 1) | :: | counts | Array (of size num_nodes) with the number of elements of the array on each node |
|
integer, | intent(out), | dimension(0:num_nodes - 1) | :: | displs | Array (of size num_nodes) with the displacement relative to the global array |
subroutine comms_array_split(numpoints, counts, displs)
!! Given an array of size numpoints, we want to split on num_nodes nodes. This function returns
!! two arrays: count and displs.
!!
!! The i-th element of the count array gives the number of elements
!! that must be calculated by the process with id (i-1).
!! The i-th element of the displs array gives the displacement of the array calculated locally on
!! the process with id (i-1) with respect to the global array.
!!
!! These values are those to be passed to the functions MPI_Scatterv, MPI_Gatherv and MPI_Alltoallv.
!!
!! one can use the following do loop to run over the needed elements, if the full array is stored
!! on all nodes:
!! do i=displs(my_node_id)+1,displs(my_node_id)+counts(my_node_id)
!!
use w90_io
integer, intent(in) :: numpoints
!! Number of elements of the array to be scattered
integer, dimension(0:num_nodes - 1), intent(out) :: counts
!! Array (of size num_nodes) with the number of elements of the array on each node
integer, dimension(0:num_nodes - 1), intent(out) :: displs
!! Array (of size num_nodes) with the displacement relative to the global array
integer :: ratio, remainder, i
ratio = numpoints/num_nodes
remainder = MOD(numpoints, num_nodes)
do i = 0, num_nodes - 1
if (i < remainder) then
counts(i) = ratio + 1
displs(i) = i*(ratio + 1)
else
counts(i) = ratio
displs(i) = remainder*(ratio + 1) + (i - remainder)*ratio
end if
end do
end subroutine comms_array_split