Skip to content

Commit

Permalink
curator ver1.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Yangxinsix committed Apr 26, 2024
1 parent 8225729 commit ddce17b
Show file tree
Hide file tree
Showing 44 changed files with 946 additions and 56,355 deletions.
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
# <div align="center">CURATOR: Building Robust Machine Learning Potentials for Atomistic Simulations</div>
This package implements an autonomous active learning workflow for construction of equivarient Machine-learned interatomic potentials (MLIPs). In this workflow you can choose between three architechtures of message passing neural networks (MPNN): [PAINN](https://arxiv.org/abs/2102.03150), [NequiP](https://arxiv.org/abs/2101.03164) or [MACE](https://arxiv.org/abs/2206.07697).
This package implements an autonomous active learning workflow for construction of equivarient Machine-learned interatomic potentials (MLIPs). In this workflow you can choose between three architechtures of message passing neural networks (MPNN): [PAINN](https://arxiv.org/abs/2102.03150), [NequiP](https://arxiv.org/abs/2101.03164) or [MACE](https://arxiv.org/abs/2206.07697).

To acquire a more accurate MLIP batch active learning is used. By first simulating a particular structure, batch active learning picks out the most diverse and uncertain structures to be labelled. The learned features or gradients in the model are used for active learning. Several selection methods are implemented.
All the active learning codes are to be tested.
The labelled structures are added to the dataset to train a more accurate MLIP.

Before training your MLIP you need to acquire an initial dataset consisting of atomic structures. That can be a collection of molecular dynamic (MD) simulation trajectories, ionic steps from an atomistic optimization, nudged elastic band (NEB) and much more. The important thing is to keep the level of theory consistent for all data meaning you need to use the same density funtional theory (DFT) calcululator or at least one with similar level of theory. When you have acquired your initial data set, you need to combine all your data to one big trajectory file saved using [ASE](https://iopscience.iop.org/article/10.1088/1361-648X/aa680e) trajectory format (example: "database.traj").

## <div align="center">Documentation</div>
## <div align="left">Documentation</div>
The code itself is well documented and a working example is presented. A more indepth documentation has yet to be made. You can find all hyperparameters for the workflow and how they work in our default configuration folder.

## <div align="center">Quick Start</div>
## <div align="left">Quick Start</div>


## <div align="center">How to install</div>
## <div align="left">How to install</div>

This code is only tested on [**Python>=3.8.0**](https://www.python.org/) and [**PyTorch>=2.0**](https://pytorch.org/get-started/locally/).
Requirements: [PyTorch Lightning](https://lightning.ai/), [ASE](https://wiki.fysik.dtu.dk/ase/index.html),
[hydra](https://hydra.cc/), [myqueue](https://myqueue.readthedocs.io/en/latest/installation.html)(if you want to submit jobs automatically).

```
$ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
$ python3 -m pip install myqueue==22.7.1
$ conda install pytorch pytorch-cuda=11.8 -c pytorch -c nvidia
$ git clone https://gitlab.gbar.dtu.dk/swano/Curator.git
$ cd Curator
$ python pip install .
$ pip install .
```

## <div align="center">How to use</div>
## <div align="left">How to use</div>
A working example is presented in `/example` where you will model the diffusivity of LiFePO4 using both MD simulation and NEB.
First you download the curator package as described above. Then you create a directory somewhere. You then need to copy the user configuration script `user_cfg.yaml` , the inital dataset `init_dataset.traj`, the MD simulation trajectories `LiFePO4_MD_0.traj`,`LiFePO4_MD_1.traj`,`LiFePO4_MD_4.traj`, and the initial and final images for the NEB `NEB_init_pristine.traj` and `NEB_final_pristine.traj`(You can also optimize these NEB structures yourself if you want you). You need to change the datapaths in the user configuration file such that it matches your current directory. To run the workflow you need to have a myqueue configuration folder and file `/.myqueue/config.py`. It can also be downloaded from the example case, but it should be customized to your HPC or local computer. To run the workflow on your HPC please change the nodename and cores in `user_cfg.yaml` for each task. To run the workflow you either need to copy the workflow script `curator-workflow` from the exmaple folder into the same diretcory as `user_cfg.yaml` or locate the path to the script in `Curator/scripts`. You then write `mq workflow curator-workflow` in the terminal and the workflow will starts. A more illustrative example and video tutorial will be published soon.

There are a couple of thing to note. First if you want to run [VASP](https://www.vasp.at/) in the labeling script you need to load a license version or else we recommend you to use [GPAW](https://wiki.fysik.dtu.dk/gpaw/). Secondly, in the end of each iteration you need to add the data to the initial dataset your self. Thirdly, if you do not want to train your model from scratch in the next iteration you should use the load_model paramater in `user_cfg.yaml` to load the previous iteration's model

If you want to dig into the code you can find all the workfing functions in `Curator/curator/cli.py` and to understand how the data was generated for the example case you can go to `Curator/example/Datageneration`
If you want to dig into the code you can find all the working functions in `Curator/curator/cli.py` and to understand how the data was generated for the example case you can go to `Curator/example/Datageneration`

## <div align="left">Reference</div>
If you are using the code for building MLIPs, please cite:
https://chemrxiv.org/engage/chemrxiv/article-details/65cd6a5366c1381729ab0854
160 changes: 87 additions & 73 deletions curator/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from omegaconf import DictConfig, OmegaConf, open_dict
import sys, os, json
from pathlib import Path
from .utils import read_user_config
from .utils import read_user_config, CustomFormatter
import logging
import socket
import contextlib

# very ugly solution for solving pytorch lighting and myqueue conflictions
if "SLURM_NTASKS" in os.environ:
Expand All @@ -16,23 +17,8 @@
del os.environ["SLURM_JOB_NAME"]

# Set up logger for the different tasks
log = logging.getLogger(__name__)

# Set up Early stopping for pytorch training
class EarlyStopping():
def __init__(self, patience=5, min_delta=0):

self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.early_stop = False

def __call__(self, val_loss, best_loss):
if val_loss - best_loss > self.min_delta:
self.counter +=1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
log = logging.getLogger('curator')
log.setLevel(logging.DEBUG)

# Trainining with Pytorch Lightning (only with weights and biasses)
@hydra.main(config_path="configs", config_name="train", version_base=None)
Expand All @@ -45,54 +31,58 @@ def train(config: DictConfig) -> None:
None
"""
import torch
from pytorch_lightning import (
LightningDataModule,
Trainer,
)
from curator.model import LitNNP
from e3nn.util.jit import script

# set up logger
fh = logging.FileHandler(os.path.join(config.run_path, "training.log"), mode="w")
fh.setFormatter(CustomFormatter())
log.addHandler(fh)

# Load the arguments
if config.cfg is not None:
config = read_user_config(config.cfg, config_path="configs", config_name="train")

# Save yaml file in run_path
OmegaConf.resolve(config)
OmegaConf.save(config, f"{config.run_path}/config.yaml", resolved=True)
log.info("Running on host: " + str(socket.gethostname()))
OmegaConf.save(config, f"{config.run_path}/config.yaml", resolve=True)
log.debug("Running on host: " + str(socket.gethostname()))

# Set up seed
if "seed" in config:
log.info(f"Seed with <{config.seed}>")
log.debug(f"Seed with <{config.seed}>")
else:
log.info("Seed randomly...")
log.debug("Seed randomly...")

# Initiate the datamodule
log.info(f"Instantiating datamodule <{config.data._target_}>")
log.debug(f"Instantiating datamodule <{config.data._target_}> from dataset {config.data.datapath}")
if not os.path.isfile(config.data.datapath):
raise RuntimeError("Please provide valid data path!")
datamodule: LightningDataModule = hydra.utils.instantiate(config.data)

# Initiate the model
log.info(f"Instantiating model <{config.model._target_}>")
log.debug(f"Instantiating model <{config.model._target_}>")
model = hydra.utils.instantiate(config.model)
log.debug(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,d}")

# Initiate the task and load old model, if any
log.info(f"Instantiating task <{config.task._target_}>")
log.debug(f"Instantiating task <{config.task._target_}>")
task: LitNNP = hydra.utils.instantiate(config.task, model=model)
if config.model_path is not None:
log.info(f"Loading trained model from {config.model_path}")
log.debug(f"Loading trained model from {config.model_path}")
task = LitNNP.load_from_checkpoint(checkpoint_path=config.model_path, model=model)
# Save extra arguments in checkpoint
task.save_configuration(config)

# Initiate the training
log.info(f"Instantiating trainer <{config.trainer._target_}>")
log.debug(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(config.trainer)

# Train the model
log.info("Starting training.")
trainer.fit(model=task, datamodule=datamodule)

# Deploy model to a compiled model
Expand All @@ -109,12 +99,13 @@ def train(config: DictConfig) -> None:
model_path = model_path[index]

# Compile the model
log.info(f"Deploy trained model from {model_path} with validation loss of {val_loss:.3f}")
task = LitNNP.load_from_checkpoint(checkpoint_path=f"{model_path}", model=model)
outputs = torch.load(model_path)['outputs']
log.debug(f"Deploy trained model from {model_path} with validation loss of {val_loss:.3f}")
task = LitNNP.load_from_checkpoint(checkpoint_path=f"{model_path}", model=model, outputs=outputs)
model_compiled = script(task.model)
metadata = {"cutoff": str(model_compiled.representation.cutoff).encode("ascii")}
model_compiled.save(f"{config.run_path}/compiled_model.pt", _extra_files=metadata)
log.info(f"Deploying compiled model at <{config.run_path}/compiled_model.pt>")
log.debug(f"Deploying compiled model at <{config.run_path}/compiled_model.pt>")

# Training without Pytorch Lightning
@hydra.main(config_path="configs", config_name="train", version_base=None)
Expand All @@ -131,15 +122,15 @@ def tmp_train(config: DictConfig):
import torch
from e3nn.util.jit import script
from torch_ema import ExponentialMovingAverage
from .utils import EarlyStopping


# Load the arguments
if config.cfg is not None:
config = read_user_config(config.cfg, config_path="configs", config_name="train")

# Save yaml file in run_path
resolved = OmegaConf.to_container(config, resolve=True)
OmegaConf.save(resolved, f"{config.run_path}/config.yaml")
OmegaConf.save(config, f"{config.run_path}/config.yaml", resolve=True)
log.info("Running on host: " + str(socket.gethostname()))

# Set up the seed
Expand All @@ -161,8 +152,8 @@ def tmp_train(config: DictConfig):
raise RuntimeError("Please provide valid data path!")
datamodule = instantiate(config.data)
datamodule.setup()
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
train_loader = datamodule.train_dataloader
val_loader = datamodule.val_dataloader

# Set up the model, the optimizer and the scheduler
model = instantiate(config.model)
Expand All @@ -183,26 +174,34 @@ def tmp_train(config: DictConfig):
epoch = config.trainer.max_epochs
best_val_loss = torch.inf
prev_loss = None

rescale_layers = []
for layer in model.output_modules:
if hasattr(layer, "unscale"):
rescale_layers.append(layer)
# Start training
for e in range(epoch):
# train
model.train()
for i, batch in enumerate(train_loader):
# Initialize the batch, targets and loss
batch = {k: v.to(config.device) for k, v in batch.items()}
targets = {
output.target_property: batch[output.target_property]
for output in outputs
}
atoms_dict = {k: v for k, v in batch.items() if k not in targets}
optimizer.zero_grad()

unscaled_targets = targets.copy()
unscaled_targets.update(atoms_dict)
for layer in rescale_layers:
unscaled_targets = layer.unscale(unscaled_targets, force_process=True)
pred = model(batch)
loss = 0.0

# Calculate the loss and metrics
metrics = {}
for output in outputs:
tmp_loss = output.calculate_loss(targets, pred)
tmp_loss, _ = output.calculate_loss(unscaled_targets, pred)
metrics[f"{output.target_property}_loss"] = tmp_loss.detach()
loss += tmp_loss

Expand All @@ -213,47 +212,65 @@ def tmp_train(config: DictConfig):
ema.update()

# Log the training metrics
scaled_pred = pred.copy()
scaled_pred.update(atoms_dict)
for layer in rescale_layers:
scaled_pred = layer.scale(scaled_pred, force_process=True)
if i % config.trainer.log_every_n_steps == 0:
for output in outputs:
for k, v in output.metrics['train'].items():
metrics[f"{output.name}_{k}"] = v(pred[output.name], targets[output.name]).detach()
metrics[f"{output.name}_{k}"] = v(scaled_pred[output.name], targets[output.name]).detach()
log_outputs = ",".join([f"{k}: {v:8.3f} " for k, v in metrics.items()])
log.info(f"Training step: {i} {log_outputs}")

# validation for each epoch
model.eval()
metrics = {}
a_counts = 0
s_counts = 0
for i, batch in enumerate(val_loader):
# Initialize the batch, targets and loss
batch = {k: v.to(config.device) for k, v in batch.items()}
targets = {
output.target_property: batch[output.target_property]
for output in outputs
}

a = batch["forces"].shape[0]
s = batch["energy"].shape[0]
a_counts += a
s_counts += s
pred = model(batch)
# calculate loss
for output in outputs:
tmp_loss = output.calculate_loss(targets, pred).detach()
# metrics
if i == 0:
metrics[f"{output.target_property}_loss"] = tmp_loss
else:
metrics[f"{output.target_property}_loss"] += tmp_loss
if config.task.use_ema:
cm = ema.average_parameters()
else:
cm = contextlib.nullcontext()
with cm:
for i, batch in enumerate(val_loader):
# Initialize the batch, targets and loss
batch = {k: v.to(config.device) for k, v in batch.items()}
targets = {
output.target_property: batch[output.target_property]
for output in outputs
}
atoms_dict = {k: v for k, v in batch.items() if k not in targets}

a = batch["forces"].shape[0]
s = batch["energy"].shape[0]
a_counts += a
s_counts += s
pred = model(batch)
unscaled_targets, unscaled_pred = targets.copy(), pred.copy()
unscaled_pred.update(atoms_dict)
unscaled_targets.update(atoms_dict)
for layer in rescale_layers:
unscaled_targets = layer.unscale(unscaled_targets, force_process=True)
unscaled_pred = layer.unscale(unscaled_pred, force_process=True)

for k, v in output.metrics['train'].items():
m = v(pred[output.name], targets[output.name]).detach()
if "rmse" in k:
m = m ** 2
# calculate loss
for output in outputs:
tmp_loss = output.calculate_loss(unscaled_targets, unscaled_pred, return_num_obs=False).detach()
# metrics
if i == 0:
metrics[f"{output.name}_{k}"] = m
metrics[f"{output.target_property}_loss"] = tmp_loss
else:
metrics[f"{output.name}_{k}"] += m
metrics[f"{output.target_property}_loss"] += tmp_loss

for k, v in output.metrics['train'].items():
m = v(pred[output.name], targets[output.name]).detach()
if "rmse" in k:
m = m ** 2
if i == 0:
metrics[f"{output.name}_{k}"] = m
else:
metrics[f"{output.name}_{k}"] += m

# postprocess validation metrics
for k in metrics:
Expand Down Expand Up @@ -384,8 +401,7 @@ def simulate(config: DictConfig):
config = read_user_config(config.cfg, config_path="configs", config_name="simulate")

# Save yaml file in run_path
resolved = OmegaConf.to_container(config, resolve=True)
OmegaConf.save(resolved, f"{config.run_path}/config.yaml")
OmegaConf.save(config, f"{config.run_path}/config.yaml", resolve=True)

# set logger
log.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -473,8 +489,7 @@ def select(config: DictConfig):
config = read_user_config(config.cfg, config_path="configs", config_name="select")

# Save yaml file in run_path
resolved = OmegaConf.to_container(config, resolve=True)
OmegaConf.save(resolved, f"{config.run_path}/config.yaml")
OmegaConf.save(config, f"{config.run_path}/config.yaml", resolve=True)
log.info("Running on host: " + str(socket.gethostname()))

# Set up the seed
Expand Down Expand Up @@ -565,8 +580,7 @@ def label(config: DictConfig):
config = read_user_config(config.cfg, config_path="configs", config_name="label")

# Save yaml file in run_path
resolved = OmegaConf.to_container(config, resolve=True)
OmegaConf.save(resolved, f"{config.run_path}/config.yaml")
OmegaConf.save(config, f"{config.run_path}/config.yaml", resolve=True)
log.info("Running on host: " + str(socket.gethostname()))

# Set up dataframe and load possible converged data id's
Expand Down
2 changes: 2 additions & 0 deletions curator/configs/data/custom.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ test_batch_size: 8
num_train: 0.9
num_val: 0.1
num_test: null
train_val_split: random
shuffle: true
num_workers: 1
pin_memory: True
species: auto
Expand Down
6 changes: 6 additions & 0 deletions curator/configs/model/repr_params/nequip_params.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# @package _global_
task:
optimizer:
amsgrad: True
lr: 0.005
weight_decay: 0.0
4 changes: 4 additions & 0 deletions curator/configs/model/repr_params/painn_params.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# @package _global_
task:
optimizer:
lr: 0.0005
Loading

0 comments on commit ddce17b

Please sign in to comment.