Line data Source code
1 : #if 0
2 : ! This file is part of ELPA.
3 : !
4 : ! The ELPA library was originally created by the ELPA consortium,
5 : ! consisting of the following organizations:
6 : !
7 : ! - Max Planck Computing and Data Facility (MPCDF), formerly known as
8 : ! Rechenzentrum Garching der Max-Planck-Gesellschaft (RZG),
9 : ! - Bergische Universität Wuppertal, Lehrstuhl für angewandte
10 : ! Informatik,
11 : ! - Technische Universität München, Lehrstuhl für Informatik mit
12 : ! Schwerpunkt Wissenschaftliches Rechnen ,
13 : ! - Fritz-Haber-Institut, Berlin, Abt. Theorie,
14 : ! - Max-Plack-Institut für Mathematik in den Naturwissenschaften,
15 : ! Leipzig, Abt. Komplexe Strukutren in Biologie und Kognition,
16 : ! and
17 : ! - IBM Deutschland GmbH
18 : !
19 : !
20 : ! More information can be found here:
21 : ! http://elpa.mpcdf.mpg.de/
22 : !
23 : ! ELPA is free software: you can redistribute it and/or modify
24 : ! it under the terms of the version 3 of the license of the
25 : ! GNU Lesser General Public License as published by the Free
26 : ! Software Foundation.
27 : !
28 : ! ELPA is distributed in the hope that it will be useful,
29 : ! but WITHOUT ANY WARRANTY; without even the implied warranty of
30 : ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
31 : ! GNU Lesser General Public License for more details.
32 : !
33 : ! You should have received a copy of the GNU Lesser General Public License
34 : ! along with ELPA. If not, see <http://www.gnu.org/licenses/>
35 : !
36 : ! ELPA reflects a substantial effort on the part of the original
37 : ! ELPA consortium, and we ask you to respect the spirit of the
38 : ! license that we chose: i.e., please contribute any changes you
39 : ! may have back to the original ELPA library distribution, and keep
40 : ! any derivatives of ELPA under the same license that we chose for
41 : ! the original distribution, the GNU Lesser General Public License.
42 : !
43 : ! Author: Andreas Marek, MPCDF
44 : #endif
45 : ! --------------------------------------------------------------------------------------------------
46 : ! redist_band: redistributes band from 2D block cyclic form to 1D band
47 :
48 : #include "config-f90.h"
49 :
50 : subroutine redist_band_&
51 : &MATH_DATATYPE&
52 : &_&
53 47448 : &PRECISION &
54 : (obj, &
55 : #if REALCASE == 1
56 28584 : r_a, &
57 : #endif
58 : #if COMPLEXCASE == 1
59 18864 : c_a, &
60 : #endif
61 : a_dev, lda, na, nblk, nbw, matrixCols, mpi_comm_rows, mpi_comm_cols, communicator, &
62 : #if REALCASE == 1
63 28584 : r_ab, useGPU)
64 : #endif
65 : #if COMPLEXCASE == 1
66 18864 : c_ab, useGPU)
67 : #endif
68 :
69 : use elpa_abstract_impl
70 : use elpa2_workload
71 : use precision
72 : use iso_c_binding
73 : use cuda_functions
74 : use elpa_utilities, only : local_index
75 : use elpa_mpi
76 : implicit none
77 :
78 : class(elpa_abstract_impl_t), intent(inout) :: obj
79 : logical, intent(in) :: useGPU
80 : integer(kind=ik), intent(in) :: lda, na, nblk, nbw, matrixCols, mpi_comm_rows, mpi_comm_cols, communicator
81 : #if REALCASE == 1
82 : MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(in) :: r_a(lda, matrixCols)
83 : #endif
84 : #if COMPLEXCASE == 1
85 : MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(in) :: c_a(lda, matrixCols)
86 : #endif
87 :
88 : #if REALCASE == 1
89 : MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(out) :: r_ab(:,:)
90 : #endif
91 :
92 : #if COMPLEXCASE == 1
93 : MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(out) :: c_ab(:,:)
94 : #endif
95 :
96 47448 : integer(kind=ik), allocatable :: ncnt_s(:), nstart_s(:), ncnt_r(:), nstart_r(:), &
97 66312 : global_id(:,:), global_id_tmp(:,:), block_limits(:)
98 : #if REALCASE == 1
99 28584 : MATH_DATATYPE(kind=C_DATATYPE_KIND), allocatable :: r_sbuf(:,:,:), r_rbuf(:,:,:), r_buf(:,:)
100 : #endif
101 :
102 : #if COMPLEXCASE == 1
103 18864 : MATH_DATATYPE(kind=C_DATATYPE_KIND), allocatable :: c_sbuf(:,:,:), c_rbuf(:,:,:), c_buf(:,:)
104 : #endif
105 : integer(kind=ik) :: i, j, my_pe, n_pes, my_prow, np_rows, my_pcol, np_cols, &
106 : nfact, np, npr, npc, mpierr, is, js
107 : integer(kind=ik) :: nblocks_total, il, jl, l_rows, l_cols, n_off
108 :
109 : logical :: successCUDA
110 : integer(kind=c_intptr_t) :: a_dev
111 : integer(kind=c_intptr_t), parameter :: size_of_datatype = size_of_&
112 : &PRECISION&
113 : &_&
114 : &MATH_DATATYPE
115 :
116 : call obj%timer%start("redist_band_&
117 : &MATH_DATATYPE&
118 : &" // &
119 : &PRECISION_SUFFIX &
120 47448 : )
121 :
122 47448 : if (useGPU) then
123 : ! copy a_dev to aMatrix
124 : successCUDA = cuda_memcpy ( &
125 : #if REALCASE == 1
126 : loc(r_a), &
127 : #endif
128 : #if COMPLEXCASE == 1
129 : loc(c_a(1,1)), &
130 : #endif
131 : int(a_dev,kind=c_intptr_t), int(lda*matrixCols* size_of_datatype, kind=c_intptr_t), &
132 0 : cudaMemcpyDeviceToHost)
133 0 : if (.not.(successCUDA)) then
134 : print *,"redist_band_&
135 : &MATH_DATATYPE&
136 0 : &: error in cudaMemcpy"
137 0 : stop 1
138 : endif
139 : endif ! useGPU
140 :
141 47448 : call obj%timer%start("mpi_communication")
142 47448 : call mpi_comm_rank(communicator,my_pe,mpierr)
143 47448 : call mpi_comm_size(communicator,n_pes,mpierr)
144 :
145 47448 : call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
146 47448 : call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
147 47448 : call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
148 47448 : call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
149 47448 : call obj%timer%stop("mpi_communication")
150 :
151 : ! Get global_id mapping 2D procssor coordinates to global id
152 :
153 47448 : allocate(global_id(0:np_rows-1,0:np_cols-1))
154 : #ifdef WITH_OPENMP
155 23724 : allocate(global_id_tmp(0:np_rows-1,0:np_cols-1))
156 : #endif
157 47448 : global_id(:,:) = 0
158 47448 : global_id(my_prow, my_pcol) = my_pe
159 : #ifdef WITH_MPI
160 31632 : call obj%timer%start("mpi_communication")
161 : #ifdef WITH_OPENMP
162 15816 : global_id_tmp(:,:) = global_id(:,:)
163 15816 : call mpi_allreduce(global_id_tmp, global_id, np_rows*np_cols, mpi_integer, mpi_sum, communicator, mpierr)
164 15816 : deallocate(global_id_tmp)
165 : #else
166 15816 : call mpi_allreduce(mpi_in_place, global_id, np_rows*np_cols, mpi_integer, mpi_sum, communicator, mpierr)
167 : #endif
168 31632 : call obj%timer%stop("mpi_communication")
169 : #endif /* WITH_MPI */
170 : ! Set work distribution
171 :
172 47448 : nblocks_total = (na-1)/nbw + 1
173 :
174 47448 : allocate(block_limits(0:n_pes))
175 47448 : call divide_band(obj, nblocks_total, n_pes, block_limits)
176 :
177 :
178 47448 : allocate(ncnt_s(0:n_pes-1))
179 47448 : allocate(nstart_s(0:n_pes-1))
180 47448 : allocate(ncnt_r(0:n_pes-1))
181 47448 : allocate(nstart_r(0:n_pes-1))
182 :
183 :
184 47448 : nfact = nbw/nblk
185 :
186 : ! Count how many blocks go to which PE
187 :
188 47448 : ncnt_s(:) = 0
189 47448 : np = 0 ! receiver PE number
190 644040 : do j=0,(na-1)/nblk ! loop over rows of blocks
191 596592 : if (j/nfact==block_limits(np+1)) np = np+1
192 596592 : if (mod(j,np_rows) == my_prow) then
193 2041920 : do i=0,nfact
194 1644192 : if (mod(i+j,np_cols) == my_pcol) then
195 1644192 : ncnt_s(np) = ncnt_s(np) + 1
196 : endif
197 : enddo
198 : endif
199 : enddo
200 :
201 : ! Allocate send buffer
202 :
203 : #if REALCASE==1
204 28584 : allocate(r_sbuf(nblk,nblk,sum(ncnt_s)))
205 28584 : r_sbuf(:,:,:) = 0.
206 : #endif
207 : #if COMPLEXCASE==1
208 18864 : allocate(c_sbuf(nblk,nblk,sum(ncnt_s)))
209 18864 : c_sbuf(:,:,:) = 0.
210 : #endif
211 :
212 : ! Determine start offsets in send buffer
213 :
214 47448 : nstart_s(0) = 0
215 79080 : do i=1,n_pes-1
216 31632 : nstart_s(i) = nstart_s(i-1) + ncnt_s(i-1)
217 : enddo
218 :
219 : ! Fill send buffer
220 :
221 47448 : l_rows = local_index(na, my_prow, np_rows, nblk, -1) ! Local rows of a
222 47448 : l_cols = local_index(na, my_pcol, np_cols, nblk, -1) ! Local columns of a
223 :
224 47448 : np = 0
225 644040 : do j=0,(na-1)/nblk ! loop over rows of blocks
226 596592 : if (j/nfact==block_limits(np+1)) np = np+1
227 596592 : if (mod(j,np_rows) == my_prow) then
228 2041920 : do i=0,nfact
229 1644192 : if (mod(i+j,np_cols) == my_pcol) then
230 1644192 : nstart_s(np) = nstart_s(np) + 1
231 1644192 : js = (j/np_rows)*nblk
232 1644192 : is = ((i+j)/np_cols)*nblk
233 1644192 : jl = MIN(nblk,l_rows-js)
234 1644192 : il = MIN(nblk,l_cols-is)
235 :
236 : #if REALCASE==1
237 1147680 : r_sbuf(1:jl,1:il,nstart_s(np)) = r_a(js+1:js+jl,is+1:is+il)
238 : #endif
239 : #if COMPLEXCASE==1
240 496512 : c_sbuf(1:jl,1:il,nstart_s(np)) = c_a(js+1:js+jl,is+1:is+il)
241 : #endif
242 : endif
243 : enddo
244 : endif
245 : enddo
246 :
247 : ! Count how many blocks we get from which PE
248 :
249 47448 : ncnt_r(:) = 0
250 445176 : do j=block_limits(my_pe)*nfact,min(block_limits(my_pe+1)*nfact-1,(na-1)/nblk)
251 397728 : npr = mod(j,np_rows)
252 2041920 : do i=0,nfact
253 1644192 : npc = mod(i+j,np_cols)
254 1644192 : np = global_id(npr,npc)
255 1644192 : ncnt_r(np) = ncnt_r(np) + 1
256 : enddo
257 : enddo
258 :
259 : ! Allocate receive buffer
260 :
261 : #if REALCASE==1
262 28584 : allocate(r_rbuf(nblk,nblk,sum(ncnt_r)))
263 : #endif
264 : #if COMPLEXCASE==1
265 18864 : allocate(c_rbuf(nblk,nblk,sum(ncnt_r)))
266 : #endif
267 :
268 : ! Set send counts/send offsets, receive counts/receive offsets
269 : ! now actually in variables, not in blocks
270 :
271 47448 : ncnt_s(:) = ncnt_s(:)*nblk*nblk
272 :
273 47448 : nstart_s(0) = 0
274 79080 : do i=1,n_pes-1
275 31632 : nstart_s(i) = nstart_s(i-1) + ncnt_s(i-1)
276 : enddo
277 :
278 47448 : ncnt_r(:) = ncnt_r(:)*nblk*nblk
279 :
280 47448 : nstart_r(0) = 0
281 79080 : do i=1,n_pes-1
282 31632 : nstart_r(i) = nstart_r(i-1) + ncnt_r(i-1)
283 : enddo
284 :
285 : ! Exchange all data with MPI_Alltoallv
286 : #ifdef WITH_MPI
287 31632 : call obj%timer%start("mpi_communication")
288 :
289 : #if REALCASE==1
290 :
291 : #ifdef DOUBLE_PRECISION_REAL
292 13824 : call MPI_Alltoallv(r_sbuf, ncnt_s, nstart_s, MPI_REAL8, r_rbuf, ncnt_r, nstart_r, MPI_REAL8, communicator, mpierr)
293 : #else
294 5232 : call MPI_Alltoallv(r_sbuf, ncnt_s, nstart_s, MPI_REAL4, r_rbuf, ncnt_r, nstart_r, MPI_REAL4, communicator, mpierr)
295 : #endif
296 :
297 : #endif /* REALCASE==1 */
298 :
299 : #if COMPLEXCASE==1
300 :
301 : #ifdef DOUBLE_PRECISION_COMPLEX
302 8448 : call MPI_Alltoallv(c_sbuf, ncnt_s, nstart_s, MPI_COMPLEX16, c_rbuf, ncnt_r, nstart_r, MPI_COMPLEX16, communicator, mpierr)
303 : #else
304 4128 : call MPI_Alltoallv(c_sbuf, ncnt_s, nstart_s, MPI_COMPLEX, c_rbuf, ncnt_r, nstart_r, MPI_COMPLEX, communicator, mpierr)
305 : #endif
306 :
307 : #endif /* COMPLEXCASE==1 */
308 :
309 31632 : call obj%timer%stop("mpi_communication")
310 : #else /* WITH_MPI */
311 :
312 : #if REALCASE==1
313 9528 : r_rbuf = r_sbuf
314 : #endif
315 :
316 : #if COMPLEXCASE==1
317 6288 : c_rbuf = c_sbuf
318 : #endif
319 :
320 : #endif /* WITH_MPI */
321 :
322 : ! set band from receive buffer
323 :
324 47448 : ncnt_r(:) = ncnt_r(:)/(nblk*nblk)
325 :
326 47448 : nstart_r(0) = 0
327 79080 : do i=1,n_pes-1
328 31632 : nstart_r(i) = nstart_r(i-1) + ncnt_r(i-1)
329 : enddo
330 :
331 : #if REALCASE==1
332 28584 : allocate(r_buf((nfact+1)*nblk,nblk))
333 : #endif
334 : #if COMPLEXCASE==1
335 18864 : allocate(c_buf((nfact+1)*nblk,nblk))
336 : #endif
337 :
338 : ! n_off: Offset of ab within band
339 47448 : n_off = block_limits(my_pe)*nbw
340 :
341 445176 : do j=block_limits(my_pe)*nfact,min(block_limits(my_pe+1)*nfact-1,(na-1)/nblk)
342 397728 : npr = mod(j,np_rows)
343 2041920 : do i=0,nfact
344 1644192 : npc = mod(i+j,np_cols)
345 1644192 : np = global_id(npr,npc)
346 1644192 : nstart_r(np) = nstart_r(np) + 1
347 : #if REALCASE==1
348 1147680 : r_buf(i*nblk+1:i*nblk+nblk,:) = transpose(r_rbuf(:,:,nstart_r(np)))
349 : #endif
350 : #if COMPLEXCASE==1
351 496512 : c_buf(i*nblk+1:i*nblk+nblk,:) = conjg(transpose(c_rbuf(:,:,nstart_r(np))))
352 : #endif
353 : enddo
354 6448128 : do i=1,MIN(nblk,na-j*nblk)
355 : #if REALCASE==1
356 3511200 : r_ab(1:nbw+1,i+j*nblk-n_off) = r_buf(i:i+nbw,i)
357 : #endif
358 : #if COMPLEXCASE==1
359 2539200 : c_ab(1:nbw+1,i+j*nblk-n_off) = c_buf(i:i+nbw,i)
360 : #endif
361 : enddo
362 : enddo
363 :
364 47448 : deallocate(ncnt_s, nstart_s)
365 47448 : deallocate(ncnt_r, nstart_r)
366 47448 : deallocate(global_id)
367 47448 : deallocate(block_limits)
368 :
369 : #if REALCASE==1
370 28584 : deallocate(r_sbuf, r_rbuf, r_buf)
371 : #endif
372 : #if COMPLEXCASE==1
373 18864 : deallocate(c_sbuf, c_rbuf, c_buf)
374 : #endif
375 :
376 : call obj%timer%stop("redist_band_&
377 : &MATH_DATATYPE&
378 : &" // &
379 : &PRECISION_SUFFIX &
380 47448 : )
381 :
382 47448 : end subroutine
383 :
|