Skip to content

Commit

Permalink
Refactor:Remove update_tau_pos in ucell (deepmodeling#5783)
Browse files Browse the repository at this point in the history
* modify periodic_boundary_adjustment

* modify update_pos_tau

* update compile

* delete ucell referenc in update_pos_tau

* add unittest for update_pos_tau

* move back test file

* use EXPECT_THAT instead of EXPECT_EQ in relax_old and use regex to remove the title

* remove the bug in the relax_old for it didn't run update_pos

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
A-006 and pre-commit-ci-lite[bot] authored Jan 2, 2025
1 parent 7ae18a5 commit c53f445
Show file tree
Hide file tree
Showing 24 changed files with 301 additions and 137 deletions.
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ OBJS_CELL=atom_pseudo.o\
cell_index.o\
check_atomic_stru.o\
update_cell.o\
bcast_cell.o\

OBJS_DEEPKS=LCAO_deepks.o\
deepks_force.o\
Expand Down
1 change: 1 addition & 0 deletions source/module_cell/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_library(
cell_index.cpp
check_atomic_stru.cpp
update_cell.cpp
bcast_cell.cpp
)

if(ENABLE_COVERAGE)
Expand Down
15 changes: 15 additions & 0 deletions source/module_cell/bcast_cell.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "unitcell.h"

namespace unitcell
{
void bcast_atoms_tau(Atom* atoms,
const int ntype)
{
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
for (int i = 0; i < ntype; i++) {
atoms[i].bcast_atom(); // bcast tau array
}
#endif
}
}
10 changes: 10 additions & 0 deletions source/module_cell/bcast_cell.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef BCAST_CELL_H
#define BCAST_CELL_H

namespace unitcell
{
void bcast_atoms_tau(Atom* atoms,
const int ntype);
}

#endif // BCAST_CELL_H
7 changes: 4 additions & 3 deletions source/module_cell/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ install(FILES unitcell_test_parallel.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR})

list(APPEND cell_simple_srcs
../unitcell.cpp
../update_cell.cpp
../bcast_cell.cpp
../read_atoms.cpp
../atom_spec.cpp
../atom_pseudo.cpp
Expand Down Expand Up @@ -103,14 +105,14 @@ add_test(NAME cell_parallel_kpoints_test
AddTest(
TARGET cell_unitcell_test
LIBS parameter ${math_libs} base device cell_info symmetry
SOURCES unitcell_test.cpp ../../module_io/output.cpp ../../module_elecstate/cal_ux.cpp ../update_cell.cpp
SOURCES unitcell_test.cpp ../../module_io/output.cpp ../../module_elecstate/cal_ux.cpp

)

AddTest(
TARGET cell_unitcell_test_readpp
LIBS parameter ${math_libs} base device cell_info
SOURCES unitcell_test_readpp.cpp ../../module_io/output.cpp
SOURCES unitcell_test_readpp.cpp ../../module_io/output.cpp
)

AddTest(
Expand All @@ -123,7 +125,6 @@ AddTest(
TARGET cell_unitcell_test_setupcell
LIBS parameter ${math_libs} base device cell_info
SOURCES unitcell_test_setupcell.cpp ../../module_io/output.cpp
../../module_cell/update_cell.cpp
)

add_test(NAME cell_unitcell_test_parallel
Expand Down
2 changes: 0 additions & 2 deletions source/module_cell/test/support/mock_unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ bool UnitCell::read_atom_positions(std::ifstream& ifpos,
std::ofstream& ofs_warning) {
return true;
}
void UnitCell::update_pos_tau(const double* pos) {}
void UnitCell::update_pos_taud(double* posd_in) {}
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {}
void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {}
void UnitCell::periodic_boundary_adjustment() {}
void UnitCell::bcast_atoms_tau() {}
bool UnitCell::judge_big_cell() const { return true; }
void UnitCell::update_stress(ModuleBase::matrix& scs) {}
Expand Down
7 changes: 5 additions & 2 deletions source/module_cell/test/unitcell_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,9 @@ TEST_F(UcellDeathTest, PeriodicBoundaryAdjustment1)
PARAM.input.relax_new = utp.relax_new;
ucell = utp.SetUcellInfo();
testing::internal::CaptureStdout();
EXPECT_EXIT(ucell->periodic_boundary_adjustment(), ::testing::ExitedWithCode(1), "");
EXPECT_EXIT(unitcell::periodic_boundary_adjustment(
ucell->atoms,ucell->latvec,ucell->ntype),
::testing::ExitedWithCode(1), "");
std::string output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("the movement of atom is larger than the length of cell"));
}
Expand All @@ -793,7 +795,8 @@ TEST_F(UcellTest, PeriodicBoundaryAdjustment2)
UcellTestPrepare utp = UcellTestLib["C1H2-Index"];
PARAM.input.relax_new = utp.relax_new;
ucell = utp.SetUcellInfo();
EXPECT_NO_THROW(ucell->periodic_boundary_adjustment());
EXPECT_NO_THROW(unitcell::periodic_boundary_adjustment(
ucell->atoms,ucell->latvec,ucell->ntype));
}

TEST_F(UcellTest, PrintCell)
Expand Down
34 changes: 32 additions & 2 deletions source/module_cell/test/unitcell_test_para.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
#include "mpi.h"
#endif
#include "prepare_unitcell.h"

#include "../update_cell.h"
#include "../bcast_cell.h"
#ifdef __LCAO
InfoNonlocal::InfoNonlocal()
{
Expand Down Expand Up @@ -44,6 +45,7 @@ Magnetism::~Magnetism()
/**
* - Tested Functions:
* - UpdatePosTaud
* - update_pos_tau(double* pos)
* - update_pos_taud(const double* pos)
* - bcast_atoms_tau() is also called in the above function, which calls Atom::bcast_atom with many
* atomic info in addition to tau
Expand Down Expand Up @@ -123,7 +125,34 @@ TEST_F(UcellTest, BcastUnitcell)
EXPECT_EQ(atom_labels[1], atom_type2_expected);
}
}

TEST_F(UcellTest, UpdatePosTau)
{
double* pos_in = new double[ucell->nat * 3];
ucell->set_iat2itia();
std::fill(pos_in, pos_in + ucell->nat * 3, 0);
for (int iat = 0; iat < ucell->nat; ++iat)
{
int it, ia;
ucell->iat2iait(iat, &ia, &it);
for (int ik = 0; ik < 3; ++ik)
{
ucell->atoms[it].mbl[ia][ik] = true;
pos_in[iat * 3 + ik] = (iat * 3 + ik) / (ucell->nat * 3.0) * (ucell->lat.lat0);
}
}
unitcell::update_pos_tau(ucell->lat,pos_in,ucell->ntype,ucell->nat,ucell->atoms);
for (int iat = 0; iat < ucell->nat; ++iat)
{
int it, ia;
ucell->iat2iait(iat, &ia, &it);
for (int ik = 0; ik < 3; ++ik)
{
EXPECT_DOUBLE_EQ(ucell->atoms[it].tau[ia][ik],
(iat*3+ik)/(ucell->nat*3.0));
}
}
delete[] pos_in;
}
TEST_F(UcellTest, UpdatePosTaud)
{
double* pos_in = new double[ucell->nat * 3];
Expand All @@ -147,6 +176,7 @@ TEST_F(UcellTest, UpdatePosTaud)
EXPECT_DOUBLE_EQ(ucell->atoms[it].taud[ia].y, tmp[iat].y + 0.01);
EXPECT_DOUBLE_EQ(ucell->atoms[it].taud[ia].z, tmp[iat].z + 0.01);
}
delete[] tmp;
delete[] pos_in;
}

Expand Down
2 changes: 1 addition & 1 deletion source/module_cell/test_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ install(FILES unitcell_test_pw_para.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
AddTest(
TARGET cell_unitcell_test_pw
LIBS parameter ${math_libs} base device
SOURCES unitcell_test_pw.cpp ../unitcell.cpp ../read_atoms.cpp ../atom_spec.cpp
SOURCES unitcell_test_pw.cpp ../unitcell.cpp ../read_atoms.cpp ../atom_spec.cpp ../update_cell.cpp ../bcast_cell.cpp
../atom_pseudo.cpp ../pseudo.cpp ../read_pp.cpp ../read_pp_complete.cpp ../read_pp_upf201.cpp ../read_pp_upf100.cpp
../read_pp_vwr.cpp ../read_pp_blps.cpp ../../module_io/output.cpp ../../module_elecstate/read_pseudo.cpp ../../module_elecstate/cal_nelec_nband.cpp
)
Expand Down
76 changes: 4 additions & 72 deletions source/module_cell/unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include "module_ri/serialization_cereal.h"
#endif


#include "update_cell.h"
UnitCell::UnitCell() {
if (test_unitcell) {
ModuleBase::TITLE("unitcell", "Constructor");
Expand Down Expand Up @@ -312,29 +314,7 @@ std::vector<ModuleBase::Vector3<int>> UnitCell::get_constrain() const
return constrain;
}

void UnitCell::update_pos_tau(const double* pos) {
int iat = 0;
for (int it = 0; it < this->ntype; it++) {
Atom* atom = &this->atoms[it];
for (int ia = 0; ia < atom->na; ia++) {
for (int ik = 0; ik < 3; ++ik) {
if (atom->mbl[ia][ik]) {
atom->dis[ia][ik]
= pos[3 * iat + ik] / this->lat0 - atom->tau[ia][ik];
atom->tau[ia][ik] = pos[3 * iat + ik] / this->lat0;
}
}

// the direct coordinates also need to be updated.
atom->dis[ia] = atom->dis[ia] * this->GT;
atom->taud[ia] = atom->tau[ia] * this->GT;
iat++;
}
}
assert(iat == this->nat);
this->periodic_boundary_adjustment();
this->bcast_atoms_tau();
}

void UnitCell::update_pos_taud(double* posd_in) {
int iat = 0;
Expand All @@ -349,7 +329,7 @@ void UnitCell::update_pos_taud(double* posd_in) {
}
}
assert(iat == this->nat);
this->periodic_boundary_adjustment();
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
this->bcast_atoms_tau();
}

Expand All @@ -367,7 +347,7 @@ void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {
}
}
assert(iat == this->nat);
this->periodic_boundary_adjustment();
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
this->bcast_atoms_tau();
}

Expand All @@ -383,54 +363,6 @@ void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {
assert(iat == this->nat);
}

void UnitCell::periodic_boundary_adjustment() {
//----------------------------------------------
// because of the periodic boundary condition
// we need to adjust the atom positions,
// first adjust direct coordinates,
// then update them into cartesian coordinates,
//----------------------------------------------
for (int it = 0; it < this->ntype; it++) {
Atom* atom = &this->atoms[it];
for (int ia = 0; ia < atom->na; ia++) {
// mohan update 2011-03-21
if (atom->taud[ia].x < 0) {
atom->taud[ia].x += 1.0;
}
if (atom->taud[ia].y < 0) {
atom->taud[ia].y += 1.0;
}
if (atom->taud[ia].z < 0) {
atom->taud[ia].z += 1.0;
}
if (atom->taud[ia].x >= 1.0) {
atom->taud[ia].x -= 1.0;
}
if (atom->taud[ia].y >= 1.0) {
atom->taud[ia].y -= 1.0;
}
if (atom->taud[ia].z >= 1.0) {
atom->taud[ia].z -= 1.0;
}

if (atom->taud[ia].x < 0 || atom->taud[ia].y < 0
|| atom->taud[ia].z < 0 || atom->taud[ia].x >= 1.0
|| atom->taud[ia].y >= 1.0 || atom->taud[ia].z >= 1.0) {
GlobalV::ofs_warning << " it=" << it + 1 << " ia=" << ia + 1
<< std::endl;
GlobalV::ofs_warning << "d=" << atom->taud[ia].x << " "
<< atom->taud[ia].y << " "
<< atom->taud[ia].z << std::endl;
ModuleBase::WARNING_QUIT(
"Ions_Move_Basic::move_ions",
"the movement of atom is larger than the length of cell.");
}

atom->tau[ia] = atom->taud[ia] * this->latvec;
}
}
return;
}

void UnitCell::bcast_atoms_tau() {
#ifdef __MPI
Expand Down
2 changes: 0 additions & 2 deletions source/module_cell/unitcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,9 @@ class UnitCell {
void print_cell(std::ofstream& ofs) const;
void print_cell_xyz(const std::string& fn) const;

void update_pos_tau(const double* pos);
void update_pos_taud(const ModuleBase::Vector3<double>* posd_in);
void update_pos_taud(double* posd_in);
void update_vel(const ModuleBase::Vector3<double>* vel_in);
void periodic_boundary_adjustment();
void bcast_atoms_tau();
bool judge_big_cell() const;

Expand Down
Loading

0 comments on commit c53f445

Please sign in to comment.