Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/qzhu2017/PyXtal
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Sep 17, 2024
2 parents ad0b306 + 5ffc6bf commit b4c33fd
Show file tree
Hide file tree
Showing 11 changed files with 344 additions and 171 deletions.
1 change: 1 addition & 0 deletions pyxtal/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,7 @@ def get_db_unique(self, db_name=None, prec=3):
kvp[key] = getattr(row, key)
db.write(row.toatoms(), key_value_pairs=kvp)
print(f"Created {db_name:s} with {db.count():d} strucs")
return db.count()

def check_overlap(self, reference_db, etol=2e-3, verbose=True):
"""
Expand Down
45 changes: 22 additions & 23 deletions pyxtal/interface/charmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,33 +93,37 @@ def __init__(
try:
self.structure.optimize_lattice()
except:
print("bug")
print("bug in Lattice")
print(self.structure)
print(self.structure.lattice)
print(self.structure.lattice.matrix)
raise ValueError("Problem in Lattice")
#raise ValueError("Problem in Lattice")
self.error = True
# print("\nbeginining lattice: ", struc.lattice)

self.lat_mut = lat_mut
self.rotate = rotate

def run(self, clean=True):
if not os.path.exists(self.folder):
os.makedirs(self.folder)
cwd = os.getcwd()
os.chdir(self.folder)

self.write() # ; print("write", time()-t0)
res = self.execute() # ; print("exe", time()-t0)
if res is not None:
self.read() # ; print("read", self.structure.energy)
else:
self.structure.energy = self.errorE
self.error = True
if clean:
self.clean()
"""
Only run calc if it makes sense
"""
if not self.error:
os.makedirs(self.folder, exist_ok=True)
cwd = os.getcwd()
os.chdir(self.folder)

self.write() # ; print("write", time()-t0)
res = self.execute() # ; print("exe", time()-t0)
if res is not None:
self.read() # ; print("read", self.structure.energy)
else:
self.structure.energy = self.errorE
self.error = True
if clean:
self.clean()

os.chdir(cwd)
os.chdir(cwd)

def execute(self):
cmd = self.exe + " < " + self.input + " > " + self.output
Expand Down Expand Up @@ -310,11 +314,6 @@ def read(self):
self.structure.energy *= Z

count = 0
# for i, site in enumerate(self.structure.mol_sites):
# coords = positions[count:count+len(site.mol)]
# site.update(coords, self.structure.lattice)
# count += len(site.mol)

# if True:
try:
for _i, site in enumerate(self.structure.mol_sites):
Expand All @@ -330,7 +329,7 @@ def read(self):
self.structure.energy = self.errorE
self.error = True
if self.debug:
print("Unable to retrieve Structure after optimization")
print("Cannot retrieve Structure after optimization")
print("lattice", self.structure.lattice)
self.structure.to_file("1.cif")
print("Check 1.cif in ", os.getcwd())
Expand Down
160 changes: 129 additions & 31 deletions pyxtal/lego/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,12 @@ def minimize_from_x(x, dim, spg, wps, elements, calculator, ref_environments,
while True:
count += 1
try:
xtal.from_random(
dim, g, elements, numIons, sites=sites_wp, factor=1.0, random_state=random_state)
xtal.from_random(dim, g, elements, numIons,
sites=sites_wp, factor=1.0,
random_state=random_state)
except RuntimeError:
print(g.number, numIons, sites)
print("Trouble in generating random xtals from pyxtal, try again")
print("Trouble in generating random xtals from pyxtal")
if xtal.valid:
atoms = xtal.to_ase(resort=False, add_vaccum=False)
try:
Expand Down Expand Up @@ -334,8 +335,11 @@ def print_fun(x, f, accepted):

# Extract the optimized xtal
xtal = pyxtal()
xtal.from_1d_rep(res.x, sites, dim=dim)
return xtal, (x0, res.x)
try:
xtal.from_1d_rep(res.x, sites, dim=dim)
return xtal, (x0, res.x)
except:
return None


def calculate_dSdx(x, xtal, des_ref, f, eps=1e-4, symmetry=True, verbose=False):
Expand Down Expand Up @@ -591,6 +595,13 @@ def __str__(self):
def __repr__(self):
return str(self)

def print_memory_usage(self):
import psutil
process = psutil.Process(os.getpid())
mem = process.memory_info().rss / 1024 ** 2
self.logging.info(f"Rank {self.rank} memory: {mem:.1f} MB")
print(f"Rank {self.rank} memory: {mem:.1f} MB")

def set_descriptor_calculator(self, dtype='SO3', mykwargs={}):
"""
Set up the calculator for descriptor computation.
Expand Down Expand Up @@ -783,6 +794,89 @@ def optimize_xtals_serial(self, xtals, args):
xtals_opt.append(xtal)
return xtals_opt

def optimize_reps(self, reps, ncpu=1, opt_type='local',
T=0.2, niter=20, early_quit=0.02,
add_db=True, symmetrize=False,
minimizers=[('Nelder-Mead', 100), ('L-BFGS-B', 100)],
):
"""
Perform optimization for each structure
Args:
reps: list of reps
ncpu (int):
"""
args = (opt_type, T, niter, early_quit, add_db, symmetrize, minimizers)
if ncpu > 1:
valid_xtals = self.optimize_reps_mproc(reps, ncpu, args)
return valid_xtals
else:
raise NotImplementedError("optimize_reps works in parallel mode")

def optimize_reps_mproc(self, reps, ncpu, args):
"""
Optimization in multiprocess mode.
Args:
reps: list of reps
ncpu (int): number of parallel python processes
args: (opt_type, T, n_iter, early_quit, add_db, symmetrize, minimizers)
"""
from multiprocessing import Pool
from collections import deque
import gc

pool = Pool(processes=ncpu)
(opt_type, T, niter, early_quit, add_db, symmetrize, minimizers) = args
xtals_opt = deque()

# Split the input structures to minibatches
N_rep = 4
N_batches = N_rep * ncpu
for _i, i in enumerate(range(0, len(reps), N_batches)):
start, end = i, min([i+N_batches, len(reps)])
ids = list(range(start, end))
print(f"Rank {self.rank} minibatch {start} {end}")
self.print_memory_usage()

def generate_args():
"""
A generator to yield argument lists for minimize_from_x_par.
"""
for j in range(ncpu):
_ids = ids[j::ncpu]
wp_libs = []
for id in _ids:
rep = reps[id]
xtal = pyxtal()
xtal.from_tabular_representation(rep, normalize=False)
x = xtal.get_1d_rep_x()
spg, wps, _ = self.get_input_from_ref_xtal(xtal)
wp_libs.append((x, spg, wps))
yield (self.dim, wp_libs, self.elements, self.calculator,
self.ref_environments, opt_type, T, niter,
early_quit, minimizers)
# Use the generator to pass args to reduce memory usage
_xtal, _xs = None, None
for result in pool.imap_unordered(minimize_from_x_par, generate_args()):
if result is not None:
(_xtals, _xs) = result
valid_xtals = self.process_xtals(
_xtals, _xs, add_db, symmetrize)
xtals_opt.extend(valid_xtals) # Use deque to reduce memory

# Remove the duplicate structures
self.db.update_row_topology(overwrite=False, prefix=self.prefix)
self.db.clean_structures_spg_topology(dim=self.dim)

# After each minibatch, delete the local variables and run garbage collection
del ids, _xtals, _xs
gc.collect() # Explicitly call garbage collector to free memory

xtals_opt = list(xtals_opt)
print(f"Rank {self.rank} finish optimize_reps_mproc {len(xtals_opt)}")
return xtals_opt

def optimize_xtals_mproc(self, xtals, ncpu, args):
"""
Optimization in multiprocess mode.
Expand All @@ -793,11 +887,12 @@ def optimize_xtals_mproc(self, xtals, ncpu, args):
args: (opt_type, T, n_iter, early_quit, add_db, symmetrize, minimizers)
"""
from multiprocessing import Pool
pool = Pool(processes=ncpu)
from collections import deque
import gc

pool = Pool(processes=ncpu)
(opt_type, T, niter, early_quit, add_db, symmetrize, minimizers) = args

xtals_opt = []
xtals_opt = deque()

# Split the input structures to minibatches
N_rep = 4
Expand All @@ -806,38 +901,41 @@ def optimize_xtals_mproc(self, xtals, ncpu, args):
start, end = i, min([i+N_batches, len(xtals)])
ids = list(range(start, end))
print(f"Rank {self.rank} minibatch {start} {end}")
args_list = []
for j in range(ncpu):
_ids = ids[j::ncpu]
# print(f"test batch_{_i} cpu_{j}", _ids)
wp_libs = []
for id in _ids:
xtal = xtals[id]
x = xtal.get_1d_rep_x()
spg, wps, _ = self.get_input_from_ref_xtal(xtal)
wp_libs.append((x, xtal.group.number, wps))

args_list.append((self.dim,
wp_libs,
self.elements,
self.calculator,
self.ref_environments,
opt_type,
T,
niter,
early_quit,
minimizers))
for result in pool.imap_unordered(minimize_from_x_par, args_list):
self.print_memory_usage()

def generate_args():
"""
A generator to yield argument lists for minimize_from_x_par.
"""
for j in range(ncpu):
_ids = ids[j::ncpu]
wp_libs = []
for id in _ids:
xtal = xtals[id]
x = xtal.get_1d_rep_x()
spg, wps, _ = self.get_input_from_ref_xtal(xtal)
wp_libs.append((x, xtal.group.number, wps))
yield (self.dim, wp_libs, self.elements, self.calculator,
self.ref_environments, opt_type, T, niter,
early_quit, minimizers)
# Use the generator to pass args to reduce memory usage
_xtal, _xs = None, None
for result in pool.imap_unordered(minimize_from_x_par, generate_args()):
if result is not None:
(_xtals, _xs) = result
valid_xtals = self.process_xtals(
_xtals, _xs, add_db, symmetrize)
xtals_opt.extend(valid_xtals)
xtals_opt.extend(valid_xtals) # Use deque to reduce memory

# Remove the duplicate structures
self.db.update_row_topology(overwrite=False, prefix=self.prefix)
self.db.clean_structures_spg_topology(dim=self.dim)

# After each minibatch, delete the local variables and run garbage collection
del ids, _xtals, _xs
gc.collect() # Explicitly call garbage collector to free memory

xtals_opt = list(xtals_opt)
print(f"Rank {self.rank} finish optimize_xtals_mproc {len(xtals_opt)}")
return xtals_opt

Expand Down
14 changes: 8 additions & 6 deletions pyxtal/optimize/DFS.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,22 +162,20 @@ def _run(self, pool=None):
# Related to the FF optimization
N_added = 0
success_rate = 0

# To save for comparison
cur_survivals = [0] * self.N_pop # track the survivals
hist_best_xtals = [None] * self.N_pop
hist_best_engs = [self.E_max] * self.N_pop
print(f"Rank {self.rank} starts DFS in {self.tag}")

for gen in range(self.N_gen):
self.generation = gen
cur_xtals = None
print(f"Rank {self.rank} entering generation {gen} in {self.tag}")
self.logging.info(f"Gen {gen} starts in Rank {self.rank}")

if self.rank == 0:
print(f"\nGeneration {gen:d} starts")
self.logging.info(f"Generation {gen:d} starts")
t0 = time()

# Initialize structure and tags
cur_xtals = [(None, "Random")] * self.N_pop

Expand All @@ -193,7 +191,7 @@ def _run(self, pool=None):
if min_E < mid_E and cur_survivals[id] < self.N_survival:

if self.random_state.random() < 0.7:
source = prev_xtals[id][0]
source = prev_xtals[id][0]
else:
source = hist_best_xtals[id]

Expand All @@ -216,11 +214,12 @@ def _run(self, pool=None):

# Local optimization
gen_results = self.local_optimization(cur_xtals, pool=pool)
self.logging.info(f"Rank {self.rank} finishes local_opt.")

prev_xtals = None
if self.rank == 0:
# pass results, summary_and_ranking
cur_xtals, matches, engs = self.gen_summary(t0,
cur_xtals, matches, engs = self.gen_summary(t0,
gen_results, cur_xtals)

# update hist_best
Expand All @@ -237,6 +236,7 @@ def _run(self, pool=None):
# broadcast
if self.use_mpi:
prev_xtals = self.comm.bcast(prev_xtals, root=0)
self.logging.info(f"Gen {gen} bcast in Rank {self.rank}")

# Update the FF parameters if necessary
if self.ff_opt:
Expand All @@ -256,9 +256,11 @@ def _run(self, pool=None):
if self.use_mpi:
quit = self.comm.bcast(quit, root=0)
self.comm.Barrier()
self.logging.info(f"Gen {gen} Finish in Rank {self.rank}")

# Ensure that all ranks exit
if quit:
self.logging.info(f"Early Termination in Rank {self.rank}")
return success_rate

return success_rate
Expand Down
Loading

0 comments on commit b4c33fd

Please sign in to comment.