CrystalFormer is a transformer-based autoregressive model specifically designed for space group-controlled generation of crystalline materials. The space group symmetry significantly simplifies the crystal space, which is crucial for data and compute efficient generative modeling of crystalline materials. paper
The model is an autoregressive transformer for the space group conditioned crystal probability distribution P(C|g) = P (W_1 | ... ) P ( A_1 | ... ) P(X_1| ...) P(W_2|...) ... P(L| ...)
, where
g
: space group number 1-230W
: Wyckoff letter ('a', 'b',...,'A')A
: atom type ('H', 'He', ..., 'Og')X
: factional coordinatesL
: lattice vector [a,b,c, alpha, beta, gamma]P(W_i| ...)
andP(A_i| ...)
are categorical distributuions.P(X_i| ...)
is the mixture of von Mises distribution.P(L| ...)
is the mixture of Gaussian distribution.
We only consider symmetry inequivalent atoms. The remaining atoms are restored based on the space group and Wyckoff letter information. Note that there is a natural alphabetical ordering for the Wyckoff letters, starting with 'a' for a position with the site-symmetry group of maximal order and ending with the highest letter for the general position. The sampling procedure starts from higher symmetry sites (with smaller multiplicities) and then goes on to lower symmetry ones (with larger multiplicities). Only for the cases where discrete Wyckoff letters can not fully determine the structure, one needs to further consider factional coordinates in the loss or sampling.
Machine: autodl-L20, Miniconda / conda3 / python 3.10 / ubuntu 22.04 / cuda 11.8
Fork the repo, so that you can change it as you want.
If you want to use my modification, just clone this repo.
Clone the repo openlam, which is modified from here.
Run the following command to setup the enviroment:
conda init
source /etc/network_turbo #alternative, speedup command for autodl machine
conda activate jax
conda create -y -n jax -c "nvidia/label/cuda-12.6.0" cuda python=3.10 virtualenv pip
ssh-keygen
cat ~/.ssh/id_rsa.pub #copy the public key to the ssh key setting in the github setting page
git clone [email protected]:your_name/CrystalFormer.git #clone the forked repo through ssh url, so that you can modify the code as you want
cd CrystalFormer
python -m pip install --upgrade pip
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple #change to a faster source of pip according to your location
pip install --upgrade "jax[cuda12]"
cd ../
git conle [email protected]:your_name/openlam.git
cd openlam
pip install .
pip install ".[dp]"
pip install ".[mace]"
cd ../CrystalFormer
pip install -r requirements.txt
The original repo release the weights of the model trained on the MP-20 dataset. More details can be seen in the their page.
I also trained a model in an A40 machine using the same MP-20 dataset and default settings. The model and training log can be found in here.
training setting:
adam optimizer
bs: 100
lr: 0.0001
decay: 0
clip: 1
A: 119
W: 28
N: 21
lamd a: 1
lamd w: 1
lamd l: 1
Nf: 5
Kx: 16
Kl: 4
h0: 256
layer: 16
H: 16
k: 64
m: 64
e: 32
drop: 0.5
[optional] the training step will auto do this, just list here to show my modification
In the original repo, the input data is saved as csv files. The training script will read the csv files and then convert the cif strings to the strandard input format (G, L, XYZ, A, W).
Change code in 'crystalformer/src/utils.py':
def process_one(cif, atom_types, wyck_types, n_max, tol=0.01)
to
process_one(cif, atom_types, wyck_types, n_max, tol=0.001)
To speed up the training if one want to re-training the model using different parameters with the same data, one can use the preprocess script to save the standard input format to a npz file.
To preprocess the data, run the following command:
python crystalformer/data/preprocess.py input_file_name max_atom_in_cell
input_file_name: name of inputs files
max_atom_in_cell: the maximum number of atoms in a cell
Note:
The format of argv1 can be saved as 'tar.gz', 'zip', 'csv' or 'path contains tar.gz'.
The 'tar.gz', 'zip' should be composed of cif files. The 'csv' file should be contains a column named ['cif'].
Output:
the command will give 2 output files in the input path:
jsonl: contains the raw cif text and the xtal info (G, L, XYZ, A, W) extracted from pymatgen
npz: array format (G, L, XYZ, A, W) which is used to train the model
python ./main.py --train_path data/mp_20/train.csv --valid_path data/mp_20/val.csv
train_path
: the path to the training dataset, it can be 'csv', 'tar.gz' or 'zip' filesvalid_path
: the path to the validation dataset, it can be 'csv', 'tar.gz' or 'zip' files
In jax, the parallel running in a node with multi-gpus can be achieved by funciton 'pmap'.
We add a new file train_parallel.py
in crystalformer/src/
to achieve the parallel training logic.
The parallel progress cannot use bool type value in the model, so the attention.py
and transformer.py
are also changed accordingly.
To run the parallel training, just add the '--parallel 1' option:
python main.py --parallel 1 --train_path data/mp_20/train.npz --valid_path data/mp_20/val.npz --test_path data/mp_20/test.npz
python ./main.py --optimizer none --restore_path model/epoch_005200.pkl --spacegroup 1 --num_samples 1000 --batchsize 1000 --temperature 1.0
optimizer
: the optimizer to use,none
means no training, only samplingrestore_path
: the path to the model weightsspacegroup
: the space group number to sample, can be choose from 0-230, 0 means sample all labels withnum_samples
, 1-230 is sampling for a specific space group.num_samples
: the number of samples to generatebatchsize
: the batch size for samplingtemperature
: the temperature for sampling
You can also use the elements
to sample the specific element. For example, --elements La Ni O
will sample the structure with La, Ni, and O atoms. The sampling results will be saved in the output_LABEL.csv
file, where the LABEL
is the space group number g
specified in the command --spacegroup
.
The input for the elements
can be also the json
file which specifies the atom mask in each Wyckoff site and the constraints. An example atoms.json
file can be seen in the data folder. There are two keys in the atoms.json
file:
atom_mask
: set the atom list for each Wyckoff position, the element can only be selected from the list in the corresponding Wyckoff positionconstraints
: set the constraints for the Wyckoff sites in the sampling, you can specify the pair of Wyckoff sites that should have the same elements
Note1
If use parallel training, the sampling also need to add the option '--parallel 1'
Note2
The sample code will also do the evaluation. It will first convert the (G,A,X,L,W) to cif strings and then check its structure (atoms are not too close) and compositional (charge balance) validity. The cifs will stored at model/cifs
Note3
To eval the from energy for a cif folder, just run:
python scripts/form_energy_eval.py
If this warning shows up, find the solution here.
The NVIDIA driver's CUDA version is 11.7 which is older than the ptxas CUDA
version (11.8.89). Because the driver is older than the ptxas version, XLA is
disabling parallel compilation, which may slow down compilation. You should
update your NVIDIA driver or use the NVIDIA-provided CUDA forward
compatibility packages.
If you find this repo is useful to your study, please cite the original paper
@misc{cao2024space,
title={Space Group Informed Transformer for Crystalline Materials Generation},
author={Zhendong Cao and Xiaoshan Luo and Jian Lv and Lei Wang},
year={2024},
eprint={2403.15734},
archivePrefix={arXiv},
primaryClass={cond-mat.mtrl-sci}
}
and our study:
@misc{crystalformer_exploring,
author = {Bingzhi, Li},
title = {CrystalFormer Exploring},
year = {2024},
publisher = {GitHub},
journal = {GitHub Repository},
howpublished = {Accessed: \url{https://github.com/Graph4HEP/CrystalFormer}},
}
Note: This project is unrelated to https://github.com/omron-sinicx/crystalformer with the same name.