Skip to content

Commit

Permalink
ADMM| Fix bugs in ADMMP and ADMMS forces/virial
Browse files Browse the repository at this point in the history
  • Loading branch information
abussy committed Aug 6, 2023
1 parent f783b28 commit da72331
Show file tree
Hide file tree
Showing 29 changed files with 985 additions and 794 deletions.
216 changes: 98 additions & 118 deletions src/admm_methods.F

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion src/admm_types.F
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ MODULE admm_types
do_admm_blocked_projection,&
do_admm_blocking_purify_full,&
do_admm_charge_constrained_projection,&
do_admm_exch_scaling_merlot,&
do_admm_exch_scaling_none,&
do_admm_purify_none
USE input_section_types, ONLY: section_vals_release,&
Expand Down Expand Up @@ -160,7 +161,8 @@ MODULE admm_types
n_large_basis(3) = 0.0_dp
INTEGER :: nao_orb = 0, nao_aux_fit = 0, nmo(2) = 0
INTEGER :: purification_method = do_admm_purify_none
LOGICAL :: charge_constrain = .FALSE.
LOGICAL :: charge_constrain = .FALSE., do_admmp = .FALSE., &
do_admmq = .FALSE., do_admms = .FALSE.
INTEGER :: scaling_model = do_admm_exch_scaling_none, &
aux_exch_func = do_admm_aux_exch_func_none
LOGICAL :: aux_exch_func_param = .FALSE.
Expand Down Expand Up @@ -368,6 +370,14 @@ SUBROUTINE admm_env_create(admm_env, admm_control, mos, para_env, natoms, nao_au
admm_env%aux_exch_func_param = admm_control%aux_exch_func_param
admm_env%aux_x_param(:) = admm_control%aux_x_param(:)
!ADMMP, ADMMQ, ADMMS
IF ((.NOT. admm_env%charge_constrain) .AND. (admm_env%scaling_model == do_admm_exch_scaling_merlot)) &
admm_env%do_admmp = .TRUE.
IF (admm_env%charge_constrain .AND. (admm_env%scaling_model == do_admm_exch_scaling_none)) &
admm_env%do_admmq = .TRUE.
IF (admm_env%charge_constrain .AND. (admm_env%scaling_model == do_admm_exch_scaling_merlot)) &
admm_env%do_admms = .TRUE.
IF ((admm_control%method == do_admm_blocking_purify_full) .OR. &
(admm_control%method == do_admm_blocked_projection)) THEN
! Create block map
Expand Down
18 changes: 15 additions & 3 deletions src/hfx_admm_utils.F
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ SUBROUTINE hfx_admm_init(qs_env)

INTEGER :: handle, ispin, n_rep_hf, nao_aux_fit, &
natoms, nelectron, nmo
LOGICAL :: s_mstruct_changed
LOGICAL :: s_mstruct_changed, use_virial
REAL(dp) :: maxocc
TYPE(admm_type), POINTER :: admm_env
TYPE(cp_blacs_env_type), POINTER :: blacs_env
Expand All @@ -166,9 +166,10 @@ SUBROUTINE hfx_admm_init(qs_env)
TYPE(qs_ks_env_type), POINTER :: ks_env
TYPE(qs_rho_type), POINTER :: rho
TYPE(section_vals_type), POINTER :: hfx_sections, input, xc_section
TYPE(virial_type), POINTER :: virial

CALL timeset(routineN, handle)
NULLIFY (admm_env, hfx_sections, mos, mos_aux_fit, para_env, &
NULLIFY (admm_env, hfx_sections, mos, mos_aux_fit, para_env, virial, &
mo_coeff_aux_fit, xc_section, rho, ks_env, dft_control, input, &
qs_kind_set, mo_coeff_b, aux_fit_fm_struct, blacs_env)

Expand All @@ -181,7 +182,8 @@ SUBROUTINE hfx_admm_init(qs_env)
rho=rho, &
ks_env=ks_env, &
dft_control=dft_control, &
input=input)
input=input, &
virial=virial)

hfx_sections => section_vals_get_subs_vals(input, "DFT%XC%HF")

Expand Down Expand Up @@ -270,6 +272,16 @@ SUBROUTINE hfx_admm_init(qs_env)
CPABORT("GAPW ADMM not implemented for MCWEENY or NONE_DM purification.")
END IF

!ADMMS and ADMMP stress tensors only available for close-shell systesms, because virial cannot
!be scaled by gsi spin component wise
use_virial = virial%pv_availability .AND. (.NOT. virial%pv_numer)
IF (use_virial .AND. admm_env%do_admms .AND. dft_control%nspins == 2) THEN
CPABORT("ADMMS stress tensor is only available for closed-shell systems")
END IF
IF (use_virial .AND. admm_env%do_admmp .AND. dft_control%nspins == 2) THEN
CPABORT("ADMMP stress tensor is only available for closed-shell systems")
END IF

IF (dft_control%do_admm_dm .AND. .NOT. ASSOCIATED(admm_env%admm_dm)) THEN
CALL admm_dm_create(admm_env%admm_dm, dft_control%admm_control, nspins=dft_control%nspins, natoms=natoms)
END IF
Expand Down
36 changes: 21 additions & 15 deletions src/qs_ks_atom.F
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ MODULE qs_ks_atom
!> \param oce_external ...
!> \param sab_external ...
!> \param kscale ...
!> \param fscale ...
!> \par History
!> created [MI]
!> the loop over the spins is done internally [03-05,MI]
Expand All @@ -102,7 +103,7 @@ MODULE qs_ks_atom
!> Allow for external kind_set, rho_atom_set, oce, sab 12.2019 (A. Bussy)
! **************************************************************************************************
SUBROUTINE update_ks_atom(qs_env, ksmat, pmat, forces, tddft, rho_atom_external, &
kind_set_external, oce_external, sab_external, kscale)
kind_set_external, oce_external, sab_external, kscale, fscale)

TYPE(qs_environment_type), POINTER :: qs_env
TYPE(dbcsr_p_type), DIMENSION(*), INTENT(INOUT) :: ksmat, pmat
Expand All @@ -115,7 +116,7 @@ SUBROUTINE update_ks_atom(qs_env, ksmat, pmat, forces, tddft, rho_atom_external,
TYPE(oce_matrix_type), OPTIONAL, POINTER :: oce_external
TYPE(neighbor_list_set_p_type), DIMENSION(:), &
OPTIONAL, POINTER :: sab_external
REAL(KIND=dp), INTENT(IN), OPTIONAL :: kscale
REAL(KIND=dp), INTENT(IN), OPTIONAL :: kscale, fscale(2)

CHARACTER(len=*), PARAMETER :: routineN = 'update_ks_atom'

Expand All @@ -133,7 +134,7 @@ SUBROUTINE update_ks_atom(qs_env, ksmat, pmat, forces, tddft, rho_atom_external,
REAL(dp), ALLOCATABLE, DIMENSION(:, :) :: a_matrix, p_matrix
REAL(dp), DIMENSION(3) :: rac, rbc
REAL(dp), DIMENSION(3, 3) :: force_tmp
REAL(kind=dp) :: eps_cpc, factor1, factor2
REAL(kind=dp) :: eps_cpc, factor1, factor2, force_fac(2)
REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: C_int_h, C_int_s, coc
REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: dCPC_h, dCPC_s
REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: PC_h, PC_s
Expand Down Expand Up @@ -197,6 +198,9 @@ SUBROUTINE update_ks_atom(qs_env, ksmat, pmat, forces, tddft, rho_atom_external,
factor2 = factor2*kscale
END IF

force_fac = 1.0_dp
IF (PRESENT(fscale)) force_fac(:) = fscale(:)

IF (PRESENT(rho_atom_external)) my_rho_atom => rho_atom_external
IF (PRESENT(kind_set_external)) my_kind_set => kind_set_external
IF (PRESENT(oce_external)) my_oce => oce_external
Expand Down Expand Up @@ -296,7 +300,7 @@ SUBROUTINE update_ks_atom(qs_env, ksmat, pmat, forces, tddft, rho_atom_external,
!$OMP , basis_set_list, nimages, cell_to_index &
!$OMP , ksmat, pmat, natom, nkind, my_kind_set, my_oce &
!$OMP , my_rho_atom, factor1, factor2, use_virial &
!$OMP , atom_of_kind, ldCPC, force, locks &
!$OMP , atom_of_kind, ldCPC, force, locks, force_fac &
!$OMP ) &
!$OMP PRIVATE( slot_num, is_entry_null, TASK, is_task_valid &
!$OMP , C_int_h, C_int_s, coc, a_matrix, p_matrix &
Expand Down Expand Up @@ -464,7 +468,7 @@ SUBROUTINE update_ks_atom(qs_env, ksmat, pmat, forces, tddft, rho_atom_external,
CALL add_vhxca_forces(mat_p, C_coeff_hh_a, C_coeff_hh_b, C_coeff_ss_a, C_coeff_ss_b, &
rho_at, force_tmp, nspins, iatom, jatom, nsoctot, &
list_a, n_cont_a, list_b, n_cont_b, dCPC_h, dCPC_s, ldCPC, &
PC_h, PC_s, p_matrix)
PC_h, PC_s, p_matrix, force_fac)
force_tmp = factor2*force_tmp
!$ CALL omp_set_lock(locks((ka_kind - 1)*nkind + kkind))
force(kkind)%vhxc_atom(1:3, ka_kind) = force(kkind)%vhxc_atom(1:3, ka_kind) + force_tmp(1:3, 3)
Expand All @@ -483,7 +487,7 @@ SUBROUTINE update_ks_atom(qs_env, ksmat, pmat, forces, tddft, rho_atom_external,
CALL add_vhxca_forces(mat_p, C_coeff_hh_b, C_coeff_hh_a, C_coeff_ss_b, C_coeff_ss_a, &
rho_at, force_tmp, nspins, jatom, iatom, nsoctot, &
list_b, n_cont_b, list_a, n_cont_a, dCPC_h, dCPC_s, ldCPC, &
PC_h, PC_s, p_matrix)
PC_h, PC_s, p_matrix, force_fac)
force_tmp = factor2*force_tmp
!$ CALL omp_set_lock(locks((ka_kind - 1)*nkind + kkind))
force(kkind)%vhxc_atom(1:3, ka_kind) = force(kkind)%vhxc_atom(1:3, ka_kind) + force_tmp(1:3, 3)
Expand Down Expand Up @@ -726,10 +730,11 @@ END SUBROUTINE add_vhxca_to_ks
!> \param PC_h ...
!> \param PC_s ...
!> \param p_matrix ...
!> \param force_scaling ...
! **************************************************************************************************
SUBROUTINE add_vhxca_forces(mat_p, C_hh_a, C_hh_b, C_ss_a, C_ss_b, &
rho_atom, force, nspins, ia, ja, nsp, lista, nconta, listb, ncontb, &
dCPC_h, dCPC_s, ldCPC, PC_h, PC_s, p_matrix)
dCPC_h, dCPC_s, ldCPC, PC_h, PC_s, p_matrix, force_scaling)
TYPE(cp_2d_r_p_type), DIMENSION(:), INTENT(IN), &
POINTER :: mat_p
REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN) :: C_hh_a, C_hh_b, C_ss_a, C_ss_b
Expand All @@ -744,6 +749,7 @@ SUBROUTINE add_vhxca_forces(mat_p, C_hh_a, C_hh_b, C_ss_a, C_ss_b, &
INTEGER, INTENT(IN) :: ldCPC
REAL(KIND=dp), DIMENSION(:, :, :) :: PC_h, PC_s
REAL(KIND=dp), DIMENSION(:, :) :: p_matrix
REAL(KIND=dp), DIMENSION(2), INTENT(IN) :: force_scaling

INTEGER :: dir, ispin
REAL(dp), DIMENSION(:, :), POINTER :: int_hard, int_soft
Expand Down Expand Up @@ -781,15 +787,15 @@ SUBROUTINE add_vhxca_forces(mat_p, C_hh_a, C_hh_b, C_ss_a, C_ss_b, &
C_hh_a(:, :, dir), SIZE(C_hh_a, 1), &
0.0_dp, dCPC_h, SIZE(dCPC_h, 1))
trace = trace_r_AxB(dCPC_h, ldCPC, int_hard, nsp, nsp, nsp)
force(dir - 1, 3) = force(dir - 1, 3) + ieqj*trace
force(dir - 1, 1) = force(dir - 1, 1) - ieqj*trace
force(dir - 1, 3) = force(dir - 1, 3) + ieqj*trace*force_scaling(ispin)
force(dir - 1, 1) = force(dir - 1, 1) - ieqj*trace*force_scaling(ispin)

CALL DGEMM('T', 'N', nsp, nsp, nconta, 1.0_dp, PC_s(:, :, ispin), SIZE(PC_s, 1), &
C_ss_a(:, :, dir), SIZE(C_ss_a, 1), &
0.0_dp, dCPC_s, SIZE(dCPC_s, 1))
trace = trace_r_AxB(dCPC_s, ldCPC, int_soft, nsp, nsp, nsp)
force(dir - 1, 3) = force(dir - 1, 3) - ieqj*trace
force(dir - 1, 1) = force(dir - 1, 1) + ieqj*trace
force(dir - 1, 3) = force(dir - 1, 3) - ieqj*trace*force_scaling(ispin)
force(dir - 1, 1) = force(dir - 1, 1) + ieqj*trace*force_scaling(ispin)
END DO

! j-k contributions
Expand All @@ -805,15 +811,15 @@ SUBROUTINE add_vhxca_forces(mat_p, C_hh_a, C_hh_b, C_ss_a, C_ss_b, &
C_hh_b(:, :, dir), SIZE(C_hh_b, 1), &
0.0_dp, dCPC_h, SIZE(dCPC_h, 1))
trace = trace_r_AxB(dCPC_h, ldCPC, int_hard, nsp, nsp, nsp)
force(dir - 1, 3) = force(dir - 1, 3) + ieqj*trace
force(dir - 1, 2) = force(dir - 1, 2) - ieqj*trace
force(dir - 1, 3) = force(dir - 1, 3) + ieqj*trace*force_scaling(ispin)
force(dir - 1, 2) = force(dir - 1, 2) - ieqj*trace*force_scaling(ispin)

CALL DGEMM('T', 'N', nsp, nsp, ncontb, 1.0_dp, PC_s(:, :, ispin), SIZE(PC_s, 1), &
C_ss_b(:, :, dir), SIZE(C_ss_b, 1), &
0.0_dp, dCPC_s, SIZE(dCPC_s, 1))
trace = trace_r_AxB(dCPC_s, ldCPC, int_soft, nsp, nsp, nsp)
force(dir - 1, 3) = force(dir - 1, 3) - ieqj*trace
force(dir - 1, 2) = force(dir - 1, 2) + ieqj*trace
force(dir - 1, 3) = force(dir - 1, 3) - ieqj*trace*force_scaling(ispin)
force(dir - 1, 2) = force(dir - 1, 2) + ieqj*trace*force_scaling(ispin)
END DO

END DO !ispin
Expand Down
10 changes: 7 additions & 3 deletions src/qs_ks_methods.F
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ SUBROUTINE qs_ks_build_kohn_sham_matrix(qs_env, calculate_forces, just_energy, &
LOGICAL :: do_adiabatic_rescaling, do_ddapc, do_hfx, do_ppl, gapw, gapw_xc, &
hfx_treat_lsd_in_core, just_energy_xc, lrigpw, my_print, rigpw, use_virial
REAL(KIND=dp) :: ecore_ppl, edisp, ee_ener, ekin_mol, &
mulliken_order_p
mulliken_order_p, vscale
REAL(KIND=dp), DIMENSION(3, 3) :: h_stress, pv_loc
TYPE(admm_type), POINTER :: admm_env
TYPE(cdft_control_type), POINTER :: cdft_control
Expand Down Expand Up @@ -538,8 +538,12 @@ SUBROUTINE qs_ks_build_kohn_sham_matrix(qs_env, calculate_forces, just_energy, &
NULLIFY (rho_struct)

IF (use_virial .AND. calculate_forces) THEN
virial%pv_exc = virial%pv_exc - virial%pv_xc
virial%pv_virial = virial%pv_virial - virial%pv_xc
vscale = 1.0_dp
!Note: ADMMS and ADMMP stress tensor only for closed-shell calculations
IF (admm_env%do_admms) vscale = admm_env%gsi(1)**(2.0_dp/3.0_dp)
IF (admm_env%do_admmp) vscale = admm_env%gsi(1)**2
virial%pv_exc = virial%pv_exc - vscale*virial%pv_xc
virial%pv_virial = virial%pv_virial - vscale*virial%pv_xc
! virial%pv_xc will be zeroed in the xc routines
END IF
xc_section => admm_env%xc_section_primary
Expand Down
11 changes: 4 additions & 7 deletions src/qs_ks_utils.F
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ MODULE qs_ks_utils
USE hfx_types, ONLY: hfx_type
USE input_constants, ONLY: &
cdft_alpha_constraint, cdft_beta_constraint, cdft_charge_constraint, &
cdft_magnetization_constraint, do_admm_aux_exch_func_none, do_admm_exch_scaling_merlot, &
do_ppl_grid, sic_ad, sic_eo, sic_list_all, sic_list_unpaired, sic_mauri_spz, sic_mauri_us, &
sic_none
cdft_magnetization_constraint, do_admm_aux_exch_func_none, do_ppl_grid, sic_ad, sic_eo, &
sic_list_all, sic_list_unpaired, sic_mauri_spz, sic_mauri_us, sic_none
USE input_section_types, ONLY: section_vals_get_subs_vals,&
section_vals_type,&
section_vals_val_get
Expand Down Expand Up @@ -1653,11 +1652,9 @@ SUBROUTINE sum_up_and_integrate(qs_env, ks_matrix, rho, my_rho, &
END IF
fadm = 1.0_dp
! Calculate bare scaling of force according to Merlot, 1. IF: ADMMP, 2. IF: ADMMS,
IF ((.NOT. admm_env%charge_constrain) .AND. &
(admm_env%scaling_model == do_admm_exch_scaling_merlot)) THEN
IF (admm_env%do_admmp) THEN
fadm = admm_env%gsi(ispin)**2
ELSE IF (admm_env%charge_constrain .AND. &
(admm_env%scaling_model == do_admm_exch_scaling_merlot)) THEN
ELSE IF (admm_env%do_admms) THEN
fadm = (admm_env%gsi(ispin))**(2.0_dp/3.0_dp)
END IF

Expand Down
66 changes: 66 additions & 0 deletions tests/QS/regtest-admm-qps-2/H2O-ADMMP-GAPW_force.inp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
&FORCE_EVAL
METHOD Quickstep
&DFT
BASIS_SET_FILE_NAME BASIS_ccGRB_UZH
BASIS_SET_FILE_NAME BASIS_ADMM_UZH
POTENTIAL_FILE_NAME POTENTIAL_UZH
&MGRID
CUTOFF 200
REL_CUTOFF 40
&END MGRID
&AUXILIARY_DENSITY_MATRIX_METHOD
ADMM_PURIFICATION_METHOD NONE
METHOD BASIS_PROJECTION
EXCH_SCALING_MODEL MERLOT
&END
&QS
EPS_DEFAULT 1.0E-12
METHOD GAPW
&END
&SCF
SCF_GUESS ATOMIC
EPS_SCF 1.0E-6
&END SCF
&XC
&XC_FUNCTIONAL NONE
&END XC_FUNCTIONAL
&HF
FRACTION 1.0
&INTERACTION_POTENTIAL
POTENTIAL_TYPE TRUNCATED
CUTOFF_RADIUS 2.0
&END
&END
&END XC
&END DFT
&SUBSYS
&CELL
ABC 4.5 4.5 4.5
&END CELL
&TOPOLOGY
&END
&COORD
O 0.000000 0.000000 -0.045587
H 0.000000 -0.757136 0.510545
H 0.000000 0.757136 0.510545
&END COORD
&KIND H
BASIS_SET ccGRB-D-q1
BASIS_SET AUX_FIT admm-dz-q1
POTENTIAL GTH-HYB-q1
&END KIND
&KIND O
BASIS_SET ccGRB-D-q6
BASIS_SET AUX_FIT admm-dz-q6
POTENTIAL GTH-HYB-q6
&END KIND
&END SUBSYS
&END FORCE_EVAL
&GLOBAL
PROJECT H2O-ADMMP-GAPW
PRINT_LEVEL MEDIUM
RUN_TYPE DEBUG
&END GLOBAL
&DEBUG
CHECK_ATOM_FORCE 1 Z
&END
Loading

0 comments on commit da72331

Please sign in to comment.