Line data Source code
1 : #if 0
2 : ! This file is part of ELPA.
3 : !
4 : ! The ELPA library was originally created by the ELPA consortium,
5 : ! consisting of the following organizations:
6 : !
7 : ! - Max Planck Computing and Data Facility (MPCDF), formerly known as
8 : ! Rechenzentrum Garching der Max-Planck-Gesellschaft (RZG),
9 : ! - Bergische Universität Wuppertal, Lehrstuhl für angewandte
10 : ! Informatik,
11 : ! - Technische Universität München, Lehrstuhl für Informatik mit
12 : ! Schwerpunkt Wissenschaftliches Rechnen ,
13 : ! - Fritz-Haber-Institut, Berlin, Abt. Theorie,
14 : ! - Max-Plack-Institut für Mathematik in den Naturwissenschaften,
15 : ! Leipzig, Abt. Komplexe Strukutren in Biologie und Kognition,
16 : ! and
17 : ! - IBM Deutschland GmbH
18 : !
19 : ! This particular source code file contains additions, changes and
20 : ! enhancements authored by Intel Corporation which is not part of
21 : ! the ELPA consortium.
22 : !
23 : ! More information can be found here:
24 : ! http://elpa.mpcdf.mpg.de/
25 : !
26 : ! ELPA is free software: you can redistribute it and/or modify
27 : ! it under the terms of the version 3 of the license of the
28 : ! GNU Lesser General Public License as published by the Free
29 : ! Software Foundation.
30 : !
31 : ! ELPA is distributed in the hope that it will be useful,
32 : ! but WITHOUT ANY WARRANTY; without even the implied warranty of
33 : ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
34 : ! GNU Lesser General Public License for more details.
35 : !
36 : ! You should have received a copy of the GNU Lesser General Public License
37 : ! along with ELPA. If not, see <http://www.gnu.org/licenses/>
38 : !
39 : ! ELPA reflects a substantial effort on the part of the original
40 : ! ELPA consortium, and we ask you to respect the spirit of the
41 : ! license that we chose: i.e., please contribute any changes you
42 : ! may have back to the original ELPA library distribution, and keep
43 : ! any derivatives of ELPA under the same license that we chose for
44 : ! the original distribution, the GNU Lesser General Public License.
45 : !
46 : ! Copyright of the original code rests with the authors inside the ELPA
47 : ! consortium. The copyright of any additional modifications shall rest
48 : ! with their original authors, but shall adhere to the licensing terms
49 : ! distributed along with the original code in the file "COPYING".
50 : #endif
51 :
52 : #include "../general/sanity.F90"
53 :
54 : subroutine trans_ev_band_to_full_&
55 : &MATH_DATATYPE&
56 : &_&
57 45432 : &PRECISION &
58 45432 : (obj, na, nqc, nblk, nbw, a, a_dev, lda, tmat, tmat_dev, q, &
59 : q_dev, ldq, matrixCols, numBlocks, mpi_comm_rows, mpi_comm_cols, useGPU &
60 : #if REALCASE == 1
61 : ,useQr)
62 : #endif
63 : #if COMPLEXCASE == 1
64 : )
65 : #endif
66 :
67 : !-------------------------------------------------------------------------------
68 : ! trans_ev_band_to_full_real/complex:
69 : ! Transforms the eigenvectors of a band matrix back to the eigenvectors of the original matrix
70 : !
71 : ! Parameters
72 : !
73 : ! na Order of matrix a, number of rows of matrix q
74 : !
75 : ! nqc Number of columns of matrix q
76 : !
77 : ! nblk blocksize of cyclic distribution, must be the same in both directions!
78 : !
79 : ! nbw semi bandwith
80 : !
81 : ! a(lda,matrixCols) Matrix containing the Householder vectors (i.e. matrix a after bandred_real/complex)
82 : ! Distribution is like in Scalapack.
83 : !
84 : ! lda Leading dimension of a
85 : ! matrixCols local columns of matrix a and q
86 : !
87 : ! tmat(nbw,nbw,numBlocks) Factors returned by bandred_real/complex
88 : !
89 : ! q On input: Eigenvectors of band matrix
90 : ! On output: Transformed eigenvectors
91 : ! Distribution is like in Scalapack.
92 : !
93 : ! ldq Leading dimension of q
94 : !
95 : ! mpi_comm_rows
96 : ! mpi_comm_cols
97 : ! MPI-Communicators for rows/columns
98 : !
99 : !-------------------------------------------------------------------------------
100 : use precision
101 : use cuda_functions
102 : use iso_c_binding
103 : use elpa_abstract_impl
104 : implicit none
105 : #include "../general/precision_kinds.F90"
106 : class(elpa_abstract_impl_t), intent(inout) :: obj
107 : logical, intent(in) :: useGPU
108 : #if REALCASE == 1
109 : logical, intent(in) :: useQR
110 : #endif
111 : integer(kind=ik) :: na, nqc, lda, ldq, nblk, nbw, matrixCols, numBlocks, mpi_comm_rows, mpi_comm_cols
112 : #ifdef USE_ASSUMED_SIZE
113 : MATH_DATATYPE(kind=rck) :: a(lda,*), q(ldq,*), tmat(nbw,nbw,*)
114 : #else
115 : MATH_DATATYPE(kind=rck) :: a(lda,matrixCols), q(ldq,matrixCols), tmat(nbw, nbw, numBlocks)
116 : #endif
117 : integer(kind=C_intptr_T) :: a_dev ! passed from bandred_real at the moment not used since copied in bandred_real
118 :
119 : integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr
120 : integer(kind=ik) :: max_blocks_row, max_blocks_col, max_local_rows, &
121 : max_local_cols
122 : integer(kind=ik) :: l_cols, l_rows, l_colh, n_cols
123 : integer(kind=ik) :: istep, lc, ncol, nrow, nb, ns
124 :
125 68148 : MATH_DATATYPE(kind=rck), allocatable :: tmp1(:), tmp2(:), hvb(:), hvm(:,:)
126 : ! hvm_dev is fist used and set in this routine
127 : ! q is changed in trans_ev_tridi on the host, copied to device and passed here. this can be adapted
128 : ! tmp_dev is first used in this routine
129 : ! tmat_dev is passed along from bandred_real
130 : integer(kind=C_intptr_T) :: hvm_dev, q_dev, tmp_dev, tmat_dev
131 :
132 : integer(kind=ik) :: i
133 :
134 : #ifdef BAND_TO_FULL_BLOCKING
135 22716 : MATH_DATATYPE(kind=rck), allocatable :: tmat_complete(:,:), t_tmp(:,:), t_tmp2(:,:)
136 : integer(kind=ik) :: cwy_blocking, t_blocking, t_cols, t_rows
137 : #endif
138 :
139 : integer(kind=ik) :: istat
140 : character(200) :: errorMessage
141 : logical :: successCUDA
142 : integer(kind=c_intptr_t), parameter :: size_of_datatype = size_of_&
143 : &PRECISION&
144 : &_&
145 : &MATH_DATATYPE
146 : integer :: blocking_factor, error
147 : call obj%timer%start("trans_ev_band_to_full_&
148 : &MATH_DATATYPE&
149 : &" // &
150 : &PRECISION_SUFFIX &
151 45432 : )
152 : #ifdef BAND_TO_FULL_BLOCKING
153 22716 : call obj%get("blocking_in_band_to_full",blocking_factor,error)
154 22716 : if (error .ne. ELPA_OK) then
155 0 : print *,"Problem getting option. Aborting..."
156 0 : stop
157 : endif
158 : #endif
159 45432 : call obj%timer%start("mpi_communication")
160 :
161 45432 : call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
162 45432 : call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
163 45432 : call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
164 45432 : call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
165 :
166 45432 : call obj%timer%stop("mpi_communication")
167 :
168 45432 : max_blocks_row = ((na -1)/nblk)/np_rows + 1 ! Rows of A
169 45432 : max_blocks_col = ((nqc-1)/nblk)/np_cols + 1 ! Columns of q!
170 :
171 45432 : max_local_rows = max_blocks_row*nblk
172 45432 : max_local_cols = max_blocks_col*nblk
173 :
174 45432 : if (useGPU) then
175 :
176 : #if REALCASE == 1
177 : ! here the GPU and CPU version diverged: the CPU version now always uses the useQR path which
178 : ! is not implemented in the GPU version
179 : #endif
180 :
181 : ! the GPU version does not (yet) support blocking
182 : ! but the handling is the same for real/complex case
183 :
184 0 : allocate(tmp1(max_local_cols*nbw), stat=istat, errmsg=errorMessage)
185 0 : if (istat .ne. 0) then
186 : print *,"trans_ev_band_to_full_&
187 : &MATH_DATATYPE&
188 0 : &: error when allocating tmp1 "//errorMessage
189 0 : stop 1
190 : endif
191 :
192 0 : allocate(tmp2(max_local_cols*nbw), stat=istat, errmsg=errorMessage)
193 0 : if (istat .ne. 0) then
194 : print *,"trans_ev_band_to_full_&
195 : &MATH_DATATYPE&
196 0 : &: error when allocating tmp2 "//errorMessage
197 0 : stop 1
198 : endif
199 :
200 0 : allocate(hvb(max_local_rows*nbw), stat=istat, errmsg=errorMessage)
201 0 : if (istat .ne. 0) then
202 : print *,"trans_ev_band_to_full_&
203 : &MATH_DATATYPE&
204 0 : &: error when allocating hvb "//errorMessage
205 0 : stop 1
206 : endif
207 :
208 0 : allocate(hvm(max_local_rows,nbw), stat=istat, errmsg=errorMessage)
209 0 : if (istat .ne. 0) then
210 : print *,"trans_ev_band_to_full_&
211 : &MATH_DATATYPE&
212 0 : &: error when allocating hvm "//errorMessage
213 0 : stop 1
214 : endif
215 :
216 0 : successCUDA = cuda_malloc(hvm_dev, (max_local_rows)*nbw* size_of_datatype)
217 0 : if (.not.(successCUDA)) then
218 : print *,"trans_ev_band_to_full_&
219 : &MATH_DATATYPE&
220 0 : &: error in cudaMalloc"
221 0 : stop 1
222 : endif
223 :
224 0 : successCUDA = cuda_malloc(tmp_dev, (max_local_cols)*nbw* size_of_datatype)
225 0 : if (.not.(successCUDA)) then
226 : print *,"trans_ev_band_to_full_&
227 : &MATH_DATATYPE&
228 0 : &: error in cudaMalloc"
229 0 : stop 1
230 : endif
231 :
232 : !#ifdef WITH_MPI
233 : !! it should be possible to keep tmat dev on the device and not copy it around
234 : !! already existent on GPU
235 : ! successCUDA = cuda_malloc(tmat_dev, nbw*nbw* &
236 : !#if REALCASE == 1
237 : ! size_of_PRECISION_real)
238 : !#endif
239 : !#if COMPLEXCASE == 1
240 : ! size_of_PRECISION_complex)
241 : !#endif
242 : !
243 : ! if (.not.(successCUDA)) then
244 : ! print *,"trans_ev_band_to_full_&
245 : ! &MATH_DATATYPE&
246 : ! &: error in cudaMalloc"
247 : ! stop 1
248 : ! endif
249 : !#endif
250 :
251 : #if REALCASE == 1
252 : ! q_dev already living on device
253 : ! successCUDA = cuda_malloc(q_dev, ldq*matrixCols*size_of_datatype)
254 : ! if (.not.(successCUDA)) then
255 : ! print *,"trans_ev_band_to_full_real: error in cudaMalloc"
256 : ! stop 1
257 : ! endif
258 : ! q_temp(:,:) = 0.0
259 : ! q_temp(1:ldq,1:na_cols) = q(1:ldq,1:na_cols)
260 :
261 : ! ! copy q_dev to device, maybe this can be avoided if q_dev can be kept on device in trans_ev_tridi_to_band
262 : ! successCUDA = cuda_memcpy(q_dev, loc(q), (ldq)*(matrixCols)*size_of_PRECISION_real, cudaMemcpyHostToDevice)
263 : ! if (.not.(successCUDA)) then
264 : ! print *,"trans_ev_band_to_full_real: error in cudaMalloc"
265 : ! stop 1
266 : ! endif
267 : #endif
268 : #if COMPLEXCASE == 1
269 : ! successCUDA = cuda_malloc(q_dev, ldq*matrixCols*size_of_PRECISION_complex)
270 : ! if (.not.(successCUDA)) then
271 : ! print *,"trans_ev_band_to_full_complex: error in cudaMalloc"
272 : ! stop 1
273 : ! endif
274 : !
275 : ! successCUDA = cuda_memcpy(q_dev, loc(q),ldq*matrixCols*size_of_PRECISION_complex, cudaMemcpyHostToDevice)
276 : ! if (.not.(successCUDA)) then
277 : ! print *,"trans_ev_band_to_full_complex: error in cudaMemcpy"
278 : ! stop 1
279 : ! endif
280 : #endif
281 :
282 : ! if MPI is NOT used the following steps could be done on the GPU and memory transfers could be avoided
283 0 : successCUDA = cuda_memset(hvm_dev, 0, (max_local_rows)*(nbw)* size_of_datatype)
284 0 : if (.not.(successCUDA)) then
285 : print *,"trans_ev_band_to_full_&
286 : &MATH_DATATYPE&
287 0 : &: error in cudaMalloc"
288 0 : stop 1
289 : endif
290 :
291 0 : hvm = 0.0_rck ! Must be set to 0 !!!
292 0 : hvb = 0.0_rck ! Safety only
293 0 : l_cols = local_index(nqc, my_pcol, np_cols, nblk, -1) ! Local columns of q
294 :
295 0 : do istep=1,(na-1)/nbw
296 :
297 0 : n_cols = MIN(na,(istep+1)*nbw) - istep*nbw ! Number of columns in current step
298 :
299 : ! Broadcast all Householder vectors for current step compressed in hvb
300 :
301 0 : nb = 0
302 0 : ns = 0
303 :
304 0 : do lc = 1, n_cols
305 0 : ncol = istep*nbw + lc ! absolute column number of householder Vector
306 0 : nrow = ncol - nbw ! absolute number of pivot row
307 :
308 0 : l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
309 0 : l_colh = local_index(ncol , my_pcol, np_cols, nblk, -1) ! HV local column number
310 :
311 0 : if (my_pcol==pcol(ncol, nblk, np_cols)) hvb(nb+1:nb+l_rows) = a(1:l_rows,l_colh)
312 :
313 0 : nb = nb+l_rows
314 :
315 0 : if (lc==n_cols .or. mod(ncol,nblk)==0) then
316 : #ifdef WITH_MPI
317 0 : call obj%timer%start("mpi_communication")
318 : call MPI_Bcast(hvb(ns+1), nb-ns, MPI_MATH_DATATYPE_PRECISION,&
319 0 : pcol(ncol, nblk, np_cols), mpi_comm_cols, mpierr)
320 :
321 0 : call obj%timer%stop("mpi_communication")
322 :
323 : #endif /* WITH_MPI */
324 0 : ns = nb
325 : endif
326 : enddo
327 :
328 : ! Expand compressed Householder vectors into matrix hvm
329 :
330 0 : nb = 0
331 0 : do lc = 1, n_cols
332 0 : nrow = (istep-1)*nbw+lc ! absolute number of pivot row
333 0 : l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
334 :
335 0 : hvm(1:l_rows,lc) = hvb(nb+1:nb+l_rows)
336 0 : if (my_prow==prow(nrow, nblk, np_rows)) hvm(l_rows+1,lc) = 1.0_rck
337 0 : nb = nb+l_rows
338 : enddo
339 :
340 0 : successCUDA = cuda_memcpy(hvm_dev, loc(hvm), max_local_rows*nbw* size_of_datatype, cudaMemcpyHostToDevice)
341 :
342 0 : if (.not.(successCUDA)) then
343 0 : print *,"trans_ev_band_to_full_real: error in cudaMemcpy"
344 0 : stop 1
345 :
346 : endif
347 :
348 0 : l_rows = local_index(MIN(na,(istep+1)*nbw), my_prow, np_rows, nblk, -1)
349 :
350 : ! Q = Q - V * T**T * V**T * Q
351 :
352 0 : if (l_rows>0) then
353 0 : call obj%timer%start("cublas")
354 : call cublas_PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', &
355 : n_cols, l_cols, l_rows, ONE, hvm_dev, max_local_rows, &
356 0 : q_dev, ldq , ZERO, tmp_dev, n_cols)
357 0 : call obj%timer%stop("cublas")
358 :
359 : #ifdef WITH_MPI
360 :
361 : ! copy data from device to host for a later MPI_ALLREDUCE
362 : ! copy to host maybe this can be avoided this is needed if MPI is used (allreduce)
363 0 : successCUDA = cuda_memcpy(loc(tmp1), tmp_dev, l_cols*n_cols*size_of_datatype, cudaMemcpyDeviceToHost)
364 0 : if (.not.(successCUDA)) then
365 0 : print *,"trans_ev_band_to_full_real: error in cudaMemcpy"
366 0 : stop 1
367 : endif
368 :
369 :
370 : #else /* WITH_MPI */
371 : ! in real case no copy needed. Don't do it in complex case neither
372 : #endif /* WITH_MPI */
373 :
374 : else ! l_rows>0
375 0 : tmp1(1:l_cols*n_cols) = 0.0_rck
376 : endif ! l_rows>0
377 :
378 : #ifdef WITH_MPI
379 0 : call obj%timer%start("mpi_communication")
380 : call mpi_allreduce(tmp1, tmp2, n_cols*l_cols, MPI_MATH_DATATYPE_PRECISION, &
381 0 : MPI_SUM, mpi_comm_rows, mpierr)
382 0 : call obj%timer%stop("mpi_communication")
383 :
384 : #else /* WITH_MPI */
385 : ! tmp2(1:n_cols*l_cols) = tmp1(1:n_cols*l_cols)
386 : #endif /* WITH_MPI */
387 :
388 0 : if (l_rows>0) then
389 : #ifdef WITH_MPI
390 : ! after the mpi_allreduce we have to copy back to the device
391 : ! copy back to device
392 : successCUDA = cuda_memcpy(tmp_dev, loc(tmp2), n_cols*l_cols* size_of_datatype, &
393 0 : cudaMemcpyHostToDevice)
394 0 : if (.not.(successCUDA)) then
395 : print *,"trans_ev_band_to_full_&
396 : &MATH_DATATYPE&
397 0 : &: error in cudaMemcpy"
398 0 : stop 1
399 : endif
400 : #else /* WITH_MPI */
401 : ! in real case no memcopy needed. Don't do it in complex case neither
402 : #endif /* WITH_MPI */
403 :
404 : !#ifdef WITH_MPI
405 : ! IMPORTANT: even though tmat_dev is transfered from the previous rutine, we have to copy from tmat again
406 : ! tmat is 3-dimensional array, while tmat_dev contains only one 2-dimensional slice of it - and here we
407 : ! need to upload another slice
408 0 : successCUDA = cuda_memcpy(tmat_dev, loc(tmat(1,1,istep)), nbw*nbw*size_of_datatype, cudaMemcpyHostToDevice)
409 :
410 0 : if (.not.(successCUDA)) then
411 : print *,"trans_ev_band_to_full_&
412 : &MATH_DATATYPE&
413 0 : &: error in cudaMemcpy"
414 0 : stop 1
415 : endif
416 : !#endif /* WITH_MPI */
417 :
418 0 : call obj%timer%start("cublas")
419 : call cublas_PRECISION_TRMM('L', 'U', BLAS_TRANS_OR_CONJ, 'N', &
420 0 : n_cols, l_cols, ONE, tmat_dev, nbw, tmp_dev, n_cols)
421 :
422 : call cublas_PRECISION_GEMM('N', 'N', l_rows, l_cols, n_cols, -ONE, hvm_dev, max_local_rows, &
423 0 : tmp_dev, n_cols, one, q_dev, ldq)
424 0 : call obj%timer%stop("cublas")
425 :
426 : ! copy to host maybe this can be avoided
427 : ! this is not necessary hvm is not used anymore
428 0 : successCUDA = cuda_memcpy(loc(hvm), hvm_dev, ((max_local_rows)*nbw*size_of_datatype),cudaMemcpyDeviceToHost)
429 0 : if (.not.(successCUDA)) then
430 0 : print *,"trans_ev_band_to_full_real: error in cudaMemcpy"
431 0 : stop 1
432 : endif
433 : endif ! l_rows > 0
434 :
435 : enddo ! istep
436 :
437 :
438 :
439 : else ! do not useGPU
440 :
441 : #ifdef BAND_TO_FULL_BLOCKING
442 : ! t_blocking was formerly 2; 3 is a better choice
443 22716 : t_blocking = blocking_factor ! number of matrices T (tmat) which are aggregated into a new (larger) T matrix (tmat_complete) and applied at once
444 :
445 : ! we only use the t_blocking if we could call it fully, this is might be better but needs to benchmarked.
446 : ! if ( na >= ((t_blocking+1)*nbw) ) then
447 22716 : cwy_blocking = t_blocking * nbw
448 :
449 22716 : allocate(tmp1(max_local_cols*cwy_blocking), stat=istat, errmsg=errorMessage)
450 22716 : if (istat .ne. 0) then
451 : print *,"trans_ev_band_to_full_&
452 : &MATH_DATATYPE&
453 0 : &: error when allocating tmp1 "//errorMessage
454 0 : stop 1
455 : endif
456 :
457 22716 : allocate(tmp2(max_local_cols*cwy_blocking), stat=istat, errmsg=errorMessage)
458 22716 : if (istat .ne. 0) then
459 : print *,"trans_ev_band_to_full_&
460 : &MATH_DATATYPE&
461 0 : &: error when allocating tmp2 "//errorMessage
462 0 : stop 1
463 : endif
464 :
465 22716 : allocate(hvb(max_local_rows*cwy_blocking), stat=istat, errmsg=errorMessage)
466 22716 : if (istat .ne. 0) then
467 : print *,"trans_ev_band_to_full_&
468 : &MATH_DATATYPE&
469 0 : &: error when allocating hvb "//errorMessage
470 0 : stop 1
471 : endif
472 :
473 22716 : allocate(hvm(max_local_rows,cwy_blocking), stat=istat, errmsg=errorMessage)
474 22716 : if (istat .ne. 0) then
475 : print *,"trans_ev_band_to_full_&
476 : &MATH_DATATYPE&
477 0 : &: error when allocating hvm "//errorMessage
478 0 : stop 1
479 : endif
480 :
481 : #else /* BAND_TO_FULL_BLOCKING */
482 :
483 22716 : allocate(tmp1(max_local_cols*nbw), stat=istat, errmsg=errorMessage)
484 22716 : if (istat .ne. 0) then
485 : print *,"trans_ev_band_to_full_&
486 : &MATH_DATATYPE&
487 0 : &: error when allocating tmp1 "//errorMessage
488 0 : stop 1
489 : endif
490 :
491 22716 : allocate(tmp2(max_local_cols*nbw), stat=istat, errmsg=errorMessage)
492 22716 : if (istat .ne. 0) then
493 : print *,"trans_ev_band_to_full_&
494 0 : &MATH_DATATYPE&: error when allocating tmp2 "//errorMessage
495 0 : stop 1
496 : endif
497 :
498 22716 : allocate(hvb(max_local_rows*nbw), stat=istat, errmsg=errorMessage)
499 22716 : if (istat .ne. 0) then
500 : print *,"trans_ev_band_to_full_&
501 : &MATH_DATATYPE&
502 0 : &: error when allocating hvb "//errorMessage
503 0 : stop 1
504 : endif
505 :
506 22716 : allocate(hvm(max_local_rows,nbw), stat=istat, errmsg=errorMessage)
507 22716 : if (istat .ne. 0) then
508 : print *,"trans_ev_band_to_full_&
509 : &MATH_DATATYPE&
510 0 : &: error when allocating hvm "//errorMessage
511 0 : stop 1
512 : endif
513 : #endif /* BAND_TO_FULL_BLOCKING */
514 :
515 : #ifdef BAND_TO_FULL_BLOCKING
516 22716 : allocate(tmat_complete(cwy_blocking,cwy_blocking), stat=istat, errmsg=errorMessage)
517 22716 : if (istat .ne. 0) then
518 : print *,"trans_ev_band_to_full_&
519 : &MATH_DATATYPE&
520 0 : &: error when allocating tmat_complete "//errorMessage
521 0 : stop 1
522 : endif
523 22716 : allocate(t_tmp(cwy_blocking,nbw), stat=istat, errmsg=errorMessage)
524 22716 : if (istat .ne. 0) then
525 : print *,"trans_ev_band_to_full_&
526 : &MATH_DATATYPE&
527 0 : &: error when allocating t_tmp "//errorMessage
528 0 : stop 1
529 : endif
530 22716 : allocate(t_tmp2(cwy_blocking,nbw), stat=istat, errmsg=errorMessage)
531 22716 : if (istat .ne. 0) then
532 : print *,"trans_ev_band_to_full_&
533 : &MATH_DATATYPE&
534 0 : &: error when allocating t_tmp2 "//errorMessage
535 0 : stop 1
536 : endif
537 : #endif
538 : ! else
539 : ! allocate(tmp1(max_local_cols*nbw))
540 : ! allocate(tmp2(max_local_cols*nbw))
541 : ! allocate(hvb(max_local_rows*nbw))
542 : ! allocate(hvm(max_local_rows,nbw))
543 : ! endif
544 :
545 45432 : hvm = 0.0_rck ! Must be set to 0 !!!
546 45432 : hvb = 0.0_rck ! Safety only
547 45432 : l_cols = local_index(nqc, my_pcol, np_cols, nblk, -1) ! Local columns of q
548 :
549 : ! if ( na >= ((t_blocking+1)*nbw) ) then
550 :
551 : #ifdef BAND_TO_FULL_BLOCKING
552 61920 : do istep=1,((na-1)/nbw-1)/t_blocking + 1
553 : #else
554 109188 : do istep=1,(na-1)/nbw
555 : #endif
556 :
557 : #ifdef BAND_TO_FULL_BLOCKING
558 : ! This the call when using na >= ((t_blocking+1)*nbw)
559 : ! n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw
560 : ! Number of columns in current step
561 : ! As an alternative we add some special case handling if na < cwy_blocking
562 39204 : IF (na < cwy_blocking) THEN
563 13140 : n_cols = MAX(0, na-nbw)
564 13140 : IF ( n_cols .eq. 0 ) THEN
565 0 : EXIT
566 : END IF
567 : ELSE
568 26064 : n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw ! Number of columns in current step
569 : END IF
570 : #else /* BAND_TO_FULL_BLOCKING */
571 86472 : n_cols = MIN(na,(istep+1)*nbw) - istep*nbw ! Number of columns in current step
572 : #endif /* BAND_TO_FULL_BLOCKING */
573 : ! Broadcast all Householder vectors for current step compressed in hvb
574 :
575 125676 : nb = 0
576 125676 : ns = 0
577 :
578 6567228 : do lc = 1, n_cols
579 : #ifdef BAND_TO_FULL_BLOCKING
580 3220776 : ncol = (istep-1)*cwy_blocking + nbw + lc ! absolute column number of householder Vector
581 : #else
582 3220776 : ncol = istep*nbw + lc ! absolute column number of householder Vector
583 : #endif
584 6441552 : nrow = ncol - nbw ! absolute number of pivot row
585 :
586 6441552 : l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
587 6441552 : l_colh = local_index(ncol , my_pcol, np_cols, nblk, -1) ! HV local column number
588 :
589 6441552 : if (my_pcol==pcol(ncol, nblk, np_cols)) hvb(nb+1:nb+l_rows) = a(1:l_rows,l_colh)
590 :
591 6441552 : nb = nb+l_rows
592 :
593 6441552 : if (lc==n_cols .or. mod(ncol,nblk)==0) then
594 : #ifdef WITH_MPI
595 287136 : call obj%timer%start("mpi_communication")
596 : call MPI_Bcast(hvb(ns+1), nb-ns, MPI_MATH_DATATYPE_PRECISION, &
597 287136 : pcol(ncol, nblk, np_cols), mpi_comm_cols, mpierr)
598 :
599 287136 : call obj%timer%stop("mpi_communication")
600 :
601 : #endif /* WITH_MPI */
602 430704 : ns = nb
603 : endif
604 : enddo ! lc
605 :
606 : ! Expand compressed Householder vectors into matrix hvm
607 :
608 125676 : nb = 0
609 6567228 : do lc = 1, n_cols
610 : #ifdef BAND_TO_FULL_BLOCKING
611 3220776 : nrow = (istep-1)*cwy_blocking + lc ! absolute number of pivot row
612 : #else
613 3220776 : nrow = (istep-1)*nbw+lc ! absolute number of pivot row
614 : #endif
615 6441552 : l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
616 :
617 6441552 : hvm(1:l_rows,lc) = hvb(nb+1:nb+l_rows)
618 6441552 : if (my_prow==prow(nrow, nblk, np_rows)) hvm(l_rows+1,lc) = 1.0_rck
619 6441552 : nb = nb+l_rows
620 : enddo
621 :
622 : #ifdef BAND_TO_FULL_BLOCKING
623 39204 : l_rows = local_index(MIN(na,(istep+1)*cwy_blocking), my_prow, np_rows, nblk, -1)
624 :
625 : ! compute tmat2 out of tmat(:,:,)
626 39204 : tmat_complete = 0
627 125676 : do i = 1, t_blocking
628 108612 : t_cols = MIN(nbw, n_cols - (i-1)*nbw)
629 108612 : if (t_cols <= 0) exit
630 86472 : t_rows = (i - 1) * nbw
631 86472 : tmat_complete(t_rows+1:t_rows+t_cols,t_rows+1:t_rows+t_cols) = tmat(1:t_cols,1:t_cols,(istep-1)*t_blocking + i)
632 :
633 86472 : if (i > 1) then
634 47268 : call obj%timer%start("blas")
635 : call PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', &
636 : t_rows, t_cols, l_rows, ONE, hvm(1,1), max_local_rows, hvm(1,(i-1)*nbw+1), &
637 47268 : max_local_rows, ZERO, t_tmp, cwy_blocking)
638 :
639 47268 : call obj%timer%stop("blas")
640 : #ifdef WITH_MPI
641 31512 : call obj%timer%start("mpi_communication")
642 :
643 : call mpi_allreduce(t_tmp, t_tmp2, cwy_blocking*nbw, MPI_MATH_DATATYPE_PRECISION, &
644 31512 : MPI_SUM, mpi_comm_rows, mpierr)
645 31512 : call obj%timer%stop("mpi_communication")
646 31512 : call obj%timer%start("blas")
647 31512 : call PRECISION_TRMM('L', 'U', 'N', 'N', t_rows, t_cols, ONE, tmat_complete, cwy_blocking, t_tmp2, cwy_blocking)
648 : call PRECISION_TRMM('R', 'U', 'N', 'N', t_rows, t_cols, -ONE, tmat_complete(t_rows+1,t_rows+1), cwy_blocking, &
649 31512 : t_tmp2, cwy_blocking)
650 31512 : call obj%timer%stop("blas")
651 :
652 31512 : tmat_complete(1:t_rows,t_rows+1:t_rows+t_cols) = t_tmp2(1:t_rows,1:t_cols)
653 :
654 : #else /* WITH_MPI */
655 : ! t_tmp2(1:cwy_blocking,1:nbw) = t_tmp(1:cwy_blocking,1:nbw)
656 15756 : call obj%timer%start("blas")
657 15756 : call PRECISION_TRMM('L', 'U', 'N', 'N', t_rows, t_cols, ONE, tmat_complete, cwy_blocking, t_tmp, cwy_blocking)
658 : call PRECISION_TRMM('R', 'U', 'N', 'N', t_rows, t_cols, -ONE, tmat_complete(t_rows+1,t_rows+1), cwy_blocking, &
659 15756 : t_tmp, cwy_blocking)
660 15756 : call obj%timer%stop("blas")
661 :
662 15756 : tmat_complete(1:t_rows,t_rows+1:t_rows+t_cols) = t_tmp(1:t_rows,1:t_cols)
663 :
664 : #endif /* WITH_MPI */
665 :
666 : ! call PRECISION_TRMM('L', 'U', 'N', 'N', t_rows, t_cols, ONE, tmat_complete, cwy_blocking, t_tmp2, cwy_blocking)
667 : ! call PRECISION_TRMM('R', 'U', 'N', 'N', t_rows, t_cols, -ONE, tmat_complete(t_rows+1,t_rows+1), cwy_blocking, &
668 : ! t_tmp2, cwy_blocking)
669 :
670 : ! tmat_complete(1:t_rows,t_rows+1:t_rows+t_cols) = t_tmp2(1:t_rows,1:t_cols)
671 : endif
672 : enddo
673 : #else /* BAND_TO_FULL_BLOCKING */
674 86472 : l_rows = local_index(MIN(na,(istep+1)*nbw), my_prow, np_rows, nblk, -1)
675 : #endif
676 :
677 : ! Q = Q - V * T**T * V**T * Q
678 :
679 125676 : if (l_rows>0) then
680 125676 : call obj%timer%start("blas")
681 :
682 : call PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', &
683 : n_cols, l_cols, l_rows, ONE, hvm, ubound(hvm,dim=1), &
684 125676 : q, ldq, ZERO, tmp1, n_cols)
685 125676 : call obj%timer%stop("blas")
686 :
687 : else ! l_rows>0
688 :
689 0 : tmp1(1:l_cols*n_cols) = 0.0_rck
690 : endif ! l_rows>0
691 :
692 : #ifdef WITH_MPI
693 83784 : call obj%timer%start("mpi_communication")
694 83784 : call mpi_allreduce(tmp1, tmp2, n_cols*l_cols, MPI_MATH_DATATYPE_PRECISION, MPI_SUM, mpi_comm_rows ,mpierr)
695 83784 : call obj%timer%stop("mpi_communication")
696 :
697 83784 : call obj%timer%start("blas")
698 :
699 83784 : if (l_rows>0) then
700 : #ifdef BAND_TO_FULL_BLOCKING
701 :
702 : call PRECISION_TRMM('L', 'U', BLAS_TRANS_OR_CONJ, 'N', &
703 26136 : n_cols, l_cols, ONE, tmat_complete, cwy_blocking, tmp2, n_cols)
704 26136 : call PRECISION_GEMM('N', 'N', l_rows, l_cols, n_cols, -ONE, hvm, ubound(hvm,dim=1), tmp2, n_cols, ONE, q, ldq)
705 :
706 : #else /* BAND_TO_FULL_BLOCKING */
707 :
708 : call PRECISION_TRMM('L', 'U', BLAS_TRANS_OR_CONJ, 'N', &
709 57648 : n_cols, l_cols, ONE, tmat(1,1,istep), ubound(tmat,dim=1), tmp2, n_cols)
710 : call PRECISION_GEMM('N', 'N', l_rows, l_cols, n_cols, -ONE, hvm, ubound(hvm,dim=1), &
711 57648 : tmp2, n_cols, ONE, q, ldq)
712 :
713 : #endif /* BAND_TO_FULL_BLOCKING */
714 :
715 : endif
716 83784 : call obj%timer%stop("blas")
717 : #else /* WITH_MPI */
718 : ! tmp2 = tmp1
719 41892 : call obj%timer%start("blas")
720 41892 : if (l_rows>0) then
721 : #ifdef BAND_TO_FULL_BLOCKING
722 : call PRECISION_TRMM('L', 'U', BLAS_TRANS_OR_CONJ, 'N', &
723 13068 : n_cols, l_cols, ONE, tmat_complete, cwy_blocking, tmp1, n_cols)
724 13068 : call PRECISION_GEMM('N', 'N', l_rows, l_cols, n_cols, -ONE, hvm, ubound(hvm,dim=1), tmp1, n_cols, ONE, q, ldq)
725 : #else /* BAND_TO_FULL_BLOCKING */
726 :
727 : call PRECISION_TRMM('L', 'U', BLAS_TRANS_OR_CONJ, 'N', &
728 28824 : n_cols, l_cols, ONE, tmat(1,1,istep), ubound(tmat,dim=1), tmp1, n_cols)
729 : call PRECISION_GEMM('N', 'N', l_rows, l_cols, n_cols, -ONE, hvm, ubound(hvm,dim=1), &
730 28824 : tmp1, n_cols, ONE, q, ldq)
731 :
732 : #endif /* BAND_TO_FULL_BLOCKING */
733 : endif
734 41892 : call obj%timer%stop("blas")
735 : #endif /* WITH_MPI */
736 :
737 : ! if (l_rows>0) then
738 : ! call PRECISION_TRMM('L', 'U', 'T', 'N', n_cols, l_cols, ONE, tmat_complete, cwy_blocking, tmp2, n_cols)
739 : ! call PRECISION_GEMM('N', 'N', l_rows, l_cols, n_cols, -ONE, hvm, ubound(hvm,dim=1), tmp2, n_cols, ONE, q, ldq)
740 : ! endif
741 :
742 : enddo ! istep
743 :
744 : endif ! useGPU
745 :
746 45432 : deallocate(tmp1, tmp2, hvb, stat=istat, errmsg=errorMessage)
747 45432 : if (istat .ne. 0) then
748 : print *,"trans_ev_band_to_full_&
749 : &MATH_DATATYPE&
750 0 : &: error when deallocating tmp1 tmp2 hvb "//errorMessage
751 0 : stop 1
752 : endif
753 :
754 45432 : if (useGPU) then
755 0 : successCUDA = cuda_free(hvm_dev)
756 0 : if (.not.(successCUDA)) then
757 : print *,"trans_ev_band_to_full_&
758 : &MATH_DATATYPE&
759 0 : &: error in cudaFree"
760 0 : stop 1
761 : endif
762 :
763 0 : successCUDA = cuda_free(tmp_dev)
764 0 : if (.not.(successCUDA)) then
765 : print *,"trans_ev_band_to_full_&
766 : &MATH_DATATYPE&
767 0 : &: error in cudaFree"
768 0 : stop 1
769 : endif
770 :
771 0 : successCUDA = cuda_free(tmat_dev)
772 0 : if (.not.(successCUDA)) then
773 : print *,"trans_ev_band_to_full_&
774 : &MATH_DATATYPE&
775 0 : &: error in cudaFree"
776 0 : stop 1
777 : endif
778 :
779 : ! final transfer of q_dev
780 0 : successCUDA = cuda_memcpy(loc(q), q_dev, ldq*matrixCols* size_of_datatype, cudaMemcpyDeviceToHost)
781 :
782 0 : if (.not.(successCUDA)) then
783 : print *,"trans_ev_band_to_full_&
784 : &MATH_DATATYPE&
785 0 : &: error in cudamemcpu q_dev"
786 0 : stop 1
787 : endif
788 :
789 : ! q(1:ldq,1:na_cols) = q_temp(1:ldq,1:na_cols)
790 :
791 0 : successCUDA = cuda_free(q_dev)
792 0 : if (.not.(successCUDA)) then
793 : print *,"trans_ev_band_to_full_&
794 : &MATH_DATATYPE&
795 0 : &: error in cudaFree"
796 0 : stop 1
797 : endif
798 :
799 : ! deallocate(q_temp, stat=istat, errmsg=errorMessage)
800 : ! if (istat .ne. 0) then
801 : ! print *,"error when deallocating q_temp "//errorMessage
802 : ! stop 1
803 : ! endif
804 : ! deallocate(tmat_temp, stat=istat, errmsg=errorMessage)
805 : ! if (istat .ne. 0) then
806 : ! print *,"trans_ev_band_to_full_real: error when deallocating tmat_temp "//errorMessage
807 : ! stop 1
808 : ! endif
809 :
810 : endif ! useGPU
811 :
812 45432 : deallocate(hvm, stat=istat, errmsg=errorMessage)
813 45432 : if (istat .ne. 0) then
814 : print *,"trans_ev_band_to_full_&
815 : &MATH_DATATYPE&
816 0 : &: error when deallocating hvm "//errorMessage
817 0 : stop 1
818 : endif
819 :
820 : #if BAND_TO_FULL_BLOCKING
821 22716 : if (.not.(useGPU)) then
822 22716 : deallocate(tmat_complete, t_tmp, t_tmp2, stat=istat, errmsg=errorMessage)
823 22716 : if (istat .ne. 0) then
824 : print *,"trans_ev_band_to_full_&
825 : &MATH_DATATYPE&
826 0 : &: error when deallocating tmat_complete, t_tmp, t_tmp2 "//errorMessage
827 0 : stop 1
828 : endif
829 : endif
830 : #endif
831 :
832 : call obj%timer%stop("trans_ev_band_to_full_&
833 : &MATH_DATATYPE&
834 : &" // &
835 : &PRECISION_SUFFIX&
836 45432 : )
837 :
838 : end subroutine trans_ev_band_to_full_&
839 : &MATH_DATATYPE&
840 : &_&
841 45432 : &PRECISION
842 :
843 :
|