Skip to content

Commit

Permalink
Merge pull request #3 from AllanChain/update-with-force
Browse files Browse the repository at this point in the history
Add tri feature which is needed for force calculation
  • Loading branch information
GiantElephant123 authored Jun 7, 2024
2 parents 9767071 + f38db65 commit b1abaf5
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 73 deletions.
1 change: 1 addition & 0 deletions DeepSolid/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def default() -> ml_collections.ConfigDict:
'hidden_dims': ((256, 32), (256, 32), (256, 32)),
'determinants': 8,
'after_determinants': 1,
'distance_type': 'nu',
},
'twist': (0.0, 0.0, 0.0), # Difine the twist of wavefunction,
# twists are given in terms of fractions of supercell reciprocal vectors
Expand Down
140 changes: 68 additions & 72 deletions DeepSolid/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,19 @@ def enforce_pbc(latvec, epos):


def init_solid_fermi_net_params(
key: jnp.ndarray,
data,
atoms: jnp.ndarray,
spins: Tuple[int, int],
envelope_type: str = 'full',
bias_orbitals: bool = False,
use_last_layer: bool = False,
eps: float = 0.01,
full_det: bool = True,
hidden_dims: FermiLayers = ((256, 32), (256, 32), (256, 32)),
determinants: int = 16,
after_determinants: Union[int, Tuple[int, ...]] = 1,
key: jnp.ndarray,
data,
atoms: jnp.ndarray,
spins: Tuple[int, int],
envelope_type: str = 'full',
bias_orbitals: bool = False,
use_last_layer: bool = False,
eps: float = 0.01,
full_det: bool = True,
hidden_dims: FermiLayers = ((256, 32), (256, 32), (256, 32)),
determinants: int = 16,
after_determinants: Union[int, Tuple[int, ...]] = 1,
distance_type='nu',
):
"""Initializes parameters for the Fermionic Neural Network.
Expand Down Expand Up @@ -107,7 +108,13 @@ def init_solid_fermi_net_params(
del data

natom = atoms.shape[0]
in_dims = (natom * 4, 4)
if distance_type == 'nu':
in_dims = (natom * 4, 4)
elif distance_type == 'tri':
in_dims = (natom * 7, 7)
else:
raise ValueError('Unrecognized distance function.')

active_spin_channels = [spin for spin in spins if spin > 0]
nchannels = len(active_spin_channels)
# The input to layer L of the one-electron stream is from
Expand Down Expand Up @@ -179,38 +186,6 @@ def init_solid_fermi_net_params(
return params


def construct_input_features(
x: jnp.ndarray,
atoms: jnp.ndarray,
ndim: int = 3) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Constructs inputs to Fermi Net from raw electron and atomic positions.
Args:
x: electron positions. Shape (nelectrons*ndim,).
atoms: atom positions. Shape (natoms, ndim).
ndim: dimension of system. Change only with caution.
Returns:
ae, ee, r_ae, r_ee tuple, where:
ae: atom-electron vector. Shape (nelectron, natom, 3).
ee: atom-electron vector. Shape (nelectron, nelectron, 3).
r_ae: atom-electron distance. Shape (nelectron, natom, 1).
r_ee: electron-electron distance. Shape (nelectron, nelectron, 1).
The diagonal terms in r_ee are masked out such that the gradients of these
terms are also zero.
"""

assert atoms.shape[1] == ndim
ae = jnp.reshape(x, [-1, 1, ndim]) - atoms[None, ...]
ee = jnp.reshape(x, [1, -1, ndim]) - jnp.reshape(x, [-1, 1, ndim])

r_ae = jnp.linalg.norm(ae, axis=2, keepdims=True)
# Avoid computing the norm of zero, as is has undefined grad
n = ee.shape[0]
r_ee = (
jnp.linalg.norm(ee + jnp.eye(n)[..., None], axis=-1) * (1.0 - jnp.eye(n)))

return ae, ee, r_ae, r_ee[..., None]


def scaled_f(w):
"""
see Phys. Rev. B 94, 035157
Expand All @@ -229,23 +204,6 @@ def scaled_g(w):
return w * (1 - 3. / 2. * jnp.abs(w / jnp.pi) + 1. / 2. * jnp.abs(w / jnp.pi) ** 2)


def sin_relative_distance(xea, a, b):
'''
:param xea:
:param a:
:param b:
:return: periodic relative distance [ne, na, 6]
'''
w = jnp.einsum('...ijk,lk->...ijl', xea, b)
mod = (w + jnp.pi) // (2 * jnp.pi)
w = (w - mod * 2 * jnp.pi)
# w = jnp.mod(w + jnp.pi, 2 * jnp.pi) - jnp.pi
# r1 = jnp.einsum('...i,ij->...j', scaled_f(w), a)
r2 = jnp.einsum('...i,ij->...j', scaled_g(w), a)
return r2


def nu_distance(xea, a, b):
"""
see Phys. Rev. B 94, 035157
Expand All @@ -265,11 +223,36 @@ def nu_distance(xea, a, b):
sd = result ** 0.5
return sd, rel


def tri_distance(xea, a, b):
"""
see Phys. Rev. Lett. 130, 036401 (2023).
:param xea: relative distance between electrons and atoms
:param a: lattice vectors of primitive cell divided by 2\pi.
:param b: reciprocal vectors of primitive cell.
:return: periodic generalized relative and absolute distance of xea.
"""
w = jnp.einsum('...ijk,lk->...ijl', xea, b)
sg = jnp.sin(w)
cg = jnp.cos(w)
rel_sin = jnp.einsum('...i,ij->...j', sg, a)
rel_cos = jnp.einsum('...i,ij->...j', cg, a)
rel = jnp.concatenate([rel_sin, rel_cos], axis=-1)
metric = jnp.einsum('ij,kj->ik', a, a)
vector_sin = sg[..., :, None] * sg[..., None, :]
vector_cos = (1-cg[..., :, None]) * (1-cg[..., None, :])
vector = vector_cos + vector_sin
sd = jnp.einsum('...ij,ij->...', vector, metric) ** 0.5
return sd, rel


def construct_periodic_input_features(
x: jnp.ndarray,
atoms: jnp.ndarray,
simulation_cell=None,
ndim: int = 3) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
x: jnp.ndarray,
atoms: jnp.ndarray,
simulation_cell=None,
ndim: int = 3,
distance_type: str = 'nu',
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Constructs a periodic generalized inputs to Fermi Net from raw electron and atomic positions.
see Phys. Rev. B 94, 035157
Args:
Expand All @@ -285,14 +268,21 @@ def construct_periodic_input_features(
The diagonal terms in r_ee are masked out such that the gradients of these
terms are also zero.
"""
if distance_type == 'nu':
distance_func = nu_distance
elif distance_type == 'tri':
distance_func = tri_distance
else:
raise ValueError('Unrecognized distance function.')

primitive_cell = simulation_cell.original_cell
x = x.reshape(-1, ndim)
n = x.shape[0]
prim_x, _ = enforce_pbc(primitive_cell.a, x)

# prim_xea = minimal_imag.dist_i(atoms.ravel(), prim_x.ravel())
prim_xea = prim_x[..., None, :] - atoms
prim_periodic_sea, prim_periodic_xea = nu_distance(prim_xea,
prim_periodic_sea, prim_periodic_xea = distance_func(prim_xea,
primitive_cell.AV,
primitive_cell.BV)
prim_periodic_sea = prim_periodic_sea[..., None]
Expand All @@ -301,7 +291,7 @@ def construct_periodic_input_features(
# sim_xee = sim_minimal_imag.dist_matrix(sim_x.ravel())
sim_xee = sim_x[:, None, :] - sim_x[None, :, :]

sim_periodic_see, sim_periodic_xee = nu_distance(sim_xee + jnp.eye(n)[..., None],
sim_periodic_see, sim_periodic_xee = distance_func(sim_xee + jnp.eye(n)[..., None],
simulation_cell.AV,
simulation_cell.BV)
sim_periodic_see = sim_periodic_see * (1.0 - jnp.eye(n))
Expand Down Expand Up @@ -474,7 +464,8 @@ def solid_fermi_net_orbitals(params, x,
atoms=None,
spins=(None, None),
envelope_type=None,
full_det=False):
full_det=False,
distance_type='nu'):
"""Forward evaluation of the Solid Neural Network up to the orbitals.
Args:
params: A dictionary of parameters, containing fields:
Expand Down Expand Up @@ -506,9 +497,9 @@ def solid_fermi_net_orbitals(params, x,
envelope, depending on the envelope type.
"""

ae_, ee_, r_ae, r_ee = construct_periodic_input_features(x, atoms,
simulation_cell=simulation_cell,
)
ae_, ee_, r_ae, r_ee = construct_periodic_input_features(
x, atoms, simulation_cell=simulation_cell, distance_type=distance_type
)
ae = jnp.concatenate((r_ae, ae_), axis=2)
ae = jnp.reshape(ae, [jnp.shape(ae)[0], -1])
ee = jnp.concatenate((r_ee, ee_), axis=2)
Expand Down Expand Up @@ -576,6 +567,7 @@ def eval_func(params, x,
spins=(None, None),
envelope_type='full',
full_det=False,
distance_type='nu',
method_name='eval_slogdet'):
'''
generates the wavefunction of simulation cell.
Expand All @@ -597,6 +589,7 @@ def eval_func(params, x,
atoms=atoms,
spins=spins,
envelope_type=envelope_type,
distance_type=distance_type,
full_det=full_det)
if method_name == 'eval_slogdet':
_, result = logdet_matmul(orbitals)
Expand All @@ -623,6 +616,7 @@ def make_solid_fermi_net(
hidden_dims: FermiLayers = ((256, 32), (256, 32), (256, 32)),
determinants: int = 16,
after_determinants: Union[int, Tuple[int, ...]] = 1,
distance_type='nu',
method_name='eval_logdet',
):
'''
Expand Down Expand Up @@ -655,6 +649,7 @@ def make_solid_fermi_net(
hidden_dims=hidden_dims,
determinants=determinants,
after_determinants=after_determinants,
distance_type=distance_type,
)
network = functools.partial(
eval_func,
Expand All @@ -664,6 +659,7 @@ def make_solid_fermi_net(
spins=simulation_cell.nelec,
envelope_type=envelope_type,
full_det=full_det,
distance_type=distance_type,
method_name=method_name,
)
method.init = init
Expand Down
1 change: 0 additions & 1 deletion bin/deepsolid
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# modified from FermiNet:https://github.com/deepmind/ferminet

import sys
import os

from absl import app
from absl import flags
Expand Down

0 comments on commit b1abaf5

Please sign in to comment.