Skip to content

Commit

Permalink
wip: remove all random imports
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jul 16, 2024
1 parent 0f75a2c commit 42206d4
Show file tree
Hide file tree
Showing 14 changed files with 195 additions and 169 deletions.
2 changes: 1 addition & 1 deletion flask/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os


class Config(object):
class Config:
SECRET_KEY = os.environ.get("SECRET_KEY")

MAIL_SERVER = os.environ.get("MAIL_SERVER")
Expand Down
90 changes: 40 additions & 50 deletions pyxtal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import itertools
import json
from copy import deepcopy
from random import choice, sample

import numpy as np
from ase import Atoms
Expand Down Expand Up @@ -171,7 +170,7 @@ class pyxtal:
"""

def __init__(self, molecular=False):
def __init__(self, molecular=False, random_state=None):
self.valid = False
self.molecular = molecular
self.standard_setting = True
Expand All @@ -185,6 +184,8 @@ def __init__(self, molecular=False):
self.dim = 3
self.factor = 1.0
self.PBC = [1, 1, 1]
self.random_state = np.random.default_rng(random_state)

if molecular:
self.molecules = []
self.mol_sites = []
Expand Down Expand Up @@ -272,6 +273,7 @@ def from_random(
block=None,
num_block=None,
seed=None,
random_state=None,
tm=None,
use_hall=False,
):
Expand Down Expand Up @@ -325,22 +327,24 @@ def from_random(
conventional=conventional,
tm=tm,
seed=seed,
random_state=random_state,
use_hall=use_hall,
)
else:
struc = random_crystal(
dim,
group,
species,
numIons,
factor,
thickness,
area,
lattice,
sites,
conventional,
tm,
dim=dim,
group=group,
species=species,
numIons=numIons,
factor=factor,
thickness=thickness,
area=area,
lattice=lattice,
sites=sites,
conventional=conventional,
tm=tm,
use_hall=use_hall,
random_state=random_state,
)
if force_pass or struc.valid:
quit = True
Expand Down Expand Up @@ -474,9 +478,7 @@ def _from_pymatgen(self, struc, tol=1e-3, a_tol=5.0, style="pyxtal", hn=None):
if self.valid:
d = sym_struc.composition.as_dict()
species = list(d.keys())
numIons = []
for ele in species:
numIons.append(int(d[ele]))
numIons = [int(d[ele]) for ele in species]
self.numIons = numIons
self.species = species
if hn is None:
Expand Down Expand Up @@ -569,9 +571,9 @@ def check_short_distances(self, r=0.7, exclude_H=True):
if exclude_H:
pmg_struc.remove_species("H")
res = pmg_struc.get_all_neighbors(r)
for i, neighs in enumerate(res):
for n in neighs:
pairs.append([pmg_struc.sites[i].specie, n.specie, n.nn_distance])
pairs = [
[pmg_struc.sites[i].specie, n.specie, n.nn_distance] for i, neighs in enumerate(res) for n in neighs
]
else:
raise NotImplementedError("Does not support cluster for now")
return pairs
Expand Down Expand Up @@ -716,7 +718,7 @@ def subgroup(
idx, sites, t_types, k_types = self._get_subgroup_ids(H, group_type, idx, max_cell, min_cell)
# randomly choose a subgroup from the available list
if N_groups is not None and len(idx) >= N_groups:
idx = sample(idx, N_groups)
idx = self.random_state.choice(idx, N_groups)
# print('max_sub_group', len(idx), max_subgroups)

valid_splitters = []
Expand Down Expand Up @@ -838,7 +840,7 @@ def subgroup_once(
# Try 100 times to see if a valid split can be found
count = 0
while count < 100:
id = choice(idx)
id = self.random_state.choice(idx)
gtype = (t_types + k_types)[id]
if gtype == "k":
id -= len(t_types)
Expand Down Expand Up @@ -905,8 +907,8 @@ def _apply_substitution(self, splitter, perms):
for site_id, site in enumerate(new_struc.atom_sites):
if site.specie in perms:
site_ids.append(site_id)
N = choice(range(1, len(site_ids))) if len(site_ids) > 1 else 1
sub_ids = sample(site_ids, N)
N = self.random_state.choice(range(1, len(site_ids))) if len(site_ids) > 1 else 1
sub_ids = self.random_state.choice(site_ids, N)
for sub_id in sub_ids:
key = new_struc.atom_sites[sub_id].specie
new_struc.atom_sites[sub_id].specie = perms[key]
Expand Down Expand Up @@ -1165,9 +1167,8 @@ def _get_formula(self):
self.numIons = numIons
numspecies = self.numIons
for i, s in zip(numspecies, species):
if type(s) == int:
s = Element(s).short_name
formula += f"{s:s}{int(i):d}"
specie = Element(s).short_name if isinstance(s, str) else s
formula += f"{specie:s}{int(i):d}"
self.formula = formula

def get_zprime(self, integer=False):
Expand Down Expand Up @@ -1466,13 +1467,10 @@ def save_dict(self):
"""
Save the model as a dictionary
"""
sites = []
if self.molecular:
for site in self.mol_sites:
sites.append(site.save_dict())
sites = [site.save_dict() for site in self.mol_sites]
else:
for site in self.atom_sites:
sites.append(site.save_dict())
sites = [site.save_dict() for site in self.atom_sites]

return {
"lattice": self.lattice.matrix,
Expand Down Expand Up @@ -1504,18 +1502,17 @@ def load_dict(self, dict0):
self.numMols = dict0["numMols"]
self.valid = dict0["valid"]
self.formula = dict0["formula"]
sites = []

if dict0["molecular"]:
self.molecules = [None] * len(self.numMols)
for site in dict0["sites"]:
msite = mol_site.load_dict(site)
sites.append(msite)
sites = [mol_site.load_dict(site) for site in dict0["sites"]]
# TODO: this for loop makes repeated calls for duplicated molecules
for msite in sites:
if self.molecules[msite.type] is None:
self.molecules[msite.type] = msite.molecule
self.mol_sites = sites
else:
for site in dict0["sites"]:
sites.append(atom_site.load_dict(site))
sites = [atom_site.load_dict(site) for site in dict0["sites"]]
self.atom_sites = sites

def build(self, group, species, numIons, lattice, sites, tol=1e-2, dim=3, use_hall=False):
Expand Down Expand Up @@ -1578,15 +1575,12 @@ def build(self, group, species, numIons, lattice, sites, tol=1e-2, dim=3, use_ha
wp0 = self.group[0]
for sp, wps in zip(species, sites):
for wp in wps:
if type(wp) is dict: # dict
if isinstance(wp, dict): # dict
for pair in wp.items():
(key, pos) = pair
_wp = choose_wyckoff(self.group, site=key)
if _wp is not False:
if _wp.get_dof() == 0: # fixed pos
pt = [0.0, 0.0, 0.0]
else:
pt = _wp.get_all_positions(pos)[0]
pt = [0.0, 0.0, 0.0] if _wp.get_dof() == 0 else _wp.get_all_positions(pos)[0]
_sites.append(atom_site(_wp, pt, sp))
else:
raise RuntimeError("Cannot interpret site", key)
Expand Down Expand Up @@ -2237,12 +2231,7 @@ def sort_sites_by_numIons(self, seq=None):
if seq is None:
seq = np.argsort(self.numIons)

sites = []
for i in seq:
for site in self.atom_sites:
if self.species[i] == site.specie:
sites.append(site)
self.atom_sites = sites
self.atom_sites = [site for i in seq for site in self.atom_sites if self.species[i] == site.specie]

def get_transition(self, ref_struc, d_tol=1.0, d_tol2=0.3, N_images=2, max_path=30, both=False):
"""
Expand Down Expand Up @@ -2840,7 +2829,7 @@ def substitute(self, dicts):
pmg = self.to_pymatgen()
pmg.replace_species(dicts)
if self.molecular:
for _e1e in dicts:
for ele in dicts:
smi = [m.smile.replace(ele, dicts[ele]) + ".smi" for m in self.molecules]
self.from_seed(pmg, smi)
else:
Expand Down Expand Up @@ -3500,6 +3489,7 @@ def check_validity(self, criteria, verbose=False):

if "Dimension" in criteria:
try:
# TODO: unclear what is the criteria_cutoff
dim1 = self.get_dimensionality(criteria_cutoff)
except:
dim1 = 3
Expand Down Expand Up @@ -3564,7 +3554,7 @@ def get_tabular_representations(self, N_max=30, N_wp=8, normalize=False, perturb
sites_mul = [range(min([min_wp, site.wp.multiplicity])) for site in self.atom_sites]
ids = list(itertools.product(*sites_mul))
if len(ids) > N_max:
ids = sample(ids, N_max)
ids = self.random_state.choice(ids, N_max)

print(f"Number of augments {len(ids):4d} from ", self.get_xtal_string())
for sites_id in ids:
Expand Down
3 changes: 1 addition & 2 deletions pyxtal/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pymatgen.core.bonds import CovalentBond
from pymatgen.core.structure import Molecule, Structure

from pyxtal import pyxtal
from pyxtal.constants import logo
from pyxtal.lattice import Lattice
from pyxtal.molecule import Orientation, compare_mol_connectivity, pyxtal_molecule
Expand All @@ -35,7 +34,7 @@ def in_merged_coords(wp, pt, pts, cell):
return False


def get_cif_str_for_pyxtal(struc: pyxtal, header: str = "", sym_num=None, style: str = "mp"):
def get_cif_str_for_pyxtal(struc, header: str = "", sym_num=None, style: str = "mp"):
"""Get the cif string for a given structure. The default setting for
_atom_site follows the materials project cif
Expand Down
18 changes: 11 additions & 7 deletions pyxtal/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
# Set required parameters
if PBC is None:
PBC = [1, 1, 1]

if ltype in ltype_keywords:
self.ltype = ltype.lower()
elif ltype is None:
Expand All @@ -68,7 +69,12 @@ def __init__(
self.dim = sum(PBC)
self.kwargs = {}
self.random = True
self.random_state = np.random.default_rng()

if isinstance(random_state, Generator):
self.random_state = random_state.spawn(1)[0]
else:
self.random_state = np.random.default_rng(random_state)

# Set optional values
self.allow_volume_reset = True
for key, value in kwargs.items():
Expand Down Expand Up @@ -602,6 +608,7 @@ def set_matrix(self, matrix=None):
raise ValueError(msg)
else:
self.reset_matrix()

para = matrix2para(self.matrix)
self.a, self.b, self.c, self.alpha, self.beta, self.gamma = para
self.volume = np.linalg.det(self.matrix)
Expand Down Expand Up @@ -677,9 +684,7 @@ def swap_axis(self, random=False, ids=None):
allowed_ids = [[0, 1, 2]]

if random:
from random import choice

ids = choice(allowed_ids)
ids = self.random_state.choice(allowed_ids)
else:
if ids not in allowed_ids:
print(ids)
Expand Down Expand Up @@ -716,9 +721,7 @@ def swap_angle(self, random=True, ids=None):
allowed_ids = ["No"]

if random:
from random import choice

ids = choice(allowed_ids)
ids = self.random_state.choice(allowed_ids)
else:
if ids not in allowed_ids:
print(ids)
Expand Down Expand Up @@ -1834,6 +1837,7 @@ def para2matrix(cell_para, radians=True, format="upper"):
sin_gamma = np.sin(gamma)
sin_alpha = np.sin(alpha)
matrix = np.zeros([3, 3])

if format == "lower":
# Generate a lower-diagonal matrix
c1 = c * cos_beta
Expand Down
10 changes: 4 additions & 6 deletions pyxtal/miscellaneous/Random_vasp_ase.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import time
import warnings
from random import randint

import numpy as np
from ase import Atoms
from ase.calculators.vasp import Vasp
from ase.io import read
Expand All @@ -25,7 +25,7 @@
minvec = 2.5
maxangle = 150
minangle = 30

random_state = np.random.default_rng()
setup = None


Expand Down Expand Up @@ -209,14 +209,12 @@ def optimize(struc, dir1):

for i in range(1000):
os.chdir(dir0)
numIons[0] = randint(1, 5)
numIons[0] = random_state.integers(1, 5)
numIons[1] = 6 - numIons[0]
numIons[2] = numIons[0] + 2 * numIons[1]
run = True
while run:
# numIons[0] = randint(8,16)
sg = randint(3, 230)
# print(sg, species, numIons, factor)
sg = random_state.integers(3, 230)
rand_crystal = random_crystal(sg, species, numIons, factor)
if rand_crystal.valid:
run = False
Expand Down
5 changes: 3 additions & 2 deletions pyxtal/miscellaneous/test_3D_molecule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from random import randint
from time import time

import numpy as np
from pymatgen.io.cif import CifWriter

from pyxtal.molecular_crystal import molecular_crystal

rng = np.random.default_rng(0)
mols = ["CH4", "H2O", "NH3", "urea", "benzene", "roy", "aspirin", "pentacene", "C60"]
filename = "out.cif"
if os.path.isfile(filename):
Expand All @@ -15,7 +16,7 @@
for _i in range(10):
run = True
while run:
sg = randint(4, 191)
sg = rng.integers(4, 191)
start = time()
rand_crystal = molecular_crystal(sg, [mol], [4], 1.0)
if rand_crystal.valid:
Expand Down
Loading

0 comments on commit 42206d4

Please sign in to comment.