Skip to content

Commit

Permalink
tests: tests now all pass but more confirmation needed before merging
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jul 16, 2024
1 parent 8d8b0db commit ba292c6
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 109 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ lint.ignore = [
"E722", # do not use bare 'except' -- should be fixed but too much tech debt to fix now
"RET505", # Unnecessary else after return
"SIM300", # Yoda conditions are not allowed
"PLW2901", # loop variable is overwritten
]
lint.pydocstyle.convention = "google"
lint.isort.known-third-party = ["wandb"]
135 changes: 59 additions & 76 deletions pyxtal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,19 +715,19 @@ def subgroup(
a list of pyxtal structures with lower symmetries
"""

idx, sites, t_types, k_types = self._get_subgroup_ids(H, group_type, idx, max_cell, min_cell)
ids, sites, t_types, k_types = self._get_subgroup_ids(H, group_type, idx, max_cell, min_cell)
# randomly choose a subgroup from the available list
if N_groups is not None and len(idx) >= N_groups:
idx = self.random_state.choice(idx, N_groups)
if N_groups is not None and len(ids) >= N_groups:
ids = self.random_state.choice(ids, N_groups)
# print('max_sub_group', len(idx), max_subgroups)

valid_splitters = []
bad_splitters = []
for id in idx:
gtype = (t_types + k_types)[id]
for idx in ids:
gtype = (t_types + k_types)[idx]
if gtype == "k":
id -= len(t_types)
splitter = wyckoff_split(G=self.group, wp1=sites, idx=id, group_type=gtype)
idx -= len(t_types)
splitter = wyckoff_split(G=self.group, wp1=sites, idx=idx, group_type=gtype)

if not splitter.error:
if perms is None:
Expand Down Expand Up @@ -903,15 +903,19 @@ def _apply_substitution(self, splitter, perms):
print(len(splitter.H_orbits), len(splitter.G2_orbits), len(self.atom_sites))
self._subgroup_by_splitter(splitter)

site_ids = []
for site_id, site in enumerate(new_struc.atom_sites):
if site.specie in perms:
site_ids.append(site_id)
N = self.random_state.choice(range(1, len(site_ids))) if len(site_ids) > 1 else 1
sub_ids = self.random_state.choice(site_ids, N)
for sub_id in sub_ids:
key = new_struc.atom_sites[sub_id].specie
new_struc.atom_sites[sub_id].specie = perms[key]
# Create a list of tuples (site_id, original_specie) for sites that can be substituted
site_info = [
(site_id, site.specie) for site_id, site in enumerate(new_struc.atom_sites) if site.specie in perms
]

# TODO range isn't inclusive of end number is this intended?
N = self.random_state.choice(range(1, len(site_info))) if len(site_info) > 1 else 1
sub_indices = self.random_state.choice(len(site_info), N, replace=False)

for index in sub_indices:
site_id, original_specie = site_info[index]
new_struc.atom_sites[site_id].specie = perms[original_specie]

new_struc._get_formula()
return new_struc

Expand Down Expand Up @@ -1148,6 +1152,7 @@ def _get_formula(self):
from pyxtal.database.element import Element

formula = ""

if self.molecular:
numspecies = self.numMols
species = [str(mol) for mol in self.molecules]
Expand All @@ -1166,9 +1171,11 @@ def _get_formula(self):
numIons[i] = specie_list.count(sp)
self.numIons = numIons
numspecies = self.numIons

for i, s in zip(numspecies, species):
specie = Element(s).short_name if isinstance(s, str) else s
specie = Element(s).short_name if isinstance(s, int) else s
formula += f"{specie:s}{int(i):d}"

self.formula = formula

def get_zprime(self, integer=False):
Expand Down Expand Up @@ -2004,7 +2011,8 @@ def get_disps_single(self, ref_struc, trans, d_tol=1.2):
if dist < 0.3:
match = True
break
elif dist < 1.2 * d_tol:

if dist < 1.2 * d_tol:
ds.append(dist)
ids.append(i)
_disps.append(disp)
Expand Down Expand Up @@ -2347,9 +2355,7 @@ def get_transition_by_path(self, ref_struc, path, d_tol, d_tol2=0.5, N_images=2,
# resort sites_H based on elements0
seq = [elements1.index(x) for x in elements0]
sites_H = [sites_H[i] for i in seq]
numIons_H = []
for site in sites_H:
numIons_H.append(sum([int(l[:-1]) for l in site]))
numIons_H = [sum(int(l[:-1]) for l in site) for site in sites_H]

# enumerate all possible solutions
ids = []
Expand Down Expand Up @@ -2409,12 +2415,11 @@ def get_transition_by_path(self, ref_struc, path, d_tol, d_tol2=0.5, N_images=2,
match = False
break
# composition
else:
number = sum([int(l[:-1]) for l in site])
if number != numIons_H[i]:
# print("bad number", site, number, numIons_H[i])
match = False
break
number = sum(int(l[:-1]) for l in site)
if number != numIons_H[i]:
# print("bad number", site, number, numIons_H[i])
match = False
break
# if int(mult) == 2: print(path, _sites0, match)
# make subgroup
if match:
Expand Down Expand Up @@ -2550,10 +2555,7 @@ def get_neighboring_molecules(self, site_id=0, factor=1.5, max_d=5.0, ignore_E=T
comps.extend(comp)
engs.extend(eng)

if engs[0] is None: # sort by distance
ids = np.argsort(min_ds)
else: # sort by energy
ids = np.argsort(engs) # min_ds)
ids = np.argsort(min_ds) if engs[0] is None else np.argsort(engs)

neighs = [neighs[i] for i in ids]
comps = [comps[i] for i in ids]
Expand Down Expand Up @@ -2718,18 +2720,17 @@ def substitute_1_2(
for _xtal in _xtals:
if max_wp is not None and len(_xtal.atom_sites) > max_wp:
continue
else:
if new_struc_wo_energy(_xtal, xtals):
add = True
if new_struc_wo_energy(_xtal, xtals):
add = True

if criteria is not None and not _xtal.check_validity(criteria):
add = False
if criteria is not None and not _xtal.check_validity(criteria):
add = False

if add:
xtals.append(_xtal)
print("Add substitution", _xtal.get_xtal_string())
if len(xtals) == N_max:
break
if add:
xtals.append(_xtal)
print("Add substitution", _xtal.get_xtal_string())
if len(xtals) == N_max:
break
# print('Add {:d} substitutions in subgroup {:d}'.format(len(_xtals), sub.group.number))
else:
print(f"Good representation ({len(xtals):d})", self.get_xtal_string())
Expand Down Expand Up @@ -2761,10 +2762,7 @@ def _substitute_1_2(self, dicts, ratio=None): # , group_type='t', max_cell=4, m

A, [B, C] = next(iter(dicts.items()))

numbers = []
for site in self.atom_sites:
if site.specie == A:
numbers.append(site.wp.multiplicity)
numbers = [site.wp.multiplicity for site in self.atom_sites if site.specie == A]
solutions = split_list_by_ratio(numbers, ratio)

# Output all possible substitutions
Expand Down Expand Up @@ -3057,11 +3055,11 @@ def from_CSD(self, csd_code):
smi = entry.molecule.smiles
if smi is None:
raise CSDError("No smile from CSD")
elif len(smi) > 350:

if len(smi) > 350:
raise CSDError(f"long smile {smi:s}")
else:
if Chem.MolFromSmiles(smi) is None:
raise CSDError(f"problematic smiles: {smi:s}")
if Chem.MolFromSmiles(smi) is None:
raise CSDError(f"problematic smiles: {smi:s}")

cif = entry.to_string(format="cif")
smiles = [s + ".smi" for s in smi.split(".")]
Expand Down Expand Up @@ -3096,41 +3094,26 @@ def from_CSD(self, csd_code):
for ele in pmg.composition.elements:
if ele.symbol == "D":
pmg.replace_species({ele: Element("H")})
elif ele.value not in [
"C",
"Si",
"H",
"O",
"N",
"S",
"F",
"Cl",
"Br",
"I",
"P",
]:
elif ele.value not in "C Si H O N S F Cl Br I P".split(" "):
organic = False
break

if not organic:
msg = "Cannot handle the organometallic entry from CSD: "
msg += entry.formula
raise CSDError(msg)
else:
# print(smiles); self.from_seed(pmg, smiles)
try:
self.from_seed(pmg, smiles)
except ReadSeedError:
try:
# print(smiles)#; import sys; sys.exit()
self.from_seed(pmg, smiles)
except ReadSeedError:
try:
# print("Add_H=============================================")
self.from_seed(pmg, smiles, add_H=True)
except:
msg = f"unknown problems in Reading CSD {csd_code:s} {smi:s}"
raise CSDError(msg)
# print("Add_H=============================================")
self.from_seed(pmg, smiles, add_H=True)
except:
msg = f"unknown problems in Reading CSD {csd_code:s} {smi:s}"
raise CSDError(msg)
except:
msg = f"unknown problems in Reading CSD {csd_code:s} {smi:s}"
raise CSDError(msg)
self.source = "CSD: " + csd_code
else:
msg = csd_code + " does not have 3D structure"
Expand Down Expand Up @@ -3522,11 +3505,11 @@ def get_xtal_string(self, dicts=None, header=None):
# Similarity, energy, status
for key in dicts:
value = dicts[key]
if type(value) in [float, np.float64]:
if isinstance(value, (float, np.float64)):
strs += f" {value:8.3f} "
elif type(value) == str:
elif isinstance(value, str):
strs += f" {value:24s} "
elif type(value) == bool:
elif isinstance(value, bool):
strs += f" {value!s:5s} "

for s in self.atom_sites:
Expand Down
1 change: 1 addition & 0 deletions pyxtal/molecular_crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def _set_orientation(self, pyxtal_mol, pt, oris, wp):
"""
# Use a Wyckoff_site object for the current site
self.numattempts += 1
# NOTE removing this copy causes tests to fail -> state not managed well
ori = self.random_state.choice(oris).copy()
ori.change_orientation(flip=True)
ms0 = mol_site(pyxtal_mol, pt, ori, wp, self.lattice)
Expand Down
10 changes: 4 additions & 6 deletions pyxtal/optimize/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,10 @@ def summary(self, calc=None):
xtal = db.get_pyxtal(code)

c_info = row.data["charmm_info"]
prm = open(work_dir + "/pyxtal.prm", "w")
prm.write(c_info["prm"])
prm.close()
rtf = open(work_dir + "/pyxtal.rtf", "w")
rtf.write(c_info["rtf"])
rtf.close()
with open(work_dir + "/pyxtal.prm", "w") as prm:
prm.write(c_info["prm"])
with open(work_dir + "/pyxtal.rtf", "w") as rtf:
rtf.write(c_info["rtf"])
g_info = row.data["gulp_info"]

pmg = xtal.to_pymatgen()
Expand Down
19 changes: 9 additions & 10 deletions pyxtal/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def from_pyxtal(cls, struc, standard=True):
symmetry = [struc.atom_sites[0].wp.hall_number]
lat = struc.lattice.encode()
vector = [symmetry + lat]
for site in struc.atom_sites:
vector.append(site.encode())
vector.extend([site.encode() for site in struc.atom_sites])
x = vector
return cls(x)

Expand Down Expand Up @@ -504,8 +503,8 @@ def get_dist(self, rep):
print(rep3.to_string())

print("Test other cases")
string1 = "81 14.08 6.36 25.31 83.9 1 0 0.83 0.40 0.63 136.6 -21.6 -151.1 -101.1 -131.2 154.7 -176.4 -147.8 178.2 -179.1 -53.3 0"
string2 = "81 14.08 6.36 25.31 83.9 1 0 0.03 0.84 0.89 149.1 -8.0 -37.8 -39.9 -104.2 176.2 -179.6 137.8 -178.5 -173.3 -103.6 0"
string1 = "81 14.08 6.36 25.31 83.9 1 0 0.83 0.40 0.63 136.6 -21.6 -151.1 -101.1 -131.2 154.7 -176.4 -147.8 178.2 -179.1 -53.3 0" # noqa: E501
string2 = "81 14.08 6.36 25.31 83.9 1 0 0.03 0.84 0.89 149.1 -8.0 -37.8 -39.9 -104.2 176.2 -179.6 137.8 -178.5 -173.3 -103.6 0" # noqa: E501
smiles = ["CC1=CC=C(C=C1)S(=O)(=O)C2=C(N=C(S2)C3=CC=C(C=C3)NC(=O)OCC4=CC=CC=C4)C"]
rep4 = representation.from_string(string1, smiles)
rep5 = representation.from_string(string2, smiles)
Expand All @@ -523,12 +522,12 @@ def get_dist(self, rep):
print(xtal)
print(rep.to_pyxtal())
# strings = [
# "83 14.08 6.36 25.31 83.9 1 0.72 0.40 0.27 131.6 -17.0 -120.0 -83.8 -134.1 -174.5 -175.7 -168.8 173.9 178.0 -157.4 0",
# "81 14.08 6.36 25.31 83.9 1 0.59 0.81 0.39 -117.8 -50.1 -95.3 -25.8 -80.6 164.7 155.9 -124.9 -159.2 178.6 -154.7 0",
# "81 14.08 6.36 25.31 83.9 1 0.75 0.09 0.01 133.8 -19.5 -55.1 -86.7 -91.7 -175.0 -170.4 -176.8 173.3 -164.8 -58.4 0",
# "81 14.08 6.36 25.31 83.9 1 0.72 0.44 0.01 135.2 27.5 97.2 -101.1 -105.1 -29.7 -169.7 -50.1 172.2 -173.1 131.6 0",
# "82 14.00 6.34 25.26 83.6 1 0.21 0.08 0.54 146.0 -12.0 50.2 108.0 112.3 -166.3 -158.7 -35.5 172.3 -168.7 133.0 0",
# "81 14.08 6.36 25.31 83.9 1 0.05 0.30 0.89 -68.2 41.2 148.8 -66.9 -85.0 -167.4 172.3 -166.2 -178.3 166.4 -45.9 0",
# "83 14.08 6.36 25.31 83.9 1 0.72 0.40 0.27 131.6 -17.0 -120.0 -83.8 -134.1 -174.5 -175.7 -168.8 173.9 178.0 -157.4 0", # noqa: E501
# "81 14.08 6.36 25.31 83.9 1 0.59 0.81 0.39 -117.8 -50.1 -95.3 -25.8 -80.6 164.7 155.9 -124.9 -159.2 178.6 -154.7 0", # noqa: E501
# "81 14.08 6.36 25.31 83.9 1 0.75 0.09 0.01 133.8 -19.5 -55.1 -86.7 -91.7 -175.0 -170.4 -176.8 173.3 -164.8 -58.4 0", # noqa: E501
# "81 14.08 6.36 25.31 83.9 1 0.72 0.44 0.01 135.2 27.5 97.2 -101.1 -105.1 -29.7 -169.7 -50.1 172.2 -173.1 131.6 0", # noqa: E501
# "82 14.00 6.34 25.26 83.6 1 0.21 0.08 0.54 146.0 -12.0 50.2 108.0 112.3 -166.3 -158.7 -35.5 172.3 -168.7 133.0 0", # noqa: E501
# "81 14.08 6.36 25.31 83.9 1 0.05 0.30 0.89 -68.2 41.2 148.8 -66.9 -85.0 -167.4 172.3 -166.2 -178.3 166.4 -45.9 0", # noqa: E501
# ]

# import pymatgen.analysis.structure_matcher as sm
Expand Down
2 changes: 1 addition & 1 deletion pyxtal/supergroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def search_supergroup(self, d_tol=0.9, max_per_G=2500, max_solutions=None):
for idx, sols in self.solutions:
if len(sols) > max_per_G:
print("Warning: ignore some solutions: ", len(sols) - max_per_G)
sols = [sols[i] for i in self.random_state.choice(len(sols), max_per_G)] # noqa: PLW2901
sols = [sols[i] for i in self.random_state.choice(len(sols), max_per_G)]
# sols=[(['8c'], ['4a', '4b'], ['4b', '8c', '8c'])]

for _i, sol in enumerate(sols):
Expand Down
32 changes: 22 additions & 10 deletions pyxtal/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2897,13 +2897,15 @@ def choose_wyckoff(G, number=None, site=None, dim=3, random_state: int | None |
if random_state.random() > 0.5:
for wyckoff in wyckoffs_organized:
if len(wyckoff[0]) <= number:
return random_state.choice(wyckoff)
# NOTE wyckoff is a ragged list of lists
return wyckoff[random_state.choice(len(wyckoff))]
return False
else:
good_wyckoff = [w for wyckoff in wyckoffs_organized if len(wyckoff[0]) <= number for w in wyckoff]

if len(good_wyckoff) > 0:
return random_state.choice(good_wyckoff)
# NOTE good_wyckoff is a ragged list of lists
return good_wyckoff[random_state.choice(len(good_wyckoff))]
else:
return False

Expand Down Expand Up @@ -2948,17 +2950,27 @@ def choose_wyckoff_mol(

wyckoffs = G.wyckoffs_organized

def filter_valid_wyckoffs(wyckoffs, orientations, number):
if gen_site or np.random.random() > 0.5: # choose from high to low
for j, wyckoff in enumerate(wyckoffs):
if len(wyckoff[0]) <= number:
yield from (w for k, w in enumerate(wyckoff) if orientations[j][k])

if gen_site or random_state.random() > 0.5:
good_wyckoffs = list(filter_valid_wyckoffs(wyckoffs, orientations, number))
return random_state.choice(good_wyckoffs) if good_wyckoffs else False
good_wyckoffs = []
for k, w in enumerate(wyckoff):
if orientations[j][k] != []:
good_wyckoffs.append(w)
if len(good_wyckoffs) > 0:
return good_wyckoffs[random_state.choice(len(good_wyckoffs))]
return False
else:
good_wyckoffs = list(filter_valid_wyckoffs(wyckoffs, orientations, number))
return random_state.choice(good_wyckoffs) if good_wyckoffs else False
good_wyckoffs = []
for j, wyckoff in enumerate(wyckoffs):
if len(wyckoff[0]) <= number:
for k, w in enumerate(wyckoff):
if orientations[j][k] != []:
good_wyckoffs.append(w)
if len(good_wyckoffs) > 0:
return good_wyckoffs[random_state.choice(len(good_wyckoffs))]
else:
return False


# -------------------- quick utilities for symmetry conversion ----------------
Expand Down
Loading

0 comments on commit ba292c6

Please sign in to comment.