Line data Source code
1 : !
2 : ! The ELPA library was originally created by the ELPA consortium,
3 : ! consisting of the following organizations:
4 : !
5 : ! - Max Planck Computing and Data Facility (MPCDF), formerly known as
6 : ! Rechenzentrum Garching der Max-Planck-Gesellschaft (RZG),
7 : ! - Bergische Universität Wuppertal, Lehrstuhl für angewandte
8 : ! Informatik,
9 : ! - Technische Universität München, Lehrstuhl für Informatik mit
10 : ! Schwerpunkt Wissenschaftliches Rechnen ,
11 : ! - Fritz-Haber-Institut, Berlin, Abt. Theorie,
12 : ! - Max-Plack-Institut für Mathematik in den Naturwissenschaften,
13 : ! Leipzig, Abt. Komplexe Strukutren in Biologie und Kognition,
14 : ! and
15 : ! - IBM Deutschland GmbH
16 : !
17 : ! This particular source code file contains additions, changes and
18 : ! enhancements authored by Intel Corporation which is not part of
19 : ! the ELPA consortium.
20 : !
21 : ! More information can be found here:
22 : ! http://elpa.mpcdf.mpg.de/
23 : !
24 : ! ELPA is free software: you can redistribute it and/or modify
25 : ! it under the terms of the version 3 of the license of the
26 : ! GNU Lesser General Public License as published by the Free
27 : ! Software Foundation.
28 : !
29 : ! ELPA is distributed in the hope that it will be useful,
30 : ! but WITHOUT ANY WARRANTY; without even the implied warranty of
31 : ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
32 : ! GNU Lesser General Public License for more details.
33 : !
34 : ! You should have received a copy of the GNU Lesser General Public License
35 : ! along with ELPA. If not, see <http://www.gnu.org/licenses/>
36 : !
37 : ! ELPA reflects a substantial effort on the part of the original
38 : ! ELPA consortium, and we ask you to respect the spirit of the
39 : ! license that we chose: i.e., please contribute any changes you
40 : ! may have back to the original ELPA library distribution, and keep
41 : ! any derivatives of ELPA under the same license that we chose for
42 : ! the original distribution, the GNU Lesser General Public License.
43 : !
44 : !
45 : ! ELPA1 -- Faster replacements for ScaLAPACK symmetric eigenvalue routines
46 : !
47 : ! Copyright of the original code rests with the authors inside the ELPA
48 : ! consortium. The copyright of any additional modifications shall rest
49 : ! with their original authors, but shall adhere to the licensing terms
50 : ! distributed along with the original code in the file "COPYING".
51 : !
52 : ! Author: A. Marek, MPCDF
53 :
54 :
55 : #include "../general/sanity.F90"
56 :
57 : use elpa1_compute
58 : use elpa_mpi
59 : use precision
60 : use elpa_abstract_impl
61 : implicit none
62 :
63 : #include "../../src/general/precision_kinds.F90"
64 : class(elpa_abstract_impl_t), intent(inout) :: obj
65 :
66 : character*1 :: uplo_a, uplo_c
67 :
68 : integer(kind=ik), intent(in) :: ldb, ldbCols, ldc, ldcCols
69 : integer(kind=ik) :: na, ncb
70 : #ifdef USE_ASSUMED_SIZE
71 : MATH_DATATYPE(kind=rck) :: a(obj%local_nrows,*), b(ldb,*), c(ldc,*)
72 : #else
73 : MATH_DATATYPE(kind=rck) :: a(obj%local_nrows,obj%local_ncols), b(ldb,ldbCols), c(ldc,ldcCols)
74 : #endif
75 : integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr
76 : integer(kind=ik) :: l_cols, l_rows, l_rows_np
77 : integer(kind=ik) :: np, n, nb, nblk_mult, lrs, lre, lcs, lce
78 : integer(kind=ik) :: gcol_min, gcol, goff
79 : integer(kind=ik) :: nstor, nr_done, noff, np_bc, n_aux_bc, nvals
80 2016 : integer(kind=ik), allocatable :: lrs_save(:), lre_save(:)
81 :
82 : logical :: a_lower, a_upper, c_lower, c_upper
83 4032 : MATH_DATATYPE(kind=rck), allocatable :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:)
84 : integer(kind=ik) :: istat
85 : character(200) :: errorMessage
86 : logical :: success
87 : integer(kind=ik) :: nblk, mpi_comm_rows, mpi_comm_cols, lda, ldaCols, error
88 :
89 : call obj%timer%start("elpa_mult_at_b_&
90 : &MATH_DATATYPE&
91 : &_&
92 : &PRECISION&
93 2016 : &")
94 :
95 2016 : na = obj%na
96 2016 : nblk = obj%nblk
97 2016 : lda = obj%local_nrows
98 2016 : ldaCols = obj%local_ncols
99 :
100 :
101 2016 : call obj%get("mpi_comm_rows",mpi_comm_rows,error)
102 2016 : if (error .ne. ELPA_OK) then
103 0 : print *,"Problem getting option. Aborting..."
104 0 : stop
105 : endif
106 2016 : call obj%get("mpi_comm_cols",mpi_comm_cols,error)
107 2016 : if (error .ne. ELPA_OK) then
108 0 : print *,"Problem getting option. Aborting..."
109 0 : stop
110 : endif
111 :
112 :
113 2016 : success = .true.
114 :
115 2016 : call obj%timer%start("mpi_communication")
116 2016 : call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
117 2016 : call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
118 2016 : call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
119 2016 : call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
120 2016 : call obj%timer%stop("mpi_communication")
121 2016 : l_rows = local_index(na, my_prow, np_rows, nblk, -1) ! Local rows of a and b
122 2016 : l_cols = local_index(ncb, my_pcol, np_cols, nblk, -1) ! Local cols of b
123 :
124 : ! Block factor for matrix multiplications, must be a multiple of nblk
125 :
126 2016 : if (na/np_rows<=256) then
127 2016 : nblk_mult = (31/nblk+1)*nblk
128 : else
129 0 : nblk_mult = (63/nblk+1)*nblk
130 : endif
131 :
132 2016 : allocate(aux_mat(l_rows,nblk_mult), stat=istat, errmsg=errorMessage)
133 2016 : if (istat .ne. 0) then
134 : print *,"elpa_mult_at_b_&
135 : &MATH_DATATYPE&
136 0 : &: error when allocating aux_mat "//errorMessage
137 0 : stop 1
138 : endif
139 :
140 2016 : allocate(aux_bc(l_rows*nblk), stat=istat, errmsg=errorMessage)
141 2016 : if (istat .ne. 0) then
142 : print *,"elpa_mult_at_b_&
143 : &MATH_DATATYPE&
144 0 : &: error when allocating aux_bc "//errorMessage
145 0 : stop 1
146 : endif
147 :
148 2016 : allocate(lrs_save(nblk), stat=istat, errmsg=errorMessage)
149 2016 : if (istat .ne. 0) then
150 : print *,"elpa_mult_at_b_&
151 : &MATH_DATATYPE&
152 0 : &: error when allocating lrs_save "//errorMessage
153 0 : stop 1
154 : endif
155 :
156 2016 : allocate(lre_save(nblk), stat=istat, errmsg=errorMessage)
157 2016 : if (istat .ne. 0) then
158 : print *,"elpa_mult_at_b_&
159 : &MATH_DATATYPE&
160 0 : &: error when allocating lre_save "//errorMessage
161 0 : stop 1
162 : endif
163 :
164 2016 : a_lower = .false.
165 2016 : a_upper = .false.
166 2016 : c_lower = .false.
167 2016 : c_upper = .false.
168 :
169 2016 : if (uplo_a=='u' .or. uplo_a=='U') a_upper = .true.
170 2016 : if (uplo_a=='l' .or. uplo_a=='L') a_lower = .true.
171 2016 : if (uplo_c=='u' .or. uplo_c=='U') c_upper = .true.
172 2016 : if (uplo_c=='l' .or. uplo_c=='L') c_lower = .true.
173 :
174 : ! Build up the result matrix by processor rows
175 :
176 5376 : do np = 0, np_rows-1
177 :
178 : ! In this turn, procs of row np assemble the result
179 :
180 3360 : l_rows_np = local_index(na, np, np_rows, nblk, -1) ! local rows on receiving processors
181 :
182 3360 : nr_done = 0 ! Number of rows done
183 3360 : aux_mat = 0
184 3360 : nstor = 0 ! Number of columns stored in aux_mat
185 :
186 : ! Loop over the blocks on row np
187 :
188 23520 : do nb=0,(l_rows_np-1)/nblk
189 :
190 20160 : goff = nb*np_rows + np ! Global offset in blocks corresponding to nb
191 :
192 : ! Get the processor column which owns this block (A is transposed, so we need the column)
193 : ! and the offset in blocks within this column.
194 : ! The corresponding block column in A is then broadcast to all for multiplication with B
195 :
196 20160 : np_bc = MOD(goff,np_cols)
197 20160 : noff = goff/np_cols
198 20160 : n_aux_bc = 0
199 :
200 : ! Gather up the complete block column of A on the owner
201 :
202 322560 : do n = 1, min(l_rows_np-nb*nblk,nblk) ! Loop over columns to be broadcast
203 :
204 302400 : gcol = goff*nblk + n ! global column corresponding to n
205 302400 : if (nstor==0 .and. n==1) gcol_min = gcol
206 :
207 302400 : lrs = 1 ! 1st local row number for broadcast
208 302400 : lre = l_rows ! last local row number for broadcast
209 302400 : if (a_lower) lrs = local_index(gcol, my_prow, np_rows, nblk, +1)
210 302400 : if (a_upper) lre = local_index(gcol, my_prow, np_rows, nblk, -1)
211 :
212 302400 : if (lrs<=lre) then
213 302400 : nvals = lre-lrs+1
214 302400 : if (my_pcol == np_bc) aux_bc(n_aux_bc+1:n_aux_bc+nvals) = a(lrs:lre,noff*nblk+n)
215 302400 : n_aux_bc = n_aux_bc + nvals
216 : endif
217 :
218 302400 : lrs_save(n) = lrs
219 302400 : lre_save(n) = lre
220 :
221 : enddo
222 :
223 : ! Broadcast block column
224 : #ifdef WITH_MPI
225 13440 : call obj%timer%start("mpi_communication")
226 : #if REALCASE == 1
227 : call MPI_Bcast(aux_bc, n_aux_bc, &
228 : MPI_REAL_PRECISION, &
229 7680 : np_bc, mpi_comm_cols, mpierr)
230 : #endif
231 : #if COMPLEXCASE == 1
232 : call MPI_Bcast(aux_bc, n_aux_bc, &
233 : MPI_COMPLEX_PRECISION, &
234 5760 : np_bc, mpi_comm_cols, mpierr)
235 : #endif
236 13440 : call obj%timer%stop("mpi_communication")
237 : #endif /* WITH_MPI */
238 : ! Insert what we got in aux_mat
239 :
240 20160 : n_aux_bc = 0
241 322560 : do n = 1, min(l_rows_np-nb*nblk,nblk)
242 302400 : nstor = nstor+1
243 302400 : lrs = lrs_save(n)
244 302400 : lre = lre_save(n)
245 302400 : if (lrs<=lre) then
246 302400 : nvals = lre-lrs+1
247 302400 : aux_mat(lrs:lre,nstor) = aux_bc(n_aux_bc+1:n_aux_bc+nvals)
248 302400 : n_aux_bc = n_aux_bc + nvals
249 : endif
250 : enddo
251 :
252 : ! If we got nblk_mult columns in aux_mat or this is the last block
253 : ! do the matrix multiplication
254 :
255 20160 : if (nstor==nblk_mult .or. nb*nblk+nblk >= l_rows_np) then
256 :
257 11424 : lrs = 1 ! 1st local row number for multiply
258 11424 : lre = l_rows ! last local row number for multiply
259 11424 : if (a_lower) lrs = local_index(gcol_min, my_prow, np_rows, nblk, +1)
260 11424 : if (a_upper) lre = local_index(gcol, my_prow, np_rows, nblk, -1)
261 :
262 11424 : lcs = 1 ! 1st local col number for multiply
263 11424 : lce = l_cols ! last local col number for multiply
264 11424 : if (c_upper) lcs = local_index(gcol_min, my_pcol, np_cols, nblk, +1)
265 11424 : if (c_lower) lce = MIN(local_index(gcol, my_pcol, np_cols, nblk, -1),l_cols)
266 :
267 11424 : if (lcs<=lce) then
268 11424 : allocate(tmp1(nstor,lcs:lce),tmp2(nstor,lcs:lce), stat=istat, errmsg=errorMessage)
269 11424 : if (istat .ne. 0) then
270 : print *,"elpa_mult_at_b_&
271 : &MATH_DATATYPE&
272 0 : &: error when allocating tmp1 "//errorMessage
273 0 : stop 1
274 : endif
275 :
276 11424 : if (lrs<=lre) then
277 11424 : call obj%timer%start("blas")
278 : call PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', nstor, lce-lcs+1, lre-lrs+1, ONE, &
279 11424 : aux_mat(lrs,1), ubound(aux_mat,dim=1), b(lrs,lcs), ldb,ZERO, tmp1, nstor)
280 11424 : call obj%timer%stop("blas")
281 : else
282 0 : tmp1 = 0
283 : endif
284 :
285 : ! Sum up the results and send to processor row np
286 : #ifdef WITH_MPI
287 8064 : call obj%timer%start("mpi_communication")
288 : call mpi_reduce(tmp1, tmp2, nstor*(lce-lcs+1), MPI_MATH_DATATYPE_PRECISION, &
289 8064 : MPI_SUM, np, mpi_comm_rows, mpierr)
290 8064 : call obj%timer%stop("mpi_communication")
291 : ! Put the result into C
292 8064 : if (my_prow==np) c(nr_done+1:nr_done+nstor,lcs:lce) = tmp2(1:nstor,lcs:lce)
293 :
294 : #else /* WITH_MPI */
295 : ! tmp2 = tmp1
296 : ! Put the result into C
297 3360 : if (my_prow==np) c(nr_done+1:nr_done+nstor,lcs:lce) = tmp1(1:nstor,lcs:lce)
298 :
299 : #endif /* WITH_MPI */
300 :
301 11424 : deallocate(tmp1,tmp2, stat=istat, errmsg=errorMessage)
302 11424 : if (istat .ne. 0) then
303 : print *,"elpa_mult_at_b_&
304 : &MATH_DATATYPE&
305 0 : &: error when deallocating tmp1 "//errorMessage
306 0 : stop 1
307 : endif
308 :
309 : endif
310 :
311 11424 : nr_done = nr_done+nstor
312 11424 : nstor=0
313 11424 : aux_mat(:,:)=0
314 : endif
315 : enddo
316 : enddo
317 :
318 2016 : deallocate(aux_mat, aux_bc, lrs_save, lre_save, stat=istat, errmsg=errorMessage)
319 2016 : if (istat .ne. 0) then
320 : print *,"elpa_mult_at_b_&
321 : &MATH_DATATYPE&
322 0 : &: error when deallocating aux_mat "//errorMessage
323 0 : stop 1
324 : endif
325 :
326 : call obj%timer%stop("elpa_mult_at_b_&
327 : &MATH_DATATYPE&
328 : &_&
329 : &PRECISION&
330 2016 : &")
331 :
|