Line data Source code
1 : ! This file is part of ELPA.
2 : !
3 : ! The ELPA library was originally created by the ELPA consortium,
4 : ! consisting of the following organizations:
5 : !
6 : ! - Max Planck Computing and Data Facility (MPCDF), formerly known as
7 : ! Rechenzentrum Garching der Max-Planck-Gesellschaft (RZG),
8 : ! - Bergische Universität Wuppertal, Lehrstuhl für angewandte
9 : ! Informatik,
10 : ! - Technische Universität München, Lehrstuhl für Informatik mit
11 : ! Schwerpunkt Wissenschaftliches Rechnen ,
12 : ! - Fritz-Haber-Institut, Berlin, Abt. Theorie,
13 : ! - Max-Plack-Institut für Mathematik in den Naturwissenschaften,
14 : ! Leipzig, Abt. Komplexe Strukutren in Biologie und Kognition,
15 : ! and
16 : ! - IBM Deutschland GmbH
17 : !
18 : !
19 : ! More information can be found here:
20 : ! http://elpa.mpcdf.mpg.de/
21 : !
22 : ! ELPA is free software: you can redistribute it and/or modify
23 : ! it under the terms of the version 3 of the license of the
24 : ! GNU Lesser General Public License as published by the Free
25 : ! Software Foundation.
26 : !
27 : ! ELPA is distributed in the hope that it will be useful,
28 : ! but WITHOUT ANY WARRANTY; without even the implied warranty of
29 : ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30 : ! GNU Lesser General Public License for more details.
31 : !
32 : ! You should have received a copy of the GNU Lesser General Public License
33 : ! along with ELPA. If not, see <http://www.gnu.org/licenses/>
34 : !
35 : ! ELPA reflects a substantial effort on the part of the original
36 : ! ELPA consortium, and we ask you to respect the spirit of the
37 : ! license that we chose: i.e., please contribute any changes you
38 : ! may have back to the original ELPA library distribution, and keep
39 : ! any derivatives of ELPA under the same license that we chose for
40 : ! the original distribution, the GNU Lesser General Public License.
41 : !
42 : !
43 : #include "config-f90.h"
44 :
45 : subroutine qr_pdlarfb_1dcomm_&
46 0 : &PRECISION &
47 0 : (m,mb,n,k,a,lda,v,ldv,tau,t,ldt,baseidx,idx,rev,mpicomm,work,lwork)
48 : use precision
49 : use qr_utils_mod
50 :
51 : implicit none
52 :
53 : ! input variables (local)
54 : integer(kind=ik) :: lda,ldv,ldt,lwork
55 : real(kind=C_DATATYPE_KIND) :: a(lda,*),v(ldv,*),tau(*),t(ldt,*),work(k,*)
56 :
57 : ! input variables (global)
58 : integer(kind=ik) :: m,mb,n,k,baseidx,idx,rev,mpicomm
59 :
60 : ! output variables (global)
61 :
62 : ! derived input variables from QR_PQRPARAM
63 :
64 : ! local scalars
65 : integer(kind=ik) :: localsize,offset,baseoffset
66 : integer(kind=ik) :: mpirank,mpiprocs,mpierr
67 :
68 0 : if (idx .le. 1) return
69 :
70 0 : if (n .le. 0) return ! nothing to do
71 :
72 0 : if (k .eq. 1) then
73 : call qr_pdlarfl_1dcomm_&
74 : &PRECISION &
75 : (v,1,baseidx,a,lda,tau(1), &
76 0 : work,lwork,m,n,idx,mb,rev,mpicomm)
77 0 : return
78 0 : else if (k .eq. 2) then
79 : call qr_pdlarfl2_tmatrix_1dcomm_&
80 : &PRECISION &
81 : (v,ldv,baseidx,a,lda,t,ldt, &
82 0 : work,lwork,m,n,idx,mb,rev,mpicomm)
83 0 : return
84 : end if
85 :
86 0 : if (lwork .eq. -1) then
87 : #ifdef DOUBLE_PRECISION_REAL
88 0 : work(1,1) =real(2*k*n,kind=rk8)
89 : #else
90 0 : work(1,1) =real(2*k*n,kind=rk4)
91 : #endif
92 0 : return
93 : end if
94 :
95 : !print *,'updating trailing matrix with k=',k
96 0 : call MPI_Comm_rank(mpicomm,mpirank,mpierr)
97 0 : call MPI_Comm_size(mpicomm,mpiprocs,mpierr)
98 : ! use baseidx as idx here, otherwise the upper triangle part will be lost
99 : ! during the calculation, especially in the reversed case
100 : call local_size_offset_1d(m,mb,baseidx,baseidx,rev,mpirank,mpiprocs, &
101 0 : localsize,baseoffset,offset)
102 :
103 : ! Z' = Y' * A
104 0 : if (localsize .gt. 0) then
105 : #ifdef DOUBLE_PRECISION_REAL
106 0 : call dgemm("Trans","Notrans",k,n,localsize,1.0_rk8,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk8,work(1,1),k)
107 : #else
108 0 : call sgemm("Trans","Notrans",k,n,localsize,1.0_rk4,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk4,work(1,1),k)
109 : #endif
110 : else
111 : #ifdef DOUBLE_PRECISION_REAL
112 0 : work(1:k,1:n) = 0.0_rk8
113 : #else
114 0 : work(1:k,1:n) = 0.0_rk4
115 : #endif
116 : end if
117 :
118 : ! data exchange
119 : #ifdef WITH_MPI
120 :
121 : #ifdef DOUBLE_PRECISION_REAL
122 0 : call mpi_allreduce(work(1,1),work(1,n+1),k*n,mpi_real8,mpi_sum,mpicomm,mpierr)
123 : #else
124 0 : call mpi_allreduce(work(1,1),work(1,n+1),k*n,mpi_real4,mpi_sum,mpicomm,mpierr)
125 : #endif
126 :
127 : #else /* WITH_MPI */
128 0 : work(1:k*n,n+1) = work(1:k*n,1)
129 : #endif
130 : call qr_pdlarfb_kernel_local_&
131 : &PRECISION &
132 0 : (localsize,n,k,a(offset,1),lda,v(baseoffset,1),ldv,t,ldt,work(1,n+1),k)
133 : end subroutine
134 :
135 : ! generalized pdlarfl2 version
136 : ! TODO: include T merge here (seperate by "old" and "new" index)
137 : subroutine qr_pdlarft_pdlarfb_1dcomm_&
138 0 : &PRECISION &
139 0 : (m,mb,n,oldk,k,v,ldv,tau,t,ldt,a,lda,baseidx,rev,mpicomm,work,lwork)
140 : use precision
141 : use qr_utils_mod
142 :
143 : implicit none
144 :
145 : ! input variables (local)
146 : integer(kind=ik) :: ldv,ldt,lda,lwork
147 : real(kind=C_DATATYPE_KIND) :: v(ldv,*),tau(*),t(ldt,*),work(k,*),a(lda,*)
148 :
149 : ! input variables (global)
150 : integer(kind=ik) :: m,mb,n,k,oldk,baseidx,rev,mpicomm
151 :
152 : ! output variables (global)
153 :
154 : ! derived input variables from QR_PQRPARAM
155 :
156 : ! local scalars
157 : integer(kind=ik) :: localsize,offset,baseoffset
158 : integer(kind=ik) :: mpirank,mpiprocs,mpierr
159 : integer(kind=ik) :: icol
160 :
161 : integer(kind=ik) :: sendoffset,recvoffset,sendsize
162 :
163 0 : sendoffset = 1
164 0 : sendsize = k*(k+n+oldk)
165 0 : recvoffset = sendoffset+(k+n+oldk)
166 :
167 0 : if (lwork .eq. -1) then
168 : #ifdef DOUBLE_PRECISION_REAL
169 0 : work(1,1) = real(2*(k*k+k*n+oldk), kind=rk8)
170 : #else
171 0 : work(1,1) = real(2*(k*k+k*n+oldk), kind=rk4)
172 : #endif
173 0 : return
174 : end if
175 0 : call MPI_Comm_rank(mpicomm,mpirank,mpierr)
176 0 : call MPI_Comm_size(mpicomm,mpiprocs,mpierr)
177 : call local_size_offset_1d(m,mb,baseidx,baseidx,rev,mpirank,mpiprocs, &
178 0 : localsize,baseoffset,offset)
179 :
180 : #ifdef DOUBLE_PRECISION_REAL
181 0 : if (localsize .gt. 0) then
182 : ! calculate inner product of householdervectors
183 0 : call dsyrk("Upper","Trans",k,localsize,1.0_rk8,v(baseoffset,1),ldv,0.0_rk8,work(1,1),k)
184 :
185 : ! calculate matrix matrix product of householder vectors and target matrix
186 : ! Z' = Y' * A
187 0 : call dgemm("Trans","Notrans",k,n,localsize,1.0_rk8,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk8,work(1,k+1),k)
188 :
189 : ! TODO: reserved for T merge parts
190 0 : work(1:k,n+k+1:n+k+oldk) = 0.0_rk8
191 : else
192 0 : work(1:k,1:(n+k+oldk)) = 0.0_rk8
193 : end if
194 : #else /* DOUBLE_PRECISION_REAL */
195 0 : if (localsize .gt. 0) then
196 : ! calculate inner product of householdervectors
197 0 : call ssyrk("Upper","Trans",k,localsize,1.0_rk4,v(baseoffset,1),ldv,0.0_rk4,work(1,1),k)
198 :
199 : ! calculate matrix matrix product of householder vectors and target matrix
200 : ! Z' = Y' * A
201 0 : call sgemm("Trans","Notrans",k,n,localsize,1.0_rk4,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk4,work(1,k+1),k)
202 :
203 : ! TODO: reserved for T merge parts
204 0 : work(1:k,n+k+1:n+k+oldk) = 0.0_rk4
205 : else
206 0 : work(1:k,1:(n+k+oldk)) = 0.0_rk4
207 : end if
208 : #endif /* DOUBLE_PRECISION_REAL */
209 :
210 : ! exchange data
211 : #ifdef WITH_MPI
212 :
213 : #ifdef DOUBLE_PRECISION_REAL
214 0 : call mpi_allreduce(work(1,sendoffset),work(1,recvoffset),sendsize,mpi_real8,mpi_sum,mpicomm,mpierr)
215 : #else
216 0 : call mpi_allreduce(work(1,sendoffset),work(1,recvoffset),sendsize,mpi_real4,mpi_sum,mpicomm,mpierr)
217 : #endif
218 :
219 : #else /* WITH_MPI */
220 0 : work(1:sendsize,recvoffset) = work(1:sendsize,sendoffset)
221 : #endif
222 : ! generate T matrix (pdlarft)
223 : #ifdef DOUBLE_PRECISION_REAL
224 0 : t(1:k,1:k) = 0.0_rk8 ! DEBUG: clear buffer first
225 : #else
226 0 : t(1:k,1:k) = 0.0_rk4 ! DEBUG: clear buffer first
227 : #endif
228 : ! T1 = tau1
229 : ! | tauk Tk-1' * (-tauk * Y(:,1,k+1:n) * Y(:,k))' |
230 : ! | 0 Tk-1 |
231 0 : t(k,k) = tau(k)
232 0 : do icol=k-1,1,-1
233 0 : t(icol,icol+1:k) = -tau(icol)*work(icol,recvoffset+icol:recvoffset+k-1)
234 : #ifdef DOUBLE_PRECISION_REAL
235 0 : call dtrmv("Upper","Trans","Nonunit",k-icol,t(icol+1,icol+1),ldt,t(icol,icol+1),ldt)
236 : #else
237 0 : call strmv("Upper","Trans","Nonunit",k-icol,t(icol+1,icol+1),ldt,t(icol,icol+1),ldt)
238 : #endif
239 0 : t(icol,icol) = tau(icol)
240 : end do
241 :
242 : ! TODO: elmroth and gustavson
243 :
244 : ! update matrix (pdlarfb)
245 : ! Z' = T * Z'
246 : #ifdef DOUBLE_PRECISION_REAL
247 0 : call strmm("Left","Upper","Notrans","Nonunit",k,n,1.0_rk8,t,ldt,work(1,recvoffset+k),k)
248 :
249 : ! A = A - Y * V'
250 0 : call sgemm("Notrans","Notrans",localsize,n,k,-1.0_rk8,v(baseoffset,1),ldv,work(1,recvoffset+k),k,1.0_rk8,a(offset,1),lda)
251 : #else
252 0 : call strmm("Left","Upper","Notrans","Nonunit",k,n,1.0_rk4,t,ldt,work(1,recvoffset+k),k)
253 :
254 : ! A = A - Y * V'
255 0 : call sgemm("Notrans","Notrans",localsize,n,k,-1.0_rk4,v(baseoffset,1),ldv,work(1,recvoffset+k),k,1.0_rk4,a(offset,1),lda)
256 :
257 : #endif
258 : end subroutine
259 :
260 : subroutine qr_pdlarft_set_merge_1dcomm_&
261 0 : &PRECISION &
262 0 : (m,mb,n,blocksize,v,ldv,t,ldt,baseidx,rev,mpicomm,work,lwork)
263 : use precision
264 : use qr_utils_mod
265 :
266 : implicit none
267 :
268 : ! input variables (local)
269 : integer(kind=ik) :: ldv,ldt,lwork
270 : real(kind=C_DATATYPE_KIND) :: v(ldv,*),t(ldt,*),work(n,*)
271 :
272 : ! input variables (global)
273 : integer(kind=ik) :: m,mb,n,blocksize,baseidx,rev,mpicomm
274 :
275 : ! output variables (global)
276 :
277 : ! derived input variables from QR_PQRPARAM
278 :
279 : ! local scalars
280 : integer(kind=ik) :: localsize,offset,baseoffset
281 : integer(kind=ik) :: mpirank,mpiprocs,mpierr
282 :
283 0 : if (lwork .eq. -1) then
284 : #ifdef DOUBLE_PRECISION_REAL
285 0 : work(1,1) = real(2*n*n,kind=rk8)
286 : #else
287 0 : work(1,1) = real(2*n*n,kind=rk4)
288 :
289 : #endif
290 0 : return
291 : end if
292 0 : call MPI_Comm_rank(mpicomm,mpirank,mpierr)
293 0 : call MPI_Comm_size(mpicomm,mpiprocs,mpierr)
294 : call local_size_offset_1d(m,mb,baseidx,baseidx,rev,mpirank,mpiprocs, &
295 0 : localsize,baseoffset,offset)
296 : #ifdef DOUBLE_PRECISION_REAL
297 0 : if (localsize .gt. 0) then
298 0 : call dsyrk("Upper","Trans",n,localsize,1.0_rk8,v(baseoffset,1),ldv,0.0_rk8,work(1,1),n)
299 : else
300 0 : work(1:n,1:n) = 0.0_rk8
301 : end if
302 : #else
303 0 : if (localsize .gt. 0) then
304 0 : call ssyrk("Upper","Trans",n,localsize,1.0_rk4,v(baseoffset,1),ldv,0.0_rk4,work(1,1),n)
305 : else
306 0 : work(1:n,1:n) = 0.0_rk4
307 : end if
308 :
309 : #endif
310 :
311 : #ifdef WITH_MPI
312 :
313 : #ifdef DOUBLE_PRECISION_REAL
314 0 : call mpi_allreduce(work(1,1),work(1,n+1),n*n,mpi_real8,mpi_sum,mpicomm,mpierr)
315 : #else
316 0 : call mpi_allreduce(work(1,1),work(1,n+1),n*n,mpi_real4,mpi_sum,mpicomm,mpierr)
317 : #endif
318 :
319 : #else
320 0 : work(1:n,n+1:n+1+n-1) = work(1:n,1:n)
321 : #endif
322 : ! skip Y4'*Y4 part
323 0 : offset = mod(n,blocksize)
324 0 : if (offset .eq. 0) offset=blocksize
325 : call qr_tmerge_set_kernel_&
326 : &PRECISION &
327 0 : (n,blocksize,t,ldt,work(1,n+1+offset),n)
328 :
329 : end subroutine
330 :
331 : subroutine qr_pdlarft_tree_merge_1dcomm_&
332 0 : &PRECISION &
333 0 : (m,mb,n,blocksize,treeorder,v,ldv,t,ldt,baseidx,rev,mpicomm,work,lwork)
334 : use precision
335 : use qr_utils_mod
336 :
337 : implicit none
338 :
339 : ! input variables (local)
340 : integer(kind=ik) :: ldv,ldt,lwork
341 : real(kind=C_DATATYPE_KIND) :: v(ldv,*),t(ldt,*),work(n,*)
342 :
343 : ! input variables (global)
344 : integer(kind=ik) :: m,mb,n,blocksize,treeorder,baseidx,rev,mpicomm
345 :
346 : ! output variables (global)
347 :
348 : ! derived input variables from QR_PQRPARAM
349 :
350 : ! local scalars
351 : integer(kind=ik) :: localsize,offset,baseoffset
352 : integer(kind=ik) :: mpirank,mpiprocs,mpierr
353 :
354 0 : if (lwork .eq. -1) then
355 : #ifdef DOUBLE_PRECISION_REAL
356 0 : work(1,1) = real(2*n*n,kind=rk8)
357 : #else
358 0 : work(1,1) = real(2*n*n,kind=rk4)
359 : #endif
360 0 : return
361 : end if
362 :
363 0 : if (n .le. blocksize) return ! nothing to do
364 0 : call MPI_Comm_rank(mpicomm,mpirank,mpierr)
365 0 : call MPI_Comm_size(mpicomm,mpiprocs,mpierr)
366 : call local_size_offset_1d(m,mb,baseidx,baseidx,rev,mpirank,mpiprocs, &
367 0 : localsize,baseoffset,offset)
368 :
369 : #ifdef DOUBLE_PRECISION_REAL
370 0 : if (localsize .gt. 0) then
371 0 : call dsyrk("Upper","Trans",n,localsize,1.0_rk8,v(baseoffset,1),ldv,0.0_rk8,work(1,1),n)
372 : else
373 0 : work(1:n,1:n) = 0.0_rk8
374 : end if
375 : #else
376 0 : if (localsize .gt. 0) then
377 0 : call ssyrk("Upper","Trans",n,localsize,1.0_rk4,v(baseoffset,1),ldv,0.0_rk4,work(1,1),n)
378 : else
379 0 : work(1:n,1:n) = 0.0_rk4
380 : end if
381 : #endif
382 :
383 : #ifdef WITH_MPI
384 :
385 : #ifdef DOUBLE_PRECISION_REAL
386 0 : call mpi_allreduce(work(1,1),work(1,n+1),n*n,mpi_real8,mpi_sum,mpicomm,mpierr)
387 : #else
388 0 : call mpi_allreduce(work(1,1),work(1,n+1),n*n,mpi_real4,mpi_sum,mpicomm,mpierr)
389 : #endif
390 : #else
391 0 : work(1:n,n+1:n+1+n-1) = work(1:n,1:n)
392 : #endif
393 : ! skip Y4'*Y4 part
394 0 : offset = mod(n,blocksize)
395 0 : if (offset .eq. 0) offset=blocksize
396 : call qr_tmerge_tree_kernel_&
397 : &PRECISION &
398 0 : (n,blocksize,treeorder,t,ldt,work(1,n+1+offset),n)
399 :
400 : end subroutine
401 :
402 : ! apply householder Vector to the left
403 : ! - assume unitary matrix
404 : ! - assume right positions for v
405 : subroutine qr_pdlarfl_1dcomm_&
406 0 : &PRECISION &
407 0 : (v,incv,baseidx,a,lda,tau,work,lwork,m,n,idx,mb,rev,mpicomm)
408 : use precision
409 : use elpa1_impl
410 : use qr_utils_mod
411 :
412 : implicit none
413 :
414 : ! input variables (local)
415 : integer(kind=ik) :: incv,lda,lwork,baseidx
416 : real(kind=C_DATATYPE_KIND) :: v(*),a(lda,*),work(*)
417 :
418 : ! input variables (global)
419 : integer(kind=ik) :: m,n,mb,rev,idx,mpicomm
420 : real(kind=C_DATATYPE_KIND) :: tau
421 :
422 : ! output variables (global)
423 :
424 : ! local scalars
425 : integer(kind=ik) :: mpierr,mpirank,mpiprocs
426 : integer(kind=ik) :: sendsize,recvsize,icol
427 : integer(kind=ik) :: local_size,local_offset
428 : integer(kind=ik) :: v_local_offset
429 :
430 : ! external functions
431 : real(kind=C_DATATYPE_KIND), external :: ddot
432 0 : call MPI_Comm_rank(mpicomm, mpirank, mpierr)
433 0 : call MPI_Comm_size(mpicomm, mpiprocs, mpierr)
434 0 : sendsize = n
435 0 : recvsize = sendsize
436 :
437 0 : if (lwork .eq. -1) then
438 : #ifdef DOUBLE_PRECISION_REAL
439 0 : work(1) = real(sendsize + recvsize,kind=rk8)
440 : #else
441 0 : work(1) = real(sendsize + recvsize,kind=rk4)
442 : #endif
443 0 : return
444 : end if
445 :
446 0 : if (n .le. 0) return
447 :
448 0 : if (idx .le. 1) return
449 :
450 : call local_size_offset_1d(m,mb,baseidx,idx,rev,mpirank,mpiprocs, &
451 0 : local_size,v_local_offset,local_offset)
452 :
453 : !print *,'hl ref',local_size,n
454 :
455 0 : v_local_offset = v_local_offset * incv
456 :
457 0 : if (local_size > 0) then
458 :
459 0 : do icol=1,n
460 0 : work(icol) = dot_product(v(v_local_offset:v_local_offset+local_size-1),a(local_offset:local_offset+local_size-1,icol))
461 :
462 : end do
463 : else
464 : #ifdef DOUBLE_PRECISION_REAL
465 0 : work(1:n) = 0.0_rk8
466 : #else
467 0 : work(1:n) = 0.0_rk4
468 : #endif
469 : end if
470 : #ifdef WITH_MPI
471 :
472 : #ifdef DOUBLE_PRECISION_REAL
473 0 : call mpi_allreduce(work, work(sendsize+1), sendsize, mpi_real8, mpi_sum, mpicomm, mpierr)
474 : #else
475 0 : call mpi_allreduce(work, work(sendsize+1), sendsize, mpi_real4, mpi_sum, mpicomm, mpierr)
476 : #endif
477 : #else
478 0 : work(sendsize+1:sendsize+1+sendsize+1+sendsize-1) = work(1:sendsize)
479 : #endif
480 0 : if (local_size > 0) then
481 :
482 0 : do icol=1,n
483 : a(local_offset:local_offset+local_size-1,icol) = a(local_offset:local_offset+local_size-1,icol) &
484 : - tau*work(sendsize+icol)*v(v_local_offset:v_local_offset+ &
485 0 : local_size-1)
486 : enddo
487 : end if
488 :
489 : end subroutine
490 :
491 : subroutine qr_pdlarfl2_tmatrix_1dcomm_&
492 0 : &PRECISION &
493 0 : (v,ldv,baseidx,a,lda,t,ldt,work,lwork,m,n,idx,mb,rev,mpicomm)
494 : use precision
495 : use elpa1_impl
496 : use qr_utils_mod
497 :
498 : implicit none
499 :
500 : ! input variables (local)
501 : integer(kind=ik) :: ldv,lda,lwork,baseidx,ldt
502 : real(kind=C_DATATYPE_KIND) :: v(ldv,*),a(lda,*),work(*),t(ldt,*)
503 :
504 : ! input variables (global)
505 : integer(kind=ik) :: m,n,mb,rev,idx,mpicomm
506 :
507 : ! output variables (global)
508 :
509 : ! local scalars
510 : integer(kind=ik) :: mpierr,mpirank,mpiprocs,mpirank_top1,mpirank_top2
511 : integer(kind=ik) :: dgemv1_offset,dgemv2_offset
512 : integer(kind=ik) :: sendsize, recvsize
513 : integer(kind=ik) :: local_size1,local_offset1
514 : integer(kind=ik) :: local_size2,local_offset2
515 : integer(kind=ik) :: local_size_dger,local_offset_dger
516 : integer(kind=ik) :: v1_local_offset,v2_local_offset
517 : integer(kind=ik) :: v_local_offset_dger
518 : real(kind=C_DATATYPE_KIND) :: hvdot
519 : integer(kind=ik) :: irow,icol,v1col,v2col
520 :
521 : ! external functions
522 : real(kind=C_DATATYPE_KIND), external :: ddot
523 0 : call MPI_Comm_rank(mpicomm, mpirank, mpierr)
524 0 : call MPI_Comm_size(mpicomm, mpiprocs, mpierr)
525 0 : sendsize = 2*n
526 0 : recvsize = sendsize
527 :
528 0 : if (lwork .eq. -1) then
529 0 : work(1) = sendsize + recvsize
530 0 : return
531 : end if
532 :
533 0 : dgemv1_offset = 1
534 0 : dgemv2_offset = dgemv1_offset + n
535 :
536 : ! in 2x2 matrix case only one householder Vector was generated
537 0 : if (idx .le. 2) then
538 : call qr_pdlarfl_1dcomm_&
539 : &PRECISION &
540 : (v(1,2),1,baseidx,a,lda,t(2,2), &
541 0 : work,lwork,m,n,idx,mb,rev,mpicomm)
542 0 : return
543 : end if
544 :
545 : call local_size_offset_1d(m,mb,baseidx,idx,rev,mpirank,mpiprocs, &
546 0 : local_size1,v1_local_offset,local_offset1)
547 : call local_size_offset_1d(m,mb,baseidx,idx-1,rev,mpirank,mpiprocs, &
548 0 : local_size2,v2_local_offset,local_offset2)
549 :
550 0 : v1_local_offset = v1_local_offset * 1
551 0 : v2_local_offset = v2_local_offset * 1
552 :
553 0 : v1col = 2
554 0 : v2col = 1
555 :
556 : ! keep buffers clean in case that local_size1/local_size2 are zero
557 : #ifdef DOUBLE_PRECISION_REAL
558 0 : work(1:sendsize) = 0.0_rk8
559 :
560 0 : call dgemv("Trans",local_size1,n,1.0_rk8,a(local_offset1,1),lda,v(v1_local_offset,v1col),1,0.0_rk8,work(dgemv1_offset),1)
561 : call dgemv("Trans",local_size2,n,t(v2col,v2col),a(local_offset2,1),lda,v(v2_local_offset,v2col),1,0.0_rk8, &
562 0 : work(dgemv2_offset),1)
563 : #else
564 0 : work(1:sendsize) = 0.0_rk4
565 :
566 0 : call sgemv("Trans",local_size1,n,1.0_rk4,a(local_offset1,1),lda,v(v1_local_offset,v1col),1,0.0_rk4,work(dgemv1_offset),1)
567 : call sgemv("Trans",local_size2,n,t(v2col,v2col),a(local_offset2,1),lda,v(v2_local_offset,v2col),1,0.0_rk4, &
568 0 : work(dgemv2_offset),1)
569 : #endif
570 :
571 : #ifdef WITH_MPI
572 :
573 : #ifdef DOUBLE_PRECISION_REAL
574 0 : call mpi_allreduce(work, work(sendsize+1), sendsize, mpi_real8, mpi_sum, mpicomm, mpierr)
575 : #else
576 0 : call mpi_allreduce(work, work(sendsize+1), sendsize, mpi_real4, mpi_sum, mpicomm, mpierr)
577 : #endif
578 : #else
579 0 : work(sendsize+1:sendsize+1+sendsize-1) = work(1:sendsize)
580 : #endif
581 : ! update second Vector
582 : #ifdef DOUBLE_PRECISION_REAL
583 0 : call daxpy(n,t(1,2),work(sendsize+dgemv1_offset),1,work(sendsize+dgemv2_offset),1)
584 : #else
585 0 : call saxpy(n,t(1,2),work(sendsize+dgemv1_offset),1,work(sendsize+dgemv2_offset),1)
586 : #endif
587 :
588 : call local_size_offset_1d(m,mb,baseidx,idx-2,rev,mpirank,mpiprocs, &
589 0 : local_size_dger,v_local_offset_dger,local_offset_dger)
590 :
591 : ! get ranks of processes with topelements
592 0 : mpirank_top1 = MOD((idx-1)/mb,mpiprocs)
593 0 : mpirank_top2 = MOD((idx-2)/mb,mpiprocs)
594 :
595 0 : if (mpirank_top1 .eq. mpirank) local_offset1 = local_size1
596 0 : if (mpirank_top2 .eq. mpirank) then
597 0 : local_offset2 = local_size2
598 0 : v2_local_offset = local_size2
599 : end if
600 :
601 : ! use hvdot as temporary variable
602 0 : hvdot = t(v1col,v1col)
603 0 : do icol=1,n
604 : ! make use of "1" entries in householder vectors
605 0 : if (mpirank_top1 .eq. mpirank) then
606 : a(local_offset1,icol) = a(local_offset1,icol) &
607 0 : - work(sendsize+dgemv1_offset+icol-1)*hvdot
608 : end if
609 :
610 0 : if (mpirank_top2 .eq. mpirank) then
611 : a(local_offset2,icol) = a(local_offset2,icol) &
612 : - v(v2_local_offset,v1col)*work(sendsize+dgemv1_offset+icol-1)*hvdot &
613 0 : - work(sendsize+dgemv2_offset+icol-1)
614 : end if
615 :
616 0 : do irow=1,local_size_dger
617 : a(local_offset_dger+irow-1,icol) = a(local_offset_dger+irow-1,icol) &
618 : - work(sendsize+dgemv1_offset+icol-1)*v(v_local_offset_dger+irow-1,v1col)*hvdot &
619 0 : - work(sendsize+dgemv2_offset+icol-1)*v(v_local_offset_dger+irow-1,v2col)
620 : end do
621 : end do
622 :
623 : end subroutine
624 :
625 : ! generalized pdlarfl2 version
626 : ! TODO: include T merge here (seperate by "old" and "new" index)
627 : subroutine qr_tmerge_pdlarfb_1dcomm_&
628 0 : &PRECISION &
629 0 : (m,mb,n,oldk,k,v,ldv,t,ldt,a,lda,baseidx,rev,updatemode,mpicomm,work,lwork)
630 : use precision
631 : use qr_utils_mod
632 :
633 : implicit none
634 :
635 : ! input variables (local)
636 : integer(kind=ik) :: ldv,ldt,lda,lwork
637 : real(kind=C_DATATYPE_KIND) :: v(ldv,*),t(ldt,*),work(*),a(lda,*)
638 :
639 : ! input variables (global)
640 : integer(kind=ik) :: m,mb,n,k,oldk,baseidx,rev,updatemode,mpicomm
641 :
642 : ! output variables (global)
643 :
644 : ! derived input variables from QR_PQRPARAM
645 :
646 : ! local scalars
647 : integer(kind=ik) :: localsize,offset,baseoffset
648 : integer(kind=ik) :: mpirank,mpiprocs,mpierr
649 :
650 : integer(kind=ik) :: sendoffset,recvoffset,sendsize
651 : integer(kind=ik) :: updateoffset,updatelda,updatesize
652 : integer(kind=ik) :: mergeoffset,mergelda,mergesize
653 : integer(kind=ik) :: tgenoffset,tgenlda,tgensize
654 :
655 : ! quickfix
656 0 : mergeoffset = 0
657 :
658 0 : if (updatemode .eq. ichar('I')) then
659 0 : updatelda = oldk+k
660 : else
661 0 : updatelda = k
662 : end if
663 :
664 0 : updatesize = updatelda*n
665 :
666 0 : mergelda = k
667 0 : mergesize = mergelda*oldk
668 :
669 0 : tgenlda = 0
670 0 : tgensize = 0
671 :
672 0 : sendsize = updatesize + mergesize + tgensize
673 :
674 0 : if (lwork .eq. -1) then
675 : #ifdef DOUBLE_PRECISION_REAL
676 0 : work(1) = real(2*sendsize,kind=rk8)
677 : #else
678 0 : work(1) = real(2*sendsize,kind=rk4)
679 : #endif
680 0 : return
681 : end if
682 0 : call MPI_Comm_rank(mpicomm,mpirank,mpierr)
683 0 : call MPI_Comm_size(mpicomm,mpiprocs,mpierr)
684 : ! use baseidx as idx here, otherwise the upper triangle part will be lost
685 : ! during the calculation, especially in the reversed case
686 : call local_size_offset_1d(m,mb,baseidx,baseidx,rev,mpirank,mpiprocs, &
687 0 : localsize,baseoffset,offset)
688 :
689 0 : sendoffset = 1
690 :
691 0 : if (oldk .gt. 0) then
692 0 : updateoffset = 0
693 0 : mergeoffset = updateoffset + updatesize
694 0 : tgenoffset = mergeoffset + mergesize
695 :
696 0 : sendsize = updatesize + mergesize + tgensize
697 :
698 : !print *,'sendsize',sendsize,updatesize,mergesize,tgensize
699 : !print *,'merging nr of rotations', oldk+k
700 : #ifdef DOUBLE_PRECISION_REAL
701 0 : if (localsize .gt. 0) then
702 : ! calculate matrix matrix product of householder vectors and target matrix
703 0 : if (updatemode .eq. ichar('I')) then
704 : ! Z' = (Y1,Y2)' * A
705 : call dgemm("Trans","Notrans",k+oldk,n,localsize,1.0_rk8,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk8, &
706 0 : work(sendoffset+updateoffset),updatelda)
707 : else
708 : ! Z' = Y1' * A
709 : call dgemm("Trans","Notrans",k,n,localsize,1.0_rk8,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk8, &
710 0 : work(sendoffset+updateoffset),updatelda)
711 : end if
712 :
713 : ! calculate parts needed for T merge
714 : call dgemm("Trans","Notrans",k,oldk,localsize,1.0_rk8,v(baseoffset,1),ldv,v(baseoffset,k+1),ldv,0.0_rk8, &
715 0 : work(sendoffset+mergeoffset),mergelda)
716 :
717 : else
718 : ! cleanup buffer
719 0 : work(sendoffset:sendoffset+sendsize-1) = 0.0_rk8
720 : end if
721 : #else /* DOUBLE_PRECISION_REAL */
722 0 : if (localsize .gt. 0) then
723 : ! calculate matrix matrix product of householder vectors and target matrix
724 0 : if (updatemode .eq. ichar('I')) then
725 : ! Z' = (Y1,Y2)' * A
726 : call sgemm("Trans","Notrans",k+oldk,n,localsize,1.0_rk4,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk4, &
727 0 : work(sendoffset+updateoffset),updatelda)
728 : else
729 : ! Z' = Y1' * A
730 : call sgemm("Trans","Notrans",k,n,localsize,1.0_rk4,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk4, &
731 0 : work(sendoffset+updateoffset),updatelda)
732 : end if
733 :
734 : ! calculate parts needed for T merge
735 : call sgemm("Trans","Notrans",k,oldk,localsize,1.0_rk4,v(baseoffset,1),ldv,v(baseoffset,k+1),ldv,0.0_rk4, &
736 0 : work(sendoffset+mergeoffset),mergelda)
737 :
738 : else
739 : ! cleanup buffer
740 0 : work(sendoffset:sendoffset+sendsize-1) = 0.0_rk4
741 : end if
742 : #endif /* DOUBLE_PRECISION_REAL */
743 :
744 : else
745 : ! do not calculate parts for T merge as there is nothing to merge
746 :
747 0 : mergeoffset = 0
748 0 : updateoffset = 0
749 :
750 0 : tgenoffset = updateoffset + updatesize
751 :
752 0 : sendsize = updatesize + tgensize
753 : #ifdef DOUBLE_PRECISION_REAL
754 0 : if (localsize .gt. 0) then
755 : ! calculate matrix matrix product of householder vectors and target matrix
756 : ! Z' = (Y1)' * A
757 : call dgemm("Trans","Notrans",k,n,localsize,1.0_rk8,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk8, &
758 0 : work(sendoffset+updateoffset),updatelda)
759 :
760 : else
761 : ! cleanup buffer
762 0 : work(sendoffset:sendoffset+sendsize-1) = 0.0_rk8
763 : end if
764 : #else
765 0 : if (localsize .gt. 0) then
766 : ! calculate matrix matrix product of householder vectors and target matrix
767 : ! Z' = (Y1)' * A
768 : call sgemm("Trans","Notrans",k,n,localsize,1.0_rk4,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk4, &
769 0 : work(sendoffset+updateoffset),updatelda)
770 :
771 : else
772 : ! cleanup buffer
773 0 : work(sendoffset:sendoffset+sendsize-1) = 0.0_rk4
774 : end if
775 : #endif
776 : end if
777 :
778 0 : recvoffset = sendoffset + sendsize
779 :
780 0 : if (sendsize .le. 0) return ! nothing to do
781 :
782 : ! exchange data
783 : #ifdef WITH_MPI
784 : #ifdef DOUBLE_PRECISION_REAL
785 0 : call mpi_allreduce(work(sendoffset),work(recvoffset),sendsize,mpi_real8,mpi_sum,mpicomm,mpierr)
786 : #else
787 0 : call mpi_allreduce(work(sendoffset),work(recvoffset),sendsize,mpi_real4,mpi_sum,mpicomm,mpierr)
788 : #endif
789 :
790 : #else
791 0 : work(recvoffset:recvoffset+sendsize-1) = work(sendoffset:sendoffset+sendsize-1)
792 : #endif
793 0 : updateoffset = recvoffset+updateoffset
794 0 : mergeoffset = recvoffset+mergeoffset
795 0 : tgenoffset = recvoffset+tgenoffset
796 :
797 0 : if (oldk .gt. 0) then
798 : call qr_pdlarft_merge_kernel_local_&
799 : &PRECISION &
800 0 : (oldk,k,t,ldt,work(mergeoffset),mergelda)
801 :
802 0 : if (localsize .gt. 0) then
803 0 : if (updatemode .eq. ichar('I')) then
804 :
805 : ! update matrix (pdlarfb) with complete T
806 : call qr_pdlarfb_kernel_local_&
807 : &PRECISION &
808 : (localsize,n,k+oldk,a(offset,1),lda,v(baseoffset,1),ldv,t(1,1),ldt, &
809 0 : work(updateoffset),updatelda)
810 : else
811 : ! update matrix (pdlarfb) with small T (same as update with no old T TODO)
812 : call qr_pdlarfb_kernel_local_&
813 : &PRECISION &
814 : (localsize,n,k,a(offset,1),lda,v(baseoffset,1),ldv,t(1,1),ldt, &
815 0 : work(updateoffset),updatelda)
816 : end if
817 : end if
818 : else
819 0 : if (localsize .gt. 0) then
820 : ! update matrix (pdlarfb) with small T
821 : call qr_pdlarfb_kernel_local_&
822 : &PRECISION &
823 : (localsize,n,k,a(offset,1),lda,v(baseoffset,1),ldv,t(1,1),ldt, &
824 0 : work(updateoffset),updatelda)
825 : end if
826 : end if
827 :
828 : end subroutine
|