LCOV - code coverage report
Current view: top level - src/elpa1 - elpa_multiply_a_b.F90 (source / functions) Hit Total Coverage
Test: coverage_50ab7a7628bba174fc62cee3ab72b26e81f87fe5.info Lines: 103 123 83.7 %
Date: 2018-01-10 09:29:53 Functions: 0 0 -

          Line data    Source code
       1             : !
       2             : !    The ELPA library was originally created by the ELPA consortium,
       3             : !    consisting of the following organizations:
       4             : !
       5             : !    - Max Planck Computing and Data Facility (MPCDF), formerly known as
       6             : !      Rechenzentrum Garching der Max-Planck-Gesellschaft (RZG),
       7             : !    - Bergische Universität Wuppertal, Lehrstuhl für angewandte
       8             : !      Informatik,
       9             : !    - Technische Universität München, Lehrstuhl für Informatik mit
      10             : !      Schwerpunkt Wissenschaftliches Rechnen ,
      11             : !    - Fritz-Haber-Institut, Berlin, Abt. Theorie,
      12             : !    - Max-Plack-Institut für Mathematik in den Naturwissenschaften,
      13             : !      Leipzig, Abt. Komplexe Strukutren in Biologie und Kognition,
      14             : !      and
      15             : !    - IBM Deutschland GmbH
      16             : !
      17             : !    This particular source code file contains additions, changes and
      18             : !    enhancements authored by Intel Corporation which is not part of
      19             : !    the ELPA consortium.
      20             : !
      21             : !    More information can be found here:
      22             : !    http://elpa.mpcdf.mpg.de/
      23             : !
      24             : !    ELPA is free software: you can redistribute it and/or modify
      25             : !    it under the terms of the version 3 of the license of the
      26             : !    GNU Lesser General Public License as published by the Free
      27             : !    Software Foundation.
      28             : !
      29             : !    ELPA is distributed in the hope that it will be useful,
      30             : !    but WITHOUT ANY WARRANTY; without even the implied warranty of
      31             : !    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      32             : !    GNU Lesser General Public License for more details.
      33             : !
      34             : !    You should have received a copy of the GNU Lesser General Public License
      35             : !    along with ELPA.  If not, see <http://www.gnu.org/licenses/>
      36             : !
      37             : !    ELPA reflects a substantial effort on the part of the original
      38             : !    ELPA consortium, and we ask you to respect the spirit of the
      39             : !    license that we chose: i.e., please contribute any changes you
      40             : !    may have back to the original ELPA library distribution, and keep
      41             : !    any derivatives of ELPA under the same license that we chose for
      42             : !    the original distribution, the GNU Lesser General Public License.
      43             : !
      44             : !
      45             : ! ELPA1 -- Faster replacements for ScaLAPACK symmetric eigenvalue routines
      46             : !
      47             : ! Copyright of the original code rests with the authors inside the ELPA
      48             : ! consortium. The copyright of any additional modifications shall rest
      49             : ! with their original authors, but shall adhere to the licensing terms
      50             : ! distributed along with the original code in the file "COPYING".
      51             : !
      52             : ! Author: A. Marek, MPCDF
      53             : 
      54             : 
      55             : #include "../general/sanity.F90"
      56             : 
      57             :       use elpa1_compute
      58             :       use elpa_mpi
      59             :       use precision
      60             :       use elpa_abstract_impl
      61             :       implicit none
      62             : 
      63             : #include "../../src/general/precision_kinds.F90"
      64             :       class(elpa_abstract_impl_t), intent(inout) :: obj
      65             : 
      66             :       character*1                   :: uplo_a, uplo_c
      67             : 
      68             :       integer(kind=ik), intent(in)  :: ldb, ldbCols, ldc, ldcCols
      69             :       integer(kind=ik)              :: na, ncb
      70             : #ifdef USE_ASSUMED_SIZE
      71             :       MATH_DATATYPE(kind=rck)                 :: a(obj%local_nrows,*), b(ldb,*), c(ldc,*)
      72             : #else
      73             :       MATH_DATATYPE(kind=rck)                 :: a(obj%local_nrows,obj%local_ncols), b(ldb,ldbCols), c(ldc,ldcCols)
      74             : #endif
      75             :       integer(kind=ik)              :: my_prow, my_pcol, np_rows, np_cols, mpierr
      76             :       integer(kind=ik)              :: l_cols, l_rows, l_rows_np
      77             :       integer(kind=ik)              :: np, n, nb, nblk_mult, lrs, lre, lcs, lce
      78             :       integer(kind=ik)              :: gcol_min, gcol, goff
      79             :       integer(kind=ik)              :: nstor, nr_done, noff, np_bc, n_aux_bc, nvals
      80        2016 :       integer(kind=ik), allocatable :: lrs_save(:), lre_save(:)
      81             : 
      82             :       logical                       :: a_lower, a_upper, c_lower, c_upper
      83        4032 :       MATH_DATATYPE(kind=rck), allocatable    :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:)
      84             :       integer(kind=ik)              :: istat
      85             :       character(200)                :: errorMessage
      86             :       logical                       :: success
      87             :       integer(kind=ik)              :: nblk, mpi_comm_rows, mpi_comm_cols, lda, ldaCols, error
      88             : 
      89             :       call obj%timer%start("elpa_mult_at_b_&
      90             :       &MATH_DATATYPE&
      91             :       &_&
      92             :       &PRECISION&
      93        2016 :       &")
      94             : 
      95        2016 :       na   = obj%na
      96        2016 :       nblk = obj%nblk
      97        2016 :       lda  = obj%local_nrows
      98        2016 :       ldaCols  = obj%local_ncols
      99             : 
     100             : 
     101        2016 :       call obj%get("mpi_comm_rows",mpi_comm_rows,error)
     102        2016 :       if (error .ne. ELPA_OK) then
     103           0 :         print *,"Problem getting option. Aborting..."
     104           0 :         stop
     105             :       endif
     106        2016 :       call obj%get("mpi_comm_cols",mpi_comm_cols,error)
     107        2016 :       if (error .ne. ELPA_OK) then
     108           0 :         print *,"Problem getting option. Aborting..."
     109           0 :         stop
     110             :       endif
     111             : 
     112             : 
     113        2016 :       success = .true.
     114             : 
     115        2016 :       call obj%timer%start("mpi_communication")
     116        2016 :       call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
     117        2016 :       call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
     118        2016 :       call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
     119        2016 :       call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
     120        2016 :       call obj%timer%stop("mpi_communication")
     121        2016 :       l_rows = local_index(na,  my_prow, np_rows, nblk, -1) ! Local rows of a and b
     122        2016 :       l_cols = local_index(ncb, my_pcol, np_cols, nblk, -1) ! Local cols of b
     123             : 
     124             :       ! Block factor for matrix multiplications, must be a multiple of nblk
     125             : 
     126        2016 :       if (na/np_rows<=256) then
     127        2016 :          nblk_mult = (31/nblk+1)*nblk
     128             :       else
     129           0 :          nblk_mult = (63/nblk+1)*nblk
     130             :       endif
     131             : 
     132        2016 :       allocate(aux_mat(l_rows,nblk_mult), stat=istat, errmsg=errorMessage)
     133        2016 :       if (istat .ne. 0) then
     134             :         print *,"elpa_mult_at_b_&
     135             :   &MATH_DATATYPE&
     136           0 :   &: error when allocating aux_mat "//errorMessage
     137           0 :         stop 1
     138             :       endif
     139             : 
     140        2016 :       allocate(aux_bc(l_rows*nblk), stat=istat, errmsg=errorMessage)
     141        2016 :       if (istat .ne. 0) then
     142             :         print *,"elpa_mult_at_b_&
     143             :   &MATH_DATATYPE&
     144           0 :   &: error when allocating aux_bc "//errorMessage
     145           0 :         stop 1
     146             :       endif
     147             : 
     148        2016 :       allocate(lrs_save(nblk), stat=istat, errmsg=errorMessage)
     149        2016 :       if (istat .ne. 0) then
     150             :         print *,"elpa_mult_at_b_&
     151             :         &MATH_DATATYPE&
     152           0 :         &: error when allocating lrs_save "//errorMessage
     153           0 :         stop 1
     154             :       endif
     155             : 
     156        2016 :       allocate(lre_save(nblk), stat=istat, errmsg=errorMessage)
     157        2016 :       if (istat .ne. 0) then
     158             :         print *,"elpa_mult_at_b_&
     159             :         &MATH_DATATYPE&
     160           0 :         &: error when allocating lre_save "//errorMessage
     161           0 :         stop 1
     162             :       endif
     163             : 
     164        2016 :       a_lower = .false.
     165        2016 :       a_upper = .false.
     166        2016 :       c_lower = .false.
     167        2016 :       c_upper = .false.
     168             : 
     169        2016 :       if (uplo_a=='u' .or. uplo_a=='U') a_upper = .true.
     170        2016 :       if (uplo_a=='l' .or. uplo_a=='L') a_lower = .true.
     171        2016 :       if (uplo_c=='u' .or. uplo_c=='U') c_upper = .true.
     172        2016 :       if (uplo_c=='l' .or. uplo_c=='L') c_lower = .true.
     173             : 
     174             :       ! Build up the result matrix by processor rows
     175             : 
     176        5376 :       do np = 0, np_rows-1
     177             : 
     178             :         ! In this turn, procs of row np assemble the result
     179             : 
     180        3360 :         l_rows_np = local_index(na, np, np_rows, nblk, -1) ! local rows on receiving processors
     181             : 
     182        3360 :         nr_done = 0 ! Number of rows done
     183        3360 :         aux_mat = 0
     184        3360 :         nstor = 0   ! Number of columns stored in aux_mat
     185             : 
     186             :         ! Loop over the blocks on row np
     187             : 
     188       23520 :         do nb=0,(l_rows_np-1)/nblk
     189             : 
     190       20160 :           goff  = nb*np_rows + np ! Global offset in blocks corresponding to nb
     191             : 
     192             :           ! Get the processor column which owns this block (A is transposed, so we need the column)
     193             :           ! and the offset in blocks within this column.
     194             :           ! The corresponding block column in A is then broadcast to all for multiplication with B
     195             : 
     196       20160 :           np_bc = MOD(goff,np_cols)
     197       20160 :           noff = goff/np_cols
     198       20160 :           n_aux_bc = 0
     199             : 
     200             :           ! Gather up the complete block column of A on the owner
     201             : 
     202      322560 :           do n = 1, min(l_rows_np-nb*nblk,nblk) ! Loop over columns to be broadcast
     203             : 
     204      302400 :             gcol = goff*nblk + n ! global column corresponding to n
     205      302400 :             if (nstor==0 .and. n==1) gcol_min = gcol
     206             : 
     207      302400 :             lrs = 1       ! 1st local row number for broadcast
     208      302400 :             lre = l_rows  ! last local row number for broadcast
     209      302400 :             if (a_lower) lrs = local_index(gcol, my_prow, np_rows, nblk, +1)
     210      302400 :             if (a_upper) lre = local_index(gcol, my_prow, np_rows, nblk, -1)
     211             : 
     212      302400 :             if (lrs<=lre) then
     213      302400 :               nvals = lre-lrs+1
     214      302400 :               if (my_pcol == np_bc) aux_bc(n_aux_bc+1:n_aux_bc+nvals) = a(lrs:lre,noff*nblk+n)
     215      302400 :               n_aux_bc = n_aux_bc + nvals
     216             :             endif
     217             : 
     218      302400 :             lrs_save(n) = lrs
     219      302400 :             lre_save(n) = lre
     220             : 
     221             :           enddo
     222             : 
     223             :           ! Broadcast block column
     224             : #ifdef WITH_MPI
     225       13440 :           call obj%timer%start("mpi_communication")
     226             : #if REALCASE == 1
     227             :           call MPI_Bcast(aux_bc, n_aux_bc,    &
     228             :                          MPI_REAL_PRECISION,  &
     229        7680 :                          np_bc, mpi_comm_cols, mpierr)
     230             : #endif
     231             : #if COMPLEXCASE == 1
     232             :           call MPI_Bcast(aux_bc, n_aux_bc,    &
     233             :                          MPI_COMPLEX_PRECISION,  &
     234        5760 :                          np_bc, mpi_comm_cols, mpierr)
     235             : #endif
     236       13440 :           call obj%timer%stop("mpi_communication")
     237             : #endif /* WITH_MPI */
     238             :           ! Insert what we got in aux_mat
     239             : 
     240       20160 :           n_aux_bc = 0
     241      322560 :           do n = 1, min(l_rows_np-nb*nblk,nblk)
     242      302400 :             nstor = nstor+1
     243      302400 :             lrs = lrs_save(n)
     244      302400 :             lre = lre_save(n)
     245      302400 :             if (lrs<=lre) then
     246      302400 :               nvals = lre-lrs+1
     247      302400 :               aux_mat(lrs:lre,nstor) = aux_bc(n_aux_bc+1:n_aux_bc+nvals)
     248      302400 :               n_aux_bc = n_aux_bc + nvals
     249             :             endif
     250             :           enddo
     251             : 
     252             :           ! If we got nblk_mult columns in aux_mat or this is the last block
     253             :           ! do the matrix multiplication
     254             : 
     255       20160 :           if (nstor==nblk_mult .or. nb*nblk+nblk >= l_rows_np) then
     256             : 
     257       11424 :             lrs = 1       ! 1st local row number for multiply
     258       11424 :             lre = l_rows  ! last local row number for multiply
     259       11424 :             if (a_lower) lrs = local_index(gcol_min, my_prow, np_rows, nblk, +1)
     260       11424 :             if (a_upper) lre = local_index(gcol, my_prow, np_rows, nblk, -1)
     261             : 
     262       11424 :             lcs = 1       ! 1st local col number for multiply
     263       11424 :             lce = l_cols  ! last local col number for multiply
     264       11424 :             if (c_upper) lcs = local_index(gcol_min, my_pcol, np_cols, nblk, +1)
     265       11424 :             if (c_lower) lce = MIN(local_index(gcol, my_pcol, np_cols, nblk, -1),l_cols)
     266             : 
     267       11424 :             if (lcs<=lce) then
     268       11424 :               allocate(tmp1(nstor,lcs:lce),tmp2(nstor,lcs:lce), stat=istat, errmsg=errorMessage)
     269       11424 :               if (istat .ne. 0) then
     270             :                print *,"elpa_mult_at_b_&
     271             :                &MATH_DATATYPE&
     272           0 :                &: error when allocating tmp1 "//errorMessage
     273           0 :                stop 1
     274             :               endif
     275             : 
     276       11424 :               if (lrs<=lre) then
     277       11424 :                 call obj%timer%start("blas")
     278             :                 call PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', nstor, lce-lcs+1, lre-lrs+1, ONE, &
     279       11424 :                         aux_mat(lrs,1), ubound(aux_mat,dim=1), b(lrs,lcs), ldb,ZERO, tmp1, nstor)
     280       11424 :                 call obj%timer%stop("blas")
     281             :               else
     282           0 :                 tmp1 = 0
     283             :               endif
     284             : 
     285             :               ! Sum up the results and send to processor row np
     286             : #ifdef WITH_MPI
     287        8064 :               call obj%timer%start("mpi_communication")
     288             :               call mpi_reduce(tmp1, tmp2, nstor*(lce-lcs+1),  MPI_MATH_DATATYPE_PRECISION, &
     289        8064 :                               MPI_SUM, np, mpi_comm_rows, mpierr)
     290        8064 :               call obj%timer%stop("mpi_communication")
     291             :               ! Put the result into C
     292        8064 :               if (my_prow==np) c(nr_done+1:nr_done+nstor,lcs:lce) = tmp2(1:nstor,lcs:lce)
     293             : 
     294             : #else /* WITH_MPI */
     295             : !              tmp2 = tmp1
     296             :               ! Put the result into C
     297        3360 :               if (my_prow==np) c(nr_done+1:nr_done+nstor,lcs:lce) = tmp1(1:nstor,lcs:lce)
     298             : 
     299             : #endif /* WITH_MPI */
     300             : 
     301       11424 :               deallocate(tmp1,tmp2, stat=istat, errmsg=errorMessage)
     302       11424 :               if (istat .ne. 0) then
     303             :                print *,"elpa_mult_at_b_&
     304             :                &MATH_DATATYPE&
     305           0 :                &: error when deallocating tmp1 "//errorMessage
     306           0 :                stop 1
     307             :               endif
     308             : 
     309             :             endif
     310             : 
     311       11424 :             nr_done = nr_done+nstor
     312       11424 :             nstor=0
     313       11424 :             aux_mat(:,:)=0
     314             :           endif
     315             :         enddo
     316             :       enddo
     317             : 
     318        2016 :       deallocate(aux_mat, aux_bc, lrs_save, lre_save, stat=istat, errmsg=errorMessage)
     319        2016 :       if (istat .ne. 0) then
     320             :        print *,"elpa_mult_at_b_&
     321             :        &MATH_DATATYPE&
     322           0 :        &: error when deallocating aux_mat "//errorMessage
     323           0 :        stop 1
     324             :       endif
     325             : 
     326             :       call obj%timer%stop("elpa_mult_at_b_&
     327             :       &MATH_DATATYPE&
     328             :       &_&
     329             :       &PRECISION&
     330        2016 :       &")
     331             : 

Generated by: LCOV version 1.12