Skip to content

Commit

Permalink
Made the total charge gradients optional.
Browse files Browse the repository at this point in the history
  • Loading branch information
jintuzhang committed Feb 7, 2024
1 parent 61c1787 commit 62718c4
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,7 @@ def forward(
compute_stress: bool = False,
compute_displacement: bool = False,
charge_cv_expr: Optional[Callable] = None,
compute_total_charge_gradients: Optional[bool] = False
) -> Dict[str, Optional[torch.Tensor]]:
assert compute_force is False
assert compute_virials is False
Expand Down Expand Up @@ -1218,6 +1219,12 @@ def forward(
atomic_charges = torch.sum(contributions_charges, dim=-1) # [n_nodes,1]
contributions_total_charge = torch.stack(total_charge_list, dim=-1)
total_charge = torch.sum(contributions_total_charge, dim=-1) # [n_graphs, ]
if (not training and compute_total_charge_gradients):
total_charge_gradients = compute_charge_cv_gradients(
charge_cvs=total_charge, positions=data["positions"]
)
else:
total_charge_gradients = None
if (not training and (charge_cv_expr is not None)):
for i in range(0, len(data["ptr"]) - 1):
charge_cv = charge_cv_expr(atomic_charges[data["ptr"][i]:data["ptr"][i+1]])
Expand All @@ -1227,9 +1234,6 @@ def forward(
charge_cv_gradients = compute_charge_cv_gradients(
charge_cvs=charge_cvs, positions=data["positions"]
)
total_charge_gradients = compute_charge_cv_gradients(
charge_cvs=total_charge, positions=data["positions"]
)
else:
charge_cv_gradients = None
charge_cvs = None
Expand Down Expand Up @@ -1376,6 +1380,7 @@ def forward(
compute_stress: bool = False,
compute_displacement: bool = False,
charge_cv_expr: Optional[Callable] = None,
compute_total_charge_gradients: Optional[bool] = False
) -> Dict[str, Optional[torch.Tensor]]:
# Setup
data["node_attrs"].requires_grad_(True)
Expand Down Expand Up @@ -1463,6 +1468,12 @@ def forward(
atomic_charges = torch.sum(contributions_charges, dim=-1) # [n_nodes,1]
contributions_total_charge = torch.stack(total_charge_list, dim=-1)
total_charge = torch.sum(contributions_total_charge, dim=-1) # [n_graphs, ]
if (not training and compute_total_charge_gradients):
total_charge_gradients = compute_charge_cv_gradients(
charge_cvs=total_charge, positions=data["positions"]
)
else:
total_charge_gradients = None
if (not training and (charge_cv_expr is not None)):
for i in range(0, len(data["ptr"]) - 1):
charge_cv = charge_cv_expr(atomic_charges[data["ptr"][i]:data["ptr"][i+1]])
Expand All @@ -1472,9 +1483,6 @@ def forward(
charge_cv_gradients = compute_charge_cv_gradients(
charge_cvs=charge_cvs, positions=data["positions"]
)
total_charge_gradients = compute_charge_cv_gradients(
charge_cvs=total_charge, positions=data["positions"]
)
else:
charge_cv_gradients = None
charge_cvs = None
Expand Down

0 comments on commit 62718c4

Please sign in to comment.