Skip to content

Commit

Permalink
optimize the memory in builder/optimize/symmetry
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Sep 8, 2024
1 parent dec54c7 commit b2acbc6
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 95 deletions.
59 changes: 31 additions & 28 deletions pyxtal/lego/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,12 @@ def minimize_from_x(x, dim, spg, wps, elements, calculator, ref_environments,
while True:
count += 1
try:
xtal.from_random(
dim, g, elements, numIons, sites=sites_wp, factor=1.0, random_state=random_state)
xtal.from_random(dim, g, elements, numIons,
sites=sites_wp, factor=1.0,
random_state=random_state)
except RuntimeError:
print(g.number, numIons, sites)
print("Trouble in generating random xtals from pyxtal, try again")
print("Trouble in generating random xtals from pyxtal")
if xtal.valid:
atoms = xtal.to_ase(resort=False, add_vaccum=False)
try:
Expand Down Expand Up @@ -793,11 +794,12 @@ def optimize_xtals_mproc(self, xtals, ncpu, args):
args: (opt_type, T, n_iter, early_quit, add_db, symmetrize, minimizers)
"""
from multiprocessing import Pool
pool = Pool(processes=ncpu)
from collections import deque
import gc

pool = Pool(processes=ncpu)
(opt_type, T, niter, early_quit, add_db, symmetrize, minimizers) = args

xtals_opt = []
xtals_opt = deque()

# Split the input structures to minibatches
N_rep = 4
Expand All @@ -806,38 +808,39 @@ def optimize_xtals_mproc(self, xtals, ncpu, args):
start, end = i, min([i+N_batches, len(xtals)])
ids = list(range(start, end))
print(f"Rank {self.rank} minibatch {start} {end}")
args_list = []
for j in range(ncpu):
_ids = ids[j::ncpu]
# print(f"test batch_{_i} cpu_{j}", _ids)
wp_libs = []
for id in _ids:
xtal = xtals[id]
x = xtal.get_1d_rep_x()
spg, wps, _ = self.get_input_from_ref_xtal(xtal)
wp_libs.append((x, xtal.group.number, wps))

args_list.append((self.dim,
wp_libs,
self.elements,
self.calculator,
self.ref_environments,
opt_type,
T,
niter,
early_quit,
minimizers))

def generate_args():
"""
A generator to yield argument lists for minimize_from_x_par.
"""
for j in range(ncpu):
_ids = ids[j::ncpu]
wp_libs = []
for id in _ids:
xtal = xtals[id]
x = xtal.get_1d_rep_x()
spg, wps, _ = self.get_input_from_ref_xtal(xtal)
wp_libs.append((x, xtal.group.number, wps))
yield (self.dim, wp_libs, self.elements, self.calculator,
self.ref_environments, opt_type, T, niter,
early_quit, minimizers)
# Use the generator to pass args to reduce memory usage
for result in pool.imap_unordered(minimize_from_x_par, args_list):
if result is not None:
(_xtals, _xs) = result
valid_xtals = self.process_xtals(
_xtals, _xs, add_db, symmetrize)
xtals_opt.extend(valid_xtals)
xtals_opt.extend(valid_xtals) # Use deque to reduce memory

# Remove the duplicate structures
self.db.update_row_topology(overwrite=False, prefix=self.prefix)
self.db.clean_structures_spg_topology(dim=self.dim)

# After each minibatch, delete the local variables and run garbage collection
del ids, wp_libs, _xtals, _xs
gc.collect() # Explicitly call garbage collector to free memory

xtals_opt = list(xtals_opt)
print(f"Rank {self.rank} finish optimize_xtals_mproc {len(xtals_opt)}")
return xtals_opt

Expand Down
38 changes: 22 additions & 16 deletions pyxtal/optimize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from multiprocessing import Pool
from concurrent.futures import TimeoutError
import signal
import gc

import logging
import os
Expand Down Expand Up @@ -1046,37 +1047,42 @@ def local_optimization_mproc(self, xtals, ncpu, ids=None, qrs=False, pool=None):
ncpu (int): number of parallel python processes
ids (list):
qrs (bool): Force mutation or not (related to QRS)
pool : multiprocess pool
"""
gen = self.generation
t0 = time()
args = self._get_local_optimization_args()

if ids is None:
ids = range(len(xtals))

N_cycle = int(np.ceil(len(xtals) / ncpu))
args_lists = []

# Assign args
for i in range(ncpu):
id1 = i * N_cycle
id2 = min([id1 + N_cycle, len(xtals)])
# os.makedirs(folder, exist_ok=True)
_ids = ids[id1: id2]
job_tags = [self.tag + "-g" + str(gen)
+ "-p" + str(id) for id in _ids]
_xtals = [xtals[id][0] for id in range(id1, id2)]
mutates = [False if qrs else xtal is not None for xtal in _xtals]
my_args = [_xtals, _ids, mutates, job_tags, *args, self.timeout]#, self.logging]
args_lists.append(tuple(my_args))

# Generator to create arg_lists for multiprocessing tasks
def generate_args_lists():
for i in range(ncpu):
id1 = i * N_cycle
id2 = min([id1 + N_cycle, len(xtals)])
_ids = ids[id1: id2]
job_tags = [self.tag + "-g" + str(gen)
+ "-p" + str(id) for id in _ids]
_xtals = [xtals[id][0] for id in range(id1, id2)]
mutates = [False if qrs else xtal is not None for xtal in _xtals]
my_args = [_xtals, _ids, mutates, job_tags, *args, self.timeout]
yield tuple(my_args) # Yield args instead of appending to a list

self.logging.info(f"Rank {self.rank} assign args in local_opt_mproc")
gen_results = []

for result in pool.imap_unordered(process_task, args_lists):
gen_results = []
# Stream the results to avoid holding too much in memory at once
for result in pool.imap_unordered(process_task, generate_args_lists()):
if result is not None:
#self.logging.info(f"Rank {self.rank} grab {len(result)} strucs")
for _res in result:
gen_results.append(_res)
# Explicitly delete the result and call garbage collection
def result
gc.collect()

return gen_results

Expand Down
113 changes: 62 additions & 51 deletions pyxtal/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import re
from copy import deepcopy
from ast import literal_eval

import numpy as np
from monty.serialization import loadfn
Expand Down Expand Up @@ -731,7 +732,6 @@ class Group:
[[71, 139], [129, 139], [137, 139]]
Args:
group: the group symbol or international number
dim (defult: 3): the periodic dimension of the group
Expand All @@ -745,6 +745,8 @@ def __init__(self, group, dim=3, use_hall=False, style="pyxtal", quick=False):
self.dim = dim
names = ["Point", "Rod", "Layer", "Space"]
self.header = "-- " + names[dim] + "group --"

# Retrieve symbol and number for the group (avoid redundancy)
if not use_hall:
self.symbol, self.number = get_symbol_and_number(group, dim)
else:
Expand All @@ -755,60 +757,68 @@ def __init__(self, group, dim=3, use_hall=False, style="pyxtal", quick=False):

if dim == 3:
results = get_point_group(self.number)
self.point_group = results[0]
self.pg_number = results[1]
self.polar = results[2]
self.inversion = results[3]
self.chiral = results[4]
self.point_group, self.pg_number, self.polar, self.inversion, self.chiral = results

# Lazy load Wyckoff positions and hall data unless quick=True
if not quick:
if dim == 3:
if not use_hall:
if style == "pyxtal":
self.hall_number = pyxtal_hall_numbers[self.number - 1]
else:
self.hall_number = spglib_hall_numbers[self.number - 1]
self._initialize_hall_data(group, use_hall, style, dim)
self._initialize_wyckoff_data(dim)

def _initialize_hall_data(self, group, use_hall, style, dim):
"""Initialize hall number and transformation matrices."""
if dim == 3:
if not use_hall:
if style == "pyxtal":
self.hall_number = pyxtal_hall_numbers[self.number - 1]
else:
self.hall_number = group
self.P = abc2matrix(HALL_TABLE["P"][self.hall_number - 1])
self.P1 = abc2matrix(HALL_TABLE["P^-1"][self.hall_number - 1])
self.hall_number = spglib_hall_numbers[self.number - 1]
else:
self.hall_number = None
self.P = None
self.P1 = None

# Wyckoff positions, site_symmetry, generator
self.wyckoffs = get_wyckoffs(self.number, dim=dim)
self.w_symm = get_wyckoff_symmetry(self.number, dim=dim)

wpdicts = [
{
"index": i,
"letter": letter_from_index(i, self.wyckoffs, dim=self.dim),
"ops": self.wyckoffs[i],
"multiplicity": len(self.wyckoffs[i]),
"symmetry": self.w_symm[i],
"PBC": self.PBC,
"dim": self.dim,
"number": self.number,
"symbol": self.symbol,
"P": self.P,
"P1": self.P1,
"hall_number": self.hall_number,
"directions": self.get_symmetry_directions(),
"lattice_type": self.lattice_type,
}
for i in range(len(self.wyckoffs))
]
self.hall_number = group
self.P = abc2matrix(HALL_TABLE["P"][self.hall_number - 1])
self.P1 = abc2matrix(HALL_TABLE["P^-1"][self.hall_number - 1])
else:
self.hall_number, self.P, self.P1 = None, None, None

def _initialize_wyckoff_data(self, dim):
"""Initialize Wyckoff positions and organize them."""
# Wyckoff positions, site_symmetry, generator
self.wyckoffs = get_wyckoffs(self.number, dim=dim)
self.w_symm = get_wyckoff_symmetry(self.number, dim=dim)

# Create dicts with relevant Wyckoff position data lazily
wpdicts_gen = [
{
"index": i,
"letter": letter_from_index(i, self.wyckoffs, dim=self.dim),
"ops": self.wyckoffs[i],
"multiplicity": len(self.wyckoffs[i]),
"symmetry": self.w_symm[i],
"PBC": self.PBC,
"dim": self.dim,
"number": self.number,
"symbol": self.symbol,
"P": self.P,
"P1": self.P1,
"hall_number": self.hall_number,
"directions": self.get_symmetry_directions(),
"lattice_type": self.lattice_type,
}
for i in range(len(self.wyckoffs))
]

# A list of Wyckoff_positions sorted by descending multiplicity
self.Wyckoff_positions = []
for wpdict in wpdicts:
wp = Wyckoff_position.from_dict(wpdict)
self.Wyckoff_positions.append(wp)
# A list of Wyckoff_positions sorted by descending multiplicity
#self.Wyckoff_positions = []
#for wpdict in wpdicts:
# wp = Wyckoff_position.from_dict(wpdict)
# self.Wyckoff_positions.append(wp)

# A 2D list of WP objects, grouped and sorted by multiplicity
self.wyckoffs_organized = organized_wyckoffs(self)
# Use a generator to avoid keeping the full list of dicts in memory
self.Wyckoff_positions = [
Wyckoff_position.from_dict(wpdict) for wpdict in wpdicts_gen
]

# Organize wyckoffs by multiplicity
self.wyckoffs_organized = organized_wyckoffs(self)

def __str__(self):
if self.string is not None:
Expand Down Expand Up @@ -3899,7 +3909,8 @@ def get_wyckoffs(num, organized=False, dim=3):
elif dim == 0:
df = SYMDATA.get_wyckoff_pg()

wyckoff_strings = eval(df["0"][num])
# Convert the string from df into a list of wyckoff strings
wyckoff_strings = literal_eval(df["0"][num]) # Use literal_eval instead of eval

wyckoffs = []
for x in wyckoff_strings:
Expand All @@ -3920,9 +3931,9 @@ def get_wyckoffs(num, organized=False, dim=3):
wyckoffs_organized[-1].append(wp)
return wyckoffs_organized
else:
# Return Wyckoff positions without organization
return wyckoffs


def get_wyckoff_symmetry(num, dim=3):
"""
Returns a list of site symmetry for a given group.
Expand Down
12 changes: 12 additions & 0 deletions tests/test_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,28 @@ def print_memory_usage():
if __name__ == "__main__":
print_memory_usage()
from pyxtal.symmetry import Group

print("\nImport group")
print_memory_usage()

print("\nCreate layer Group with quick")
g = Group(4, dim=2, quick=True)
print_memory_usage()

print("\nCreate Group with quick")
g = Group(4, quick=True)
print_memory_usage()

print("\nCreate Group")
g = Group(227)
print_memory_usage()

print("\nCall pyxtal")
from pyxtal import pyxtal
xtal = pyxtal()
xtal.from_spg_wps_rep(194, ['2c', '2b'], [2.46, 6.70])
print_memory_usage()

print("\nCall subgroup")
xtal.subgroup_once()
print_memory_usage()

0 comments on commit b2acbc6

Please sign in to comment.