!-*- mode: F90 -*-!
! This file is distributed as part of the Wannier90 code and !
! under the terms of the GNU General Public License. See the !
! file `LICENSE' in the root directory of the Wannier90      !
! distribution, or       !
!                                                            !
! The webpage of the Wannier90 code is       !
!                                                            !
! The Wannier90 code is hosted on GitHub:                    !
!                                                            !
!            !

module w90_dos
  !! Compute Density of States
  use w90_constants, only: dp

  implicit none


  public :: dos_main, dos_get_levelspacing, dos_get_k

  integer       :: num_freq
  !! Number of sampling points
  real(kind=dp) :: d_omega
  !! Step between energies


  !                   PUBLIC PROCEDURES                     !

  subroutine dos_main
    !                                                       !
    !! Computes the electronic density of states. Can
    !! resolve into up-spin and down-spin parts, project
    !! onto selected Wannier orbitals, and use adaptive
    !! broadening, as in PRB 75, 195121 (2007) [YWVS07].
    !                                                       !

    use w90_io, only: io_error, io_file_unit, io_date, io_stopwatch, &
      seedname, stdout
    use w90_comms, only: on_root, num_nodes, my_node_id, comms_reduce
    use w90_postw90_common, only: num_int_kpts_on_node, int_kpts, weight, &
    use w90_parameters, only: num_wann, dos_energy_min, dos_energy_max, &
      dos_energy_step, timing_level, &
      wanint_kpoint_file, dos_kmesh, &
      dos_smr_index, dos_adpt_smr, dos_adpt_smr_fac, &
      dos_adpt_smr_max, spin_decomp, &
      dos_smr_fixed_en_width, &
      dos_project, num_dos_project
    use w90_get_oper, only: get_HH_R, get_SS_R, HH_R
    use w90_wan_ham, only: wham_get_eig_deleig
    use w90_utility, only: utility_diagonalize

    ! 'dos_k' contains contrib. from one k-point,
    ! 'dos_all' from all nodes/k-points (first summed on one node and
    ! then reduced (i.e. summed) over all nodes)
    real(kind=dp), allocatable :: dos_k(:, :)
    real(kind=dp), allocatable :: dos_all(:, :)

    real(kind=dp)    :: kweight, kpt(3), omega
    integer          :: i, loop_x, loop_y, loop_z, loop_tot, ifreq
    integer          :: dos_unit, ndim, ierr
    real(kind=dp), dimension(:), allocatable :: dos_energyarray

    complex(kind=dp), allocatable :: HH(:, :)
    complex(kind=dp), allocatable :: delHH(:, :, :)
    complex(kind=dp), allocatable :: UU(:, :)
    real(kind=dp) :: del_eig(num_wann, 3)
    real(kind=dp) :: eig(num_wann), levelspacing_k(num_wann)

    num_freq = nint((dos_energy_max - dos_energy_min)/dos_energy_step) + 1
    if (num_freq == 1) num_freq = 2
    d_omega = (dos_energy_max - dos_energy_min)/(num_freq - 1)

    allocate (dos_energyarray(num_freq), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating dos_energyarray in ' &
                                 //'dos subroutine')
    do ifreq = 1, num_freq
      dos_energyarray(ifreq) = dos_energy_min + real(ifreq - 1, dp)*d_omega
    end do

    allocate (HH(num_wann, num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating HH in dos')
    allocate (delHH(num_wann, num_wann, 3), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating delHH in dos')
    allocate (UU(num_wann, num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating UU in dos')

    call get_HH_R
    if (spin_decomp) then
      ndim = 3
      call get_SS_R
      ndim = 1
    end if

    allocate (dos_k(num_freq, ndim))
    allocate (dos_all(num_freq, ndim))

    if (on_root) then

      if (timing_level > 1) call io_stopwatch('dos', 1)

!       write(stdout,'(/,1x,a)') '============'
!       write(stdout,'(1x,a)')   'Calculating:'
!       write(stdout,'(1x,a)')   '============'

      write (stdout, '(/,/,1x,a)') &
        'Properties calculated in module  d o s'
      write (stdout, '(1x,a)') &

      if (num_dos_project == num_wann) then
        write (stdout, '(/,3x,a)') '* Total density of states (_dos)'
        write (stdout, '(/,3x,a)') &
          '* Density of states projected onto selected WFs (_dos)'
        write (stdout, '(3x,a)') 'Selected WFs |Rn> are:'
        do i = 1, num_dos_project
          write (stdout, '(5x,a,2x,i3)') 'n =', dos_project(i)

      write (stdout, '(/,5x,a,f9.4,a,f9.4,a)') &
        'Energy range: [', dos_energy_min, ',', dos_energy_max, '] eV'

      write (stdout, '(/,5x,a,(f6.3,1x))') &
        'Adaptive smearing width prefactor: ', &

      write (stdout, '(/,/,1x,a20,3(i0,1x))') 'Interpolation grid: ', &

    end if

    dos_all = 0.0_dp

    if (wanint_kpoint_file) then
      ! Unlike for optical properties, this should always work for the DOS
      if (on_root) write (stdout, '(/,1x,a)') 'Sampling the irreducible BZ only'

      ! Loop over k-points on the irreducible wedge of the Brillouin zone,
      ! read from file 'kpoint.dat'
      do loop_tot = 1, num_int_kpts_on_node(my_node_id)
        kpt(:) = int_kpts(:, loop_tot)
        if (dos_adpt_smr) then
          call wham_get_eig_deleig(kpt, eig, del_eig, HH, delHH, UU)
          call dos_get_levelspacing(del_eig, dos_kmesh, levelspacing_k)
          call dos_get_k(kpt, dos_energyarray, eig, dos_k, &
                         smr_index=dos_smr_index, &
                         adpt_smr_fac=dos_adpt_smr_fac, &
                         adpt_smr_max=dos_adpt_smr_max, &
                         levelspacing_k=levelspacing_k, &
          call pw90common_fourier_R_to_k(kpt, HH_R, HH, 0)
          call utility_diagonalize(HH, num_wann, eig, UU)
          call dos_get_k(kpt, dos_energyarray, eig, dos_k, &
                         smr_index=dos_smr_index, &
                         smr_fixed_en_width=dos_smr_fixed_en_width, &
        end if
        dos_all = dos_all + dos_k*weight(loop_tot)
      end do


      if (on_root) write (stdout, '(/,1x,a)') 'Sampling the full BZ'

      kweight = 1.0_dp/real(PRODUCT(dos_kmesh), kind=dp)
      do loop_tot = my_node_id, PRODUCT(dos_kmesh) - 1, num_nodes
        loop_x = loop_tot/(dos_kmesh(2)*dos_kmesh(3))
        loop_y = (loop_tot - loop_x*(dos_kmesh(2) &
        loop_z = loop_tot - loop_x*(dos_kmesh(2)*dos_kmesh(3)) &
                 - loop_y*dos_kmesh(3)
        kpt(1) = real(loop_x, dp)/real(dos_kmesh(1), dp)
        kpt(2) = real(loop_y, dp)/real(dos_kmesh(2), dp)
        kpt(3) = real(loop_z, dp)/real(dos_kmesh(3), dp)
        if (dos_adpt_smr) then
          call wham_get_eig_deleig(kpt, eig, del_eig, HH, delHH, UU)
          call dos_get_levelspacing(del_eig, dos_kmesh, levelspacing_k)
          call dos_get_k(kpt, dos_energyarray, eig, dos_k, &
                         smr_index=dos_smr_index, &
                         adpt_smr_fac=dos_adpt_smr_fac, &
                         adpt_smr_max=dos_adpt_smr_max, &
                         levelspacing_k=levelspacing_k, &
          call pw90common_fourier_R_to_k(kpt, HH_R, HH, 0)
          call utility_diagonalize(HH, num_wann, eig, UU)
          call dos_get_k(kpt, dos_energyarray, eig, dos_k, &
                         smr_index=dos_smr_index, &
                         smr_fixed_en_width=dos_smr_fixed_en_width, &
        end if
        dos_all = dos_all + dos_k*kweight
      end do

    end if

    ! Collect contributions from all nodes
    call comms_reduce(dos_all(1, 1), num_freq*ndim, 'SUM')

    if (on_root) then
      write (stdout, '(1x,a)') 'Output data files:'
      write (stdout, '(/,3x,a)') trim(seedname)//'-dos.dat'
      dos_unit = io_file_unit()
      open (dos_unit, FILE=trim(seedname)//'-dos.dat', STATUS='UNKNOWN', &
      do ifreq = 1, num_freq
        omega = dos_energyarray(ifreq)
        write (dos_unit, '(4E16.8)') omega, dos_all(ifreq, :)
      close (dos_unit)
      if (timing_level > 1) call io_stopwatch('dos', 2)
    end if

    deallocate (HH, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating HH in dos_main')
    deallocate (delHH, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating delHH in dos_main')
    deallocate (UU, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating UU in dos_main')

  end subroutine dos_main

  ! =========================================================================

  !> This subroutine calculates the contribution to the DOS of a single k point
  !> \todo still to do: adapt spin_get_nk to read in input the UU rotation matrix
  !> \note This routine simply provides the dos contribution of a given
  !>       point. This must be externally summed after proper weighting.
  !>       The weight factor (for a full BZ sampling with N^3 points) is 1/N^3 if we want
  !>       the final DOS to be normalized to the total number of electrons.
  !> \note The only factor that is included INSIDE this routine is the spin degeneracy
  !>       factor (=num_elec_per_state variable)
  !> \note The EnergyArray is assumed to be evenly spaced (and the energy spacing
  !>       is taken from EnergyArray(2)-EnergyArray(1))
  !> \note The routine is assuming that EnergyArray has at least two elements.
  !> \note The dos_k array must have dimensions size(EnergyArray) * ndim, where
  !>       ndim=1 if spin_decomp==false, or ndim=3 if spin_decomp==true. This is not checked.
  !> \note If smearing/binwidth < min_smearing_binwidth_ratio,
  !>       no smearing is applied (for that k point)
  !> \param kpt         the three coordinates of the k point vector whose DOS contribution we
  !>                    want to calculate (in relative coordinates)
  !> \param EnergyArray array with the energy grid on which to calculate the DOS (in eV)
  !>                    It must have at least two elements
  !> \param eig_k       array with the eigenvalues at the given k point (in eV)
  !> \param dos_k       array in which the contribution is stored. Three dimensions:
  !>                    dos_k(energyidx, spinidx), where:
  !>                    - energyidx is the index of the energies, corresponding to the one
  !>                      of the EnergyArray array;
  !>                    - spinidx=1 contains the total dos; if if spin_decomp==.true., then
  !>                      spinidx=2 and spinidx=3 contain the spin-up and spin-down contributions to the DOS
  !> \param smr_index  index that tells the kind of smearing
  !> \param smr_fixed_en_width optional parameter with the fixed energy for smearing, in eV. Can be provided only if the
  !>                    levelspacing_k parameter is NOT given
  !> \param adpt_smr_fac optional parameter with the factor for the adaptive smearing. Can be provided only if the
  !>                    levelspacing_k parameter IS given
  !> \param levelspacing_k optional array with the level spacings, i.e. how much each level changes
  !>                    near a given point of the interpolation mesh, as given by the
  !>                    dos_get_levelspacing() routine
  !>                    If present: adaptive smearing
  !>                    If not present: fixed-energy-width smearing
  subroutine dos_get_k(kpt, EnergyArray, eig_k, dos_k, smr_index, &
                       smr_fixed_en_width, adpt_smr_fac, adpt_smr_max, levelspacing_k, UU)
    use w90_io, only: io_error
    use w90_constants, only: dp, smearing_cutoff, min_smearing_binwidth_ratio
    use w90_utility, only: utility_w0gauss
    use w90_parameters, only: num_wann, spin_decomp, num_elec_per_state, &
      num_dos_project, dos_project
    use w90_spin, only: spin_get_nk

    ! Arguments
    real(kind=dp), dimension(3), intent(in)               :: kpt
    real(kind=dp), dimension(:), intent(in)               :: EnergyArray
    real(kind=dp), dimension(:), intent(in)               :: eig_k
    real(kind=dp), dimension(:, :), intent(out)            :: dos_k
    integer, intent(in)                                   :: smr_index
    real(kind=dp), intent(in), optional                    :: smr_fixed_en_width
    real(kind=dp), intent(in), optional                    :: adpt_smr_fac
    real(kind=dp), intent(in), optional                    :: adpt_smr_max
    real(kind=dp), dimension(:), intent(in), optional      :: levelspacing_k
    complex(kind=dp), dimension(:, :), intent(in), optional :: UU

    ! Adaptive smearing
    real(kind=dp) :: eta_smr, arg

    ! Misc/Dummy
    integer          :: i, j, loop_f, min_f, max_f, num_s_steps
    real(kind=dp)    :: rdum, spn_nk(num_wann), alpha_sq, beta_sq
    real(kind=dp)    :: binwidth, r_num_elec_per_state
    logical          :: DoSmearing

    if (present(levelspacing_k)) then
      if (present(smr_fixed_en_width)) &
        call io_error('Cannot call doskpt with levelspacing_k and ' &
                      //'with smr_fixed_en_width parameters together')
      if (.not. (present(adpt_smr_fac))) &
        call io_error('Cannot call doskpt with levelspacing_k and ' &
                      //'without adpt_smr_fac parameter')
      if (.not. (present(adpt_smr_max))) &
        call io_error('Cannot call doskpt with levelspacing_k and ' &
                      //'without adpt_smr_max parameter')
      if (present(adpt_smr_fac)) &
        call io_error('Cannot call doskpt without levelspacing_k and ' &
                      //'with adpt_smr_fac parameter')
      if (present(adpt_smr_max)) &
        call io_error('Cannot call doskpt without levelspacing_k and ' &
                      //'with adpt_smr_max parameter')
      if (.not. (present(smr_fixed_en_width))) &
        call io_error('Cannot call doskpt without levelspacing_k and ' &
                      //'without smr_fixed_en_width parameter')
    end if

    r_num_elec_per_state = real(num_elec_per_state, kind=dp)

    ! Get spin projections for every band
    if (spin_decomp) call spin_get_nk(kpt, spn_nk)

    binwidth = EnergyArray(2) - EnergyArray(1)

    dos_k = 0.0_dp
    do i = 1, num_wann
      if (spin_decomp) then
        ! Contribution to spin-up DOS of Bloch spinor with component
        ! (alpha,beta) with respect to the chosen quantization axis
        alpha_sq = (1.0_dp + spn_nk(i))/2.0_dp ! |alpha|^2
        ! Contribution to spin-down DOS
        beta_sq = 1.0_dp - alpha_sq ! |beta|^2 = 1 - |alpha|^2
      end if

      if (.not. present(levelspacing_k)) then
        eta_smr = smr_fixed_en_width
        ! Eq.(35) YWVS07
        eta_smr = min(levelspacing_k(i)*adpt_smr_fac, adpt_smr_max)
!          eta_smr=max(eta_smr,min_smearing_binwidth_ratio) !! No: it would render the next if always false
      end if

      ! Faster optimization: I precalculate the indices
      if (eta_smr/binwidth < min_smearing_binwidth_ratio) then
        min_f = max(nint((eig_k(i) - EnergyArray(1))/ &
                         (EnergyArray(size(EnergyArray)) - EnergyArray(1)) &
                         *real(size(EnergyArray) - 1, kind=dp)) + 1, 1)
        max_f = min(nint((eig_k(i) - EnergyArray(1))/ &
                         (EnergyArray(size(EnergyArray)) - EnergyArray(1)) &
                         *real(size(EnergyArray) - 1, kind=dp)) + 1, size(EnergyArray))
        DoSmearing = .false.
        min_f = max(nint((eig_k(i) - smearing_cutoff*eta_smr - EnergyArray(1))/ &
                         (EnergyArray(size(EnergyArray)) - EnergyArray(1)) &
                         *real(size(EnergyArray) - 1, kind=dp)) + 1, 1)
        max_f = min(nint((eig_k(i) + smearing_cutoff*eta_smr - EnergyArray(1))/ &
                         (EnergyArray(size(EnergyArray)) - EnergyArray(1)) &
                         *real(size(EnergyArray) - 1, kind=dp)) + 1, size(EnergyArray))
        DoSmearing = .true.
      end if

      do loop_f = min_f, max_f
        ! kind of smearing read from input (internal smearing_index variable)
        if (DoSmearing) then
          arg = (EnergyArray(loop_f) - eig_k(i))/eta_smr
          rdum = utility_w0gauss(arg, smr_index)/eta_smr
          rdum = 1._dp/(EnergyArray(2) - EnergyArray(1))
        end if

        ! Contribution to total DOS
        if (num_dos_project == num_wann) then
          ! Total DOS (default): do not loop over j, to save time
          dos_k(loop_f, 1) = dos_k(loop_f, 1) + rdum*r_num_elec_per_state
          ! [GP] I don't put num_elec_per_state here below: if we are
          ! calculating the spin decomposition, we should be doing a
          ! calcultation with spin-orbit, and thus num_elec_per_state=1!
          if (spin_decomp) then
            ! Spin-up contribution
            dos_k(loop_f, 2) = dos_k(loop_f, 2) + rdum*alpha_sq
            ! Spin-down contribution
            dos_k(loop_f, 3) = dos_k(loop_f, 3) + rdum*beta_sq
          end if
        else ! 0<num_dos_project<num_wann
          ! Partial DOS, projected onto the WFs with indices
          ! n=dos_project(1:num_dos_project)
          do j = 1, num_dos_project
            dos_k(loop_f, 1) = dos_k(loop_f, 1) + rdum*r_num_elec_per_state &
                               *abs(UU(dos_project(j), i))**2
            if (spin_decomp) then
              ! Spin-up contribution
              dos_k(loop_f, 2) = dos_k(loop_f, 2) &
                                 + rdum*alpha_sq*abs(UU(dos_project(j), i))**2
              ! Spin-down contribution
              dos_k(loop_f, 3) = dos_k(loop_f, 3) &
                                 + rdum*beta_sq*abs(UU(dos_project(j), i))**2
            end if
      enddo !loop_f
    end do !loop over bands

  end subroutine dos_get_k

  ! =========================================================================

  subroutine dos_get_levelspacing(del_eig, kmesh, levelspacing)
    !! This subroutine calculates the level spacing, i.e. how much the level changes
    !! near a given point of the interpolation mesh
    use w90_parameters, only: num_wann
    use w90_postw90_common, only: pw90common_kmesh_spacing

    real(kind=dp), dimension(num_wann, 3), intent(in) :: del_eig
    !! Band velocities, already corrected when degeneracies occur
    integer, dimension(3), intent(in)                :: kmesh
    !! array of three integers, giving the number of k points along
    !! each of the three directions defined by the reciprocal lattice vectors
    real(kind=dp), dimension(num_wann), intent(out)  :: levelspacing
    !! On output, the spacing for each of the bands (in eV)

    real(kind=dp) :: Delta_k
    integer :: band

    Delta_k = pw90common_kmesh_spacing(kmesh)
    do band = 1, num_wann
      levelspacing(band) = &
        sqrt(dot_product(del_eig(band, :), del_eig(band, :)))*Delta_k
    end do

  end subroutine dos_get_levelspacing

end module w90_dos