-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 47673fd
Showing
47 changed files
with
9,079 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
*.pyc | ||
*.egg-info | ||
*.csv | ||
*.npz | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) ByteDance, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright 2020 DeepMind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). | ||
# All Bytedance Modifications are Copyright 2022 Bytedance Inc. | ||
|
||
import ml_collections | ||
from ml_collections import config_dict | ||
|
||
|
||
def default() -> ml_collections.ConfigDict: | ||
"""Create set of default parameters for running qmc.py. | ||
Note: placeholders (cfg.system.molecule and cfg.system.electrons) must be | ||
replaced with appropriate values. | ||
Returns: | ||
ml_collections.ConfigDict containing default settings. | ||
""" | ||
# wavefunction output. | ||
cfg = ml_collections.ConfigDict({ | ||
'batch_size': 100, # batch size | ||
# Config module used. Should be set in get_config function as either the | ||
# absolute module or relative to the configs subdirectory. Relative | ||
# imports must start with a '.' (e.g. .atom). Do *not* override on | ||
# command-line. Do *not* set using __name__ from inside a get_config | ||
# function, as config_flags overrides this when importing the module using | ||
# importlib.import_module. | ||
'config_module': __name__, | ||
'use_x64': True, # use float64 or 32 | ||
'optim': { | ||
'iterations': 1000000, # number of iterations | ||
'optimizer': 'kfac', | ||
'local_energy_outlier_width': 5.0, | ||
'lr': { | ||
'rate': 5.e-2, # learning rate, different from the reported lr in FermiNet | ||
# since DeepSolid energy gradient is not batch-size dependent | ||
'decay': 1.0, # exponent of learning rate decay | ||
'delay': 10000.0, # term that sets the scale of the rate decay | ||
}, | ||
'clip_el': 5.0, # If not none, scale at which to clip local energy | ||
'clip_type': 'real', # Clip real and imag part of gradient. | ||
'gradient_clip': 5.0, | ||
# ADAM hyperparameters. See optax documentation for details. | ||
'adam': { | ||
'b1': 0.9, | ||
'b2': 0.999, | ||
'eps': 1.e-8, | ||
'eps_root': 0.0, | ||
}, | ||
'kfac': { | ||
'invert_every': 1, | ||
'cov_update_every': 1, | ||
'damping': 0.001, | ||
'cov_ema_decay': 0.95, | ||
'momentum': 0.0, | ||
'momentum_type': 'regular', | ||
# Warning: adaptive damping is not currently available. | ||
'min_damping': 1.e-4, | ||
'norm_constraint': 0.001, | ||
'mean_center': True, | ||
'l2_reg': 0.0, | ||
'register_only_generic': False, | ||
}, | ||
'ministeps': 1, | ||
'laplacian_mode': 'for', # specify the laplacian evaluation mode, mode is one of 'for', 'partition' or 'hessian' | ||
# 'for' mode calculates the laplacian of each electron one by one, which is slow but save GPU memory | ||
# 'hessian' mode calculates the laplacian in a highly parallized mode, which is fast but require GPU memory | ||
# 'partition' mode calculate the laplacian in a moderate way. | ||
'partition_number': 3, | ||
# Only used for 'partition' mode. | ||
# partition_number must be divisivle by (dim * number of electrons). The smaller the faster, but requires more memory. | ||
}, | ||
'log': { | ||
'stats_frequency': 1, # iterations between logging of stats | ||
'save_frequency': 10.0, # minutes between saving network params | ||
'save_frequency_in_step': -1, | ||
'save_path': '', | ||
# specify the local save path | ||
'restore_path': '', | ||
# specify the restore path which contained saved Model parameters. | ||
'local_energies': False, | ||
'complex_polarization': False, # log polarization order parameter which is useful for hydrogen chain. | ||
'structure_factor': False, | ||
# return the strture factor S(k) at reciprocal lattices of supercell | ||
# log S(k) requires a lot of storage space, be careful. | ||
'stats_file_name': 'train_stats' | ||
}, | ||
'system': { | ||
'pyscf_cell': None, # simulation cell obj | ||
'ndim': 3, #dimension of the system | ||
'internal_cell': None, | ||
}, | ||
'mcmc': { | ||
# Note: HMC options are not currently used. | ||
# Number of burn in steps after pretraining. If zero do not burn in | ||
# or reinitialize walkers. | ||
'burn_in': 100, | ||
'steps': 20, # Number of MCMC steps to make between network updates. | ||
# Width of (atom-centred) Gaussian used to generate initial electron | ||
# configurations. | ||
'init_width': 0.8, | ||
# Width of Gaussian used for random moves for RMW or step size for | ||
# HMC. | ||
'move_width': 0.02, | ||
# Number of steps after which to update the adaptive MCMC step size | ||
'adapt_frequency': 100, | ||
'init_means': (), # Not implemented in JAX. | ||
# If true, scale the proposal width for each electron by the harmonic | ||
# mean of the distance to the nuclei. | ||
'importance_sampling': False, | ||
# whether to use importance sampling in MCMC step, untested yet | ||
# Metropolis sampling will be used if false | ||
'one_electron': False | ||
# If true, use one-electron moves, untested yet | ||
}, | ||
'network': { | ||
'detnet': { | ||
'envelope_type': 'isotropic', | ||
# only isotropic mode has been tested | ||
'bias_orbitals': False, | ||
'use_last_layer': False, | ||
'full_det': False, | ||
'hidden_dims': ((256, 32), (256, 32), (256, 32)), | ||
'determinants': 8, | ||
'after_determinants': 1, | ||
}, | ||
'twist': (0.0, 0.0, 0.0), # Difine the twist of wavefunction, | ||
# twists are given in terms of fractions of supercell reciprocal vectors | ||
}, | ||
'debug': { | ||
# Check optimizer state, parameters and loss and raise an exception if | ||
# NaN is found. | ||
'check_nan': False, # check whether the gradient contain nans before optimize, if True, retry. | ||
'deterministic': False, # Use a deterministic seed. | ||
}, | ||
'pretrain': { | ||
'method': 'net', # Method is one of 'hf', 'net'. | ||
'iterations': 1000, | ||
'lr': 3e-4, | ||
'steps': 1, #mcmc steps between each pretrain iterations | ||
}, | ||
}) | ||
|
||
return cfg | ||
|
||
|
||
def resolve(cfg): | ||
cfg = cfg.copy_and_resolve_references() | ||
return cfg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# Copyright 2020 DeepMind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). | ||
# All Bytedance Modifications are Copyright 2022 Bytedance Inc. | ||
|
||
import datetime | ||
import os | ||
from typing import Optional | ||
import zipfile | ||
|
||
from absl import logging | ||
import jax | ||
import numpy as np | ||
|
||
|
||
def get_restore_path(restore_path: Optional[str] = None) -> Optional[str]: | ||
"""Gets the path containing checkpoints from a previous calculation. | ||
Args: | ||
restore_path: path to checkpoints. | ||
Returns: | ||
The path or None if restore_path is falsy. | ||
""" | ||
if restore_path: | ||
ckpt_restore_path = restore_path | ||
else: | ||
ckpt_restore_path = None | ||
return ckpt_restore_path | ||
|
||
|
||
def find_last_checkpoint(ckpt_path: Optional[str] = None) -> Optional[str]: | ||
"""Finds most recent valid checkpoint in a directory. | ||
Args: | ||
ckpt_path: Directory containing checkpoints. | ||
Returns: | ||
Last QMC checkpoint (ordered by sorting all checkpoints by name in reverse) | ||
or None if no valid checkpoint is found or ckpt_path is not given or doesn't | ||
exist. A checkpoint is regarded as not valid if it cannot be read | ||
successfully using np.load. | ||
""" | ||
if ckpt_path and os.path.exists(ckpt_path): | ||
files = [f for f in os.listdir(ckpt_path) if 'qmcjax_ckpt_' in f] | ||
# Handle case where last checkpoint is corrupt/empty. | ||
for file in sorted(files, reverse=True): | ||
fname = os.path.join(ckpt_path, file) | ||
with open(fname, 'rb') as f: | ||
try: | ||
np.load(f, allow_pickle=True) | ||
return fname | ||
except (OSError, EOFError, zipfile.BadZipFile): | ||
logging.info('Error loading checkpoint %s. Trying next checkpoint...', | ||
fname) | ||
return None | ||
|
||
|
||
def create_save_path(save_path: Optional[str],) -> str: | ||
"""Creates the directory for saving checkpoints, if it doesn't exist. | ||
Args: | ||
save_path: directory to use. If false, create a directory in the working | ||
directory based upon the current time. | ||
Returns: | ||
Path to save checkpoints to. | ||
""" | ||
timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | ||
default_save_path = os.path.join(os.getcwd(), f'DeepSolid_{timestamp}') | ||
ckpt_save_path = save_path or default_save_path | ||
|
||
|
||
if ckpt_save_path and not os.path.isdir(ckpt_save_path): | ||
os.makedirs(ckpt_save_path) | ||
|
||
return ckpt_save_path | ||
|
||
|
||
def save(save_path: str, t: int, data, params, opt_state, mcmc_width, | ||
remote_save_path: Optional[int] = None) -> str: | ||
"""Saves checkpoint information to a npz file. | ||
Args: | ||
save_path: path to directory to save checkpoint to. The checkpoint file is | ||
save_path/qmcjax_ckpt_$t.npz, where $t is the number of completed | ||
iterations. | ||
t: number of completed iterations. | ||
data: MCMC walker configurations. | ||
params: pytree of network parameters. | ||
opt_state: optimization state. | ||
mcmc_width: width to use in the MCMC proposal distribution. | ||
Returns: | ||
path to checkpoint file. | ||
""" | ||
ckpt_filename = os.path.join(save_path, f'qmcjax_ckpt_{t:06d}.npz') | ||
logging.info('Saving checkpoint %s', ckpt_filename) | ||
with open(ckpt_filename, 'wb') as f: | ||
np.savez( | ||
f, | ||
t=t, | ||
data=data, | ||
params=params, | ||
opt_state=opt_state, | ||
mcmc_width=mcmc_width) | ||
|
||
return ckpt_filename | ||
|
||
|
||
def restore(restore_filename: str, batch_size: Optional[int] = None, shape_check=True): | ||
"""Restores data saved in a checkpoint. | ||
Args: | ||
restore_filename: filename containing checkpoint. | ||
batch_size: total batch size to be used. If present, check the data saved in | ||
the checkpoint is consistent with the batch size requested for the | ||
calculation. | ||
Returns: | ||
(t, data, params, opt_state, mcmc_width) tuple, where | ||
t: number of completed iterations. | ||
data: MCMC walker configurations. | ||
params: pytree of network parameters. | ||
opt_state: optimization state. | ||
mcmc_width: width to use in the MCMC proposal distribution. | ||
Raises: | ||
ValueError: if the leading dimension of data does not match the number of | ||
devices (i.e. the number of devices being parallelised over has changed) or | ||
if the total batch size is not equal to the number of MCMC configurations in | ||
data. | ||
""" | ||
logging.info('Loading checkpoint %s', restore_filename) | ||
with open(restore_filename, 'rb') as f: | ||
ckpt_data = np.load(f, allow_pickle=True) | ||
# Retrieve data from npz file. Non-array variables need to be converted back | ||
# to natives types using .tolist(). | ||
t = ckpt_data['t'].tolist() + 1 # Return the iterations completed. | ||
data = ckpt_data['data'] | ||
params = ckpt_data['params'].tolist() | ||
opt_state = ckpt_data['opt_state'].tolist() | ||
mcmc_width = ckpt_data['mcmc_width'].tolist() | ||
if shape_check: | ||
if data.shape[0] != jax.local_device_count(): | ||
raise ValueError( | ||
'Incorrect number of devices found. Expected {}, found {}.'.format( | ||
data.shape[0], jax.local_device_count())) | ||
if batch_size and data.shape[0] * data.shape[1] != batch_size: | ||
raise ValueError( | ||
'Wrong batch size in loaded data. Expected {}, found {}.'.format( | ||
batch_size, data.shape[0] * data.shape[1])) | ||
return t, data, params, opt_state, mcmc_width |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) ByteDance, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
from pyscf.pbc import gto | ||
|
||
from DeepSolid import base_config | ||
from DeepSolid import supercell | ||
from DeepSolid.utils import units | ||
|
||
|
||
def get_config(input_str): | ||
X, Y, L_Ang, S, basis= input_str.split(',') | ||
S = np.eye(3) * int(S) | ||
cfg = base_config.default() | ||
L_Ang = float(L_Ang) | ||
L_Bohr = units.angstrom2bohr(L_Ang) | ||
|
||
# Set up cell | ||
cell = gto.Cell() | ||
cell.atom = [[X, [0.0, 0.0, 0.0]], | ||
[Y, [0.25 * L_Bohr, 0.25 * L_Bohr, 0.25 * L_Bohr]]] | ||
|
||
cell.basis = basis | ||
cell.a = (np.ones((3, 3)) - np.eye(3)) * L_Bohr / 2 | ||
cell.unit = "B" | ||
cell.verbose = 5 | ||
cell.exp_to_discard = 0.1 | ||
cell.build() | ||
simulation_cell = supercell.get_supercell(cell, S) | ||
cfg.system.pyscf_cell = simulation_cell | ||
|
||
return cfg |
Oops, something went wrong.