Skip to content

Commit

Permalink
Add necessary comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
GiantElephant123 committed Oct 8, 2022
1 parent 47673fd commit d8a25f6
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 36 deletions.
17 changes: 12 additions & 5 deletions DeepSolid/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def make_complex_polarization(simulation_cell: pyscf.pbc.gto.Cell,
direction: int = 0,
ndim=3):
'''
the order parameter which is used to specify the hydrogen chain
:param simulation_cell:
:param direction:
generates the order parameter function of hydrogen chain.
:param simulation_cell: pyscf object of simulation cell.
:param direction: projection direction of electrons
:param ndim:
:return:
:return:the order parameter
'''

rec_vec = simulation_cell.reciprocal_vectors()[direction]
Expand All @@ -45,6 +45,13 @@ def complex_polarization(data):
def make_structure_factor(simulation_cell: pyscf.pbc.gto.Cell,
nq=4,
ndim=3):
'''
generates the structure factor function which is used for finite size error reduction.
see PHYSICAL REVIEW B 94, 035126 (2016) for details.
:param simulation_cell: pyscf object of simulation cell.
:param nq: number of sampled crystal momentum in each direction.
:return:the structure factor.
'''
mesh_grid = jnp.meshgrid(*[jnp.array(range(0, nq)) for _ in range(3)])
point_list = jnp.stack([m.ravel() for m in mesh_grid], axis=0).T
rec_vec = simulation_cell.reciprocal_vectors()
Expand All @@ -57,7 +64,7 @@ def structure_factor(data):
"""
:param data: electron walkers with shape [batch, ne * ndim]
:return: complex polarization with shape []
:return: structure factor with shape []
"""
leading_shape = list(data.shape[:-1])
data = data.reshape(leading_shape + [-1, ndim])
Expand Down
44 changes: 34 additions & 10 deletions DeepSolid/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
def local_kinetic_energy(f):
'''
holomorphic mode, which seems dangerous since many op don't support complex number now.
:param f:
:return:
:param f: function return the logdet of wavefunction
:return: local kinetic energy
'''
def _lapl_over_f(params, x):
ne = x.shape[-1]
Expand All @@ -44,9 +44,9 @@ def _body_fun(i, val):

def local_kinetic_energy_real_imag(f):
'''
evaluate real and imaginary part of laplacian, which is slower than holomorphic mode but is much safer.
:param f:
:return:
evaluate real and imaginary part of laplacian.
:param f: function return the logdet of wavefunction
:return: local kinetic energy
'''
def _lapl_over_f(params, x):
ne = x.shape[-1]
Expand All @@ -71,6 +71,11 @@ def _body_fun(i, val):


def local_kinetic_energy_real_imag_dim_batch(f):
'''
evaluate real and imaginary part of laplacian, in which vamp is used to accelerate.
:param f: function return the logdet of wavefunction
:return: local kinetic energy
'''

def _lapl_over_f(params, x):
ne = x.shape[-1]
Expand Down Expand Up @@ -99,8 +104,8 @@ def _body_fun(dummy_eye):
def local_kinetic_energy_real_imag_hessian(f):
'''
Use jax.hessian to evaluate laplacian, which requires huge amount of memory.
:param f:
:return:
:param f: function return the logdet of wavefunction
:return: local kinetic energy
'''
def _lapl_over_f(params, x):
ne = x.shape[-1]
Expand All @@ -122,9 +127,10 @@ def _lapl_over_f(params, x):
def local_kinetic_energy_partition(f, partition_number=3):
'''
Try to parallelize the evaluation of laplacian
:param f:
:param partition_number:
:return:
:param f: bfunction return the logdet of wavefunction
:param partition_number: partition_number must be divisivle by (dim * number of electrons).
The smaller the faster, but requires more memory.
:return: local kinetic energy
'''
vjvp = jax.vmap(jax.jvp, in_axes=(None, None, 0))

Expand Down Expand Up @@ -155,6 +161,11 @@ def _body_fun(val, e):


def local_ewald_energy(simulation_cell):
"""
generate local energy of ewald part.
:param simulation_cell:
:return:
"""
ewald = ewaldsum.EwaldSum(simulation_cell)
assert jnp.allclose(simulation_cell.energy_nuc(),
(ewald.ion_ion + ewald.ii_const),
Expand All @@ -181,6 +192,19 @@ def _local_energy(params, x):


def local_energy_seperate(f, simulation_cell, mode='for', partition_number=3):
"""
genetate the local energy function.
:param f: function return the logdet of wavefunction.
:param simulation_cell: pyscf object of simulation cell.
:param mode: specify the evaluation style of local energy.
'for' mode calculates the laplacian of each electron one by one, which is slow but save GPU memory
'hessian' mode calculates the laplacian in a highly parallized mode, which is fast but require GPU memory
'partition' mode calculate the laplacian in a moderate way.
:param partition_number: Only used if 'partition' mode is employed.
partition_number must be divisivle by (dim * number of electrons).
The smaller the faster, but requires more memory.
:return: the local energy function.
"""

if mode == 'for':
ke_ri = local_kinetic_energy_real_imag(f)
Expand Down
25 changes: 25 additions & 0 deletions DeepSolid/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def __init__(self, cell, twist=np.ones(3)*0.5):
# self.init_scf()

def init_scf(self):
"""
initialization function to set up HF ansatz.
"""
self.klist = []
for s, key in enumerate(self.coeff_key):
mclist = []
Expand All @@ -101,6 +104,12 @@ def init_scf(self):
zip(self.kmf.kpts, self.k_split[self.coeff_key[s]])]))

def eval_orbitals_pbc(self, coord, eval_str="GTOval_sph"):
"""
eval the atomic orbital valus of HF.
:param coord: electron walkers with shape [batch, ne * ndim].
:param eval_str:
:return: atomic orbital valus of HF.
"""
prim_coord, wrap = distance.np_enforce_pbc(self.primitive_cell.a, coord.reshape([coord.shape[0], -1]))
prim_coord = prim_coord.reshape([-1, 3])
wrap = wrap.reshape([-1, 3])
Expand All @@ -113,12 +122,23 @@ def eval_orbitals_pbc(self, coord, eval_str="GTOval_sph"):
return ao

def eval_mos_pbc(self, aos, s):
"""
eval the molecular orbital values.
:param aos: atomic orbital values.
:param s: spin index.
:return: molecular orbital values.
"""
c = self.coeff_key[s]
p = np.split(self.parameters[c], self.param_split[c], axis=-1)
mo = [ao.dot(p[k]) for k, ao in enumerate(aos)]
return np.concatenate(mo, axis=-1)

def eval_orb_mat(self, coord):
"""
eval the orbital matrix of HF.
:param coord: electron walkers with shape [batch, ne * ndim].
:return: orbital matrix of HF.
"""
batch, nelec, ndim = coord.shape
aos = self.eval_orbitals_pbc(coord)
aos_shape = (self.ns_tol, batch, nelec, -1)
Expand All @@ -133,6 +153,11 @@ def eval_orb_mat(self, coord):
return mos

def eval_slogdet(self, coord):
"""
eval the slogdet of HF
:param coord: electron walkers with shape [batch, ne * ndim].
:return: slogdet of HF.
"""
mos = self.eval_orb_mat(coord)
slogdets = [np.linalg.slogdet(mo) for mo in mos]
phase, slogdet = list(map(lambda x, y: [x[0] * y[0], x[1] + y[1]], *zip(slogdets)))[0]
Expand Down
106 changes: 105 additions & 1 deletion DeepSolid/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,21 @@ def construct_input_features(
x: jnp.ndarray,
atoms: jnp.ndarray,
ndim: int = 3) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Constructs inputs to Fermi Net from raw electron and atomic positions.
Args:
x: electron positions. Shape (nelectrons*ndim,).
atoms: atom positions. Shape (natoms, ndim).
ndim: dimension of system. Change only with caution.
Returns:
ae, ee, r_ae, r_ee tuple, where:
ae: atom-electron vector. Shape (nelectron, natom, 3).
ee: atom-electron vector. Shape (nelectron, nelectron, 3).
r_ae: atom-electron distance. Shape (nelectron, natom, 1).
r_ee: electron-electron distance. Shape (nelectron, nelectron, 1).
The diagonal terms in r_ee are masked out such that the gradients of these
terms are also zero.
"""

assert atoms.shape[1] == ndim
ae = jnp.reshape(x, [-1, 1, ndim]) - atoms[None, ...]
ee = jnp.reshape(x, [1, -1, ndim]) - jnp.reshape(x, [-1, 1, ndim])
Expand All @@ -197,10 +212,20 @@ def construct_input_features(


def scaled_f(w):
"""
see Phys. Rev. B 94, 035157
:param w: projection of position vectors on reciprocal vectors.
:return: function f in the ref.
"""
return jnp.abs(w) * (1 - jnp.abs(w / jnp.pi) ** 3 / 4.)


def scaled_g(w):
"""
see Phys. Rev. B 94, 035157
:param w: projection of position vectors on reciprocal vectors.
:return: function g in the ref.
"""
return w * (1 - 3. / 2. * jnp.abs(w / jnp.pi) + 1. / 2. * jnp.abs(w / jnp.pi) ** 2)


Expand All @@ -222,6 +247,13 @@ def sin_relative_distance(xea, a, b):


def nu_distance(xea, a, b):
"""
see Phys. Rev. B 94, 035157
:param xea: relative distance between electrons and atoms
:param a: lattice vectors of primitive cell divided by 2\pi.
:param b: reciprocal vectors of primitive cell.
:return: periodic generalized relative and absolute distance of xea.
"""
w = jnp.einsum('...ijk,lk->...ijl', xea, b)
mod = (w + jnp.pi) // (2 * jnp.pi)
w = (w - mod * 2 * jnp.pi)
Expand All @@ -238,6 +270,21 @@ def construct_periodic_input_features(
atoms: jnp.ndarray,
simulation_cell=None,
ndim: int = 3) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Constructs a periodic generalized inputs to Fermi Net from raw electron and atomic positions.
see Phys. Rev. B 94, 035157
Args:
x: electron positions. Shape (nelectrons*ndim,).
atoms: atom positions. Shape (natoms, ndim).
ndim: dimension of system. Change only with caution.
Returns:
ae, ee, r_ae, r_ee tuple, where:
ae: atom-electron vector. Shape (nelectron, natom, 3).
ee: atom-electron vector. Shape (nelectron, nelectron, 3).
r_ae: atom-electron distance. Shape (nelectron, natom, 1).
r_ee: electron-electron distance. Shape (nelectron, nelectron, 1).
The diagonal terms in r_ee are masked out such that the gradients of these
terms are also zero.
"""
primitive_cell = simulation_cell.original_cell
x = x.reshape(-1, ndim)
n = x.shape[0]
Expand Down Expand Up @@ -267,6 +314,20 @@ def construct_periodic_input_features(

def construct_symmetric_features(h_one: jnp.ndarray, h_two: jnp.ndarray,
spins: Tuple[int, int]) -> jnp.ndarray:
"""Combines intermediate features from rank-one and -two streams.
Args:
h_one: set of one-electron features. Shape: (nelectrons, n1), where n1 is
the output size of the previous layer.
h_two: set of two-electron features. Shape: (nelectrons, nelectrons, n2),
where n2 is the output size of the previous layer.
spins: number of spin-up and spin-down electrons.
Returns:
array containing the permutation-equivariant features: the input set of
one-electron features, the mean of the one-electron features over each
(occupied) spin channel, and the mean of the two-electron features over each
(occupied) spin channel. Output shape (nelectrons, 3*n1 + 2*n2) if there are
both spin-up and spin-down electrons and (nelectrons, 2*n1, n2) otherwise.
"""
# Split features into spin up and spin down electrons
h_ones = jnp.split(h_one, spins[0:1], axis=0)
h_twos = jnp.split(h_two, spins[0:1], axis=0)
Expand Down Expand Up @@ -414,6 +475,36 @@ def solid_fermi_net_orbitals(params, x,
spins=(None, None),
envelope_type=None,
full_det=False):
"""Forward evaluation of the Solid Neural Network up to the orbitals.
Args:
params: A dictionary of parameters, containing fields:
`single`: a list of dictionaries with params 'w' and 'b', weights for the
one-electron stream of the network.
`double`: a list of dictionaries with params 'w' and 'b', weights for the
two-electron stream of the network.
`orbital`: a list of two weight matrices, for spin up and spin down (no
bias is necessary as it only adds a constant to each row, which does
not change the determinant).
`dets`: weight on the linear combination of determinants
`envelope`: a dictionary with fields `sigma` and `pi`, weights for the
multiplicative envelope.
x: The input data, a 3N dimensional vector.
simulation_cell: pyscf object of simulation cell.
klist: Tuple with occupied k points of the spin up and spin down electrons
in simulation cell.
spins: Tuple with number of spin up and spin down electrons.
envelope_type: a string that specifies kind of envelope. One of:
`isotropic`: envelope is the same in every direction
full_det: If true, the determinants are dense, rather than block-sparse.
True by default, false is still available for backward compatibility.
Thus, the output shape of the orbitals will be (ndet, nalpha+nbeta,
nalpha+nbeta) if True, and (ndet, nalpha, nalpha) and (ndet, nbeta, nbeta)
if False.
Returns:
One (two matrices if full_det is False) that exchange columns under the
exchange of inputs, and additional variables that may be needed by the
envelope, depending on the envelope type.
"""

ae_, ee_, r_ae, r_ee = construct_periodic_input_features(x, atoms,
simulation_cell=simulation_cell,
Expand Down Expand Up @@ -486,6 +577,19 @@ def eval_func(params, x,
envelope_type='full',
full_det=False,
method_name='eval_slogdet'):
'''
generates the wavefunction of simulation cell.
:param params: parameter dict
:param x: The input data, a 3N dimensional vector.
:param simulation_cell: pyscf object of simulation cell.
:param klist: Tuple with occupied k points of the spin up and spin down electrons
in simulation cell.
:param atoms: array of atom positions in the primitive cell.
:param spins: Tuple with number of spin up and spin down electrons.
:param full_det: specify the mode of wavefunction, spin diagonalized or not.
:param method_name: specify the returned function of wavefunction
:return: required wavefunction
'''

orbitals, to_env = solid_fermi_net_orbitals(params, x,
klist=klist,
Expand Down Expand Up @@ -522,7 +626,7 @@ def make_solid_fermi_net(
method_name='eval_logdet',
):
'''
generates the wavefunction of simulation cell.
:param envelope_type: specify envelope
:param bias_orbitals: whether to contain bias in the last layer of orbitals
:param use_last_layer: wheter to use two-electron feature in the last layer
Expand Down
Loading

0 comments on commit d8a25f6

Please sign in to comment.