Skip to content

Commit

Permalink
Initial commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
GiantElephant123 committed May 17, 2022
0 parents commit 47673fd
Show file tree
Hide file tree
Showing 47 changed files with 9,079 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
*.pyc
*.egg-info
*.csv
*.npz
.idea
5 changes: 5 additions & 0 deletions DeepSolid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
161 changes: 161 additions & 0 deletions DeepSolid/base_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”).
# All Bytedance Modifications are Copyright 2022 Bytedance Inc.

import ml_collections
from ml_collections import config_dict


def default() -> ml_collections.ConfigDict:
"""Create set of default parameters for running qmc.py.
Note: placeholders (cfg.system.molecule and cfg.system.electrons) must be
replaced with appropriate values.
Returns:
ml_collections.ConfigDict containing default settings.
"""
# wavefunction output.
cfg = ml_collections.ConfigDict({
'batch_size': 100, # batch size
# Config module used. Should be set in get_config function as either the
# absolute module or relative to the configs subdirectory. Relative
# imports must start with a '.' (e.g. .atom). Do *not* override on
# command-line. Do *not* set using __name__ from inside a get_config
# function, as config_flags overrides this when importing the module using
# importlib.import_module.
'config_module': __name__,
'use_x64': True, # use float64 or 32
'optim': {
'iterations': 1000000, # number of iterations
'optimizer': 'kfac',
'local_energy_outlier_width': 5.0,
'lr': {
'rate': 5.e-2, # learning rate, different from the reported lr in FermiNet
# since DeepSolid energy gradient is not batch-size dependent
'decay': 1.0, # exponent of learning rate decay
'delay': 10000.0, # term that sets the scale of the rate decay
},
'clip_el': 5.0, # If not none, scale at which to clip local energy
'clip_type': 'real', # Clip real and imag part of gradient.
'gradient_clip': 5.0,
# ADAM hyperparameters. See optax documentation for details.
'adam': {
'b1': 0.9,
'b2': 0.999,
'eps': 1.e-8,
'eps_root': 0.0,
},
'kfac': {
'invert_every': 1,
'cov_update_every': 1,
'damping': 0.001,
'cov_ema_decay': 0.95,
'momentum': 0.0,
'momentum_type': 'regular',
# Warning: adaptive damping is not currently available.
'min_damping': 1.e-4,
'norm_constraint': 0.001,
'mean_center': True,
'l2_reg': 0.0,
'register_only_generic': False,
},
'ministeps': 1,
'laplacian_mode': 'for', # specify the laplacian evaluation mode, mode is one of 'for', 'partition' or 'hessian'
# 'for' mode calculates the laplacian of each electron one by one, which is slow but save GPU memory
# 'hessian' mode calculates the laplacian in a highly parallized mode, which is fast but require GPU memory
# 'partition' mode calculate the laplacian in a moderate way.
'partition_number': 3,
# Only used for 'partition' mode.
# partition_number must be divisivle by (dim * number of electrons). The smaller the faster, but requires more memory.
},
'log': {
'stats_frequency': 1, # iterations between logging of stats
'save_frequency': 10.0, # minutes between saving network params
'save_frequency_in_step': -1,
'save_path': '',
# specify the local save path
'restore_path': '',
# specify the restore path which contained saved Model parameters.
'local_energies': False,
'complex_polarization': False, # log polarization order parameter which is useful for hydrogen chain.
'structure_factor': False,
# return the strture factor S(k) at reciprocal lattices of supercell
# log S(k) requires a lot of storage space, be careful.
'stats_file_name': 'train_stats'
},
'system': {
'pyscf_cell': None, # simulation cell obj
'ndim': 3, #dimension of the system
'internal_cell': None,
},
'mcmc': {
# Note: HMC options are not currently used.
# Number of burn in steps after pretraining. If zero do not burn in
# or reinitialize walkers.
'burn_in': 100,
'steps': 20, # Number of MCMC steps to make between network updates.
# Width of (atom-centred) Gaussian used to generate initial electron
# configurations.
'init_width': 0.8,
# Width of Gaussian used for random moves for RMW or step size for
# HMC.
'move_width': 0.02,
# Number of steps after which to update the adaptive MCMC step size
'adapt_frequency': 100,
'init_means': (), # Not implemented in JAX.
# If true, scale the proposal width for each electron by the harmonic
# mean of the distance to the nuclei.
'importance_sampling': False,
# whether to use importance sampling in MCMC step, untested yet
# Metropolis sampling will be used if false
'one_electron': False
# If true, use one-electron moves, untested yet
},
'network': {
'detnet': {
'envelope_type': 'isotropic',
# only isotropic mode has been tested
'bias_orbitals': False,
'use_last_layer': False,
'full_det': False,
'hidden_dims': ((256, 32), (256, 32), (256, 32)),
'determinants': 8,
'after_determinants': 1,
},
'twist': (0.0, 0.0, 0.0), # Difine the twist of wavefunction,
# twists are given in terms of fractions of supercell reciprocal vectors
},
'debug': {
# Check optimizer state, parameters and loss and raise an exception if
# NaN is found.
'check_nan': False, # check whether the gradient contain nans before optimize, if True, retry.
'deterministic': False, # Use a deterministic seed.
},
'pretrain': {
'method': 'net', # Method is one of 'hf', 'net'.
'iterations': 1000,
'lr': 3e-4,
'steps': 1, #mcmc steps between each pretrain iterations
},
})

return cfg


def resolve(cfg):
cfg = cfg.copy_and_resolve_references()
return cfg
165 changes: 165 additions & 0 deletions DeepSolid/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”).
# All Bytedance Modifications are Copyright 2022 Bytedance Inc.

import datetime
import os
from typing import Optional
import zipfile

from absl import logging
import jax
import numpy as np


def get_restore_path(restore_path: Optional[str] = None) -> Optional[str]:
"""Gets the path containing checkpoints from a previous calculation.
Args:
restore_path: path to checkpoints.
Returns:
The path or None if restore_path is falsy.
"""
if restore_path:
ckpt_restore_path = restore_path
else:
ckpt_restore_path = None
return ckpt_restore_path


def find_last_checkpoint(ckpt_path: Optional[str] = None) -> Optional[str]:
"""Finds most recent valid checkpoint in a directory.
Args:
ckpt_path: Directory containing checkpoints.
Returns:
Last QMC checkpoint (ordered by sorting all checkpoints by name in reverse)
or None if no valid checkpoint is found or ckpt_path is not given or doesn't
exist. A checkpoint is regarded as not valid if it cannot be read
successfully using np.load.
"""
if ckpt_path and os.path.exists(ckpt_path):
files = [f for f in os.listdir(ckpt_path) if 'qmcjax_ckpt_' in f]
# Handle case where last checkpoint is corrupt/empty.
for file in sorted(files, reverse=True):
fname = os.path.join(ckpt_path, file)
with open(fname, 'rb') as f:
try:
np.load(f, allow_pickle=True)
return fname
except (OSError, EOFError, zipfile.BadZipFile):
logging.info('Error loading checkpoint %s. Trying next checkpoint...',
fname)
return None


def create_save_path(save_path: Optional[str],) -> str:
"""Creates the directory for saving checkpoints, if it doesn't exist.
Args:
save_path: directory to use. If false, create a directory in the working
directory based upon the current time.
Returns:
Path to save checkpoints to.
"""
timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
default_save_path = os.path.join(os.getcwd(), f'DeepSolid_{timestamp}')
ckpt_save_path = save_path or default_save_path


if ckpt_save_path and not os.path.isdir(ckpt_save_path):
os.makedirs(ckpt_save_path)

return ckpt_save_path


def save(save_path: str, t: int, data, params, opt_state, mcmc_width,
remote_save_path: Optional[int] = None) -> str:
"""Saves checkpoint information to a npz file.
Args:
save_path: path to directory to save checkpoint to. The checkpoint file is
save_path/qmcjax_ckpt_$t.npz, where $t is the number of completed
iterations.
t: number of completed iterations.
data: MCMC walker configurations.
params: pytree of network parameters.
opt_state: optimization state.
mcmc_width: width to use in the MCMC proposal distribution.
Returns:
path to checkpoint file.
"""
ckpt_filename = os.path.join(save_path, f'qmcjax_ckpt_{t:06d}.npz')
logging.info('Saving checkpoint %s', ckpt_filename)
with open(ckpt_filename, 'wb') as f:
np.savez(
f,
t=t,
data=data,
params=params,
opt_state=opt_state,
mcmc_width=mcmc_width)

return ckpt_filename


def restore(restore_filename: str, batch_size: Optional[int] = None, shape_check=True):
"""Restores data saved in a checkpoint.
Args:
restore_filename: filename containing checkpoint.
batch_size: total batch size to be used. If present, check the data saved in
the checkpoint is consistent with the batch size requested for the
calculation.
Returns:
(t, data, params, opt_state, mcmc_width) tuple, where
t: number of completed iterations.
data: MCMC walker configurations.
params: pytree of network parameters.
opt_state: optimization state.
mcmc_width: width to use in the MCMC proposal distribution.
Raises:
ValueError: if the leading dimension of data does not match the number of
devices (i.e. the number of devices being parallelised over has changed) or
if the total batch size is not equal to the number of MCMC configurations in
data.
"""
logging.info('Loading checkpoint %s', restore_filename)
with open(restore_filename, 'rb') as f:
ckpt_data = np.load(f, allow_pickle=True)
# Retrieve data from npz file. Non-array variables need to be converted back
# to natives types using .tolist().
t = ckpt_data['t'].tolist() + 1 # Return the iterations completed.
data = ckpt_data['data']
params = ckpt_data['params'].tolist()
opt_state = ckpt_data['opt_state'].tolist()
mcmc_width = ckpt_data['mcmc_width'].tolist()
if shape_check:
if data.shape[0] != jax.local_device_count():
raise ValueError(
'Incorrect number of devices found. Expected {}, found {}.'.format(
data.shape[0], jax.local_device_count()))
if batch_size and data.shape[0] * data.shape[1] != batch_size:
raise ValueError(
'Wrong batch size in loaded data. Expected {}, found {}.'.format(
batch_size, data.shape[0] * data.shape[1]))
return t, data, params, opt_state, mcmc_width
36 changes: 36 additions & 0 deletions DeepSolid/config/diamond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
from pyscf.pbc import gto

from DeepSolid import base_config
from DeepSolid import supercell
from DeepSolid.utils import units


def get_config(input_str):
X, Y, L_Ang, S, basis= input_str.split(',')
S = np.eye(3) * int(S)
cfg = base_config.default()
L_Ang = float(L_Ang)
L_Bohr = units.angstrom2bohr(L_Ang)

# Set up cell
cell = gto.Cell()
cell.atom = [[X, [0.0, 0.0, 0.0]],
[Y, [0.25 * L_Bohr, 0.25 * L_Bohr, 0.25 * L_Bohr]]]

cell.basis = basis
cell.a = (np.ones((3, 3)) - np.eye(3)) * L_Bohr / 2
cell.unit = "B"
cell.verbose = 5
cell.exp_to_discard = 0.1
cell.build()
simulation_cell = supercell.get_supercell(cell, S)
cfg.system.pyscf_cell = simulation_cell

return cfg
Loading

0 comments on commit 47673fd

Please sign in to comment.