Skip to content

Commit

Permalink
refactor: updating conditions for __eq__ with new parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkelvin committed Dec 16, 2024
1 parent b232aa0 commit ebdecba
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,11 +666,15 @@ def _compare_mirror_images(image_a: np.ndarray, image_b: np.ndarray) -> bool:
return any([np.all(image_a == image_b), np.all((-1 * image_a) == image_b)])

def __eq__(self, other: Edge) -> bool:
index_eq = self.sorted_index == other.sorted_index
# for interactions with the same atom between neighboring images
if self.src != self.dst:
image_eq = True # shortcut since this is more likely to occur
if self.is_directed:
# if we care about directionality, equality is based on exact match
index_eq = (self.src, self.dst) == (other.src, other.dst)
else:
# otherwise, equality is based only on pair of nodes
index_eq = self.sorted_index == other.sorted_index
# for interactions with the same atom between neighboring images
image_eq = True
if self.src == self.dst and self.exclude_mirror:
image_eq = self._compare_mirror_images(self.image, other.image)
return all([index_eq, image_eq])

Expand Down

0 comments on commit ebdecba

Please sign in to comment.