Skip to content

Commit

Permalink
Don't report origin_msg if any execption is raised in self._origin_msg
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 618237231
  • Loading branch information
yashk2810 authored and jax authors committed Mar 22, 2024
1 parent d7e5dde commit 0b46341
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,9 +713,13 @@ def sharding(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
try:
orig_msg = self._origin_msg()
except:
orig_msg = ''
raise AttributeError(self,
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
f"{orig_msg}")

@property
def addressable_shards(self):
Expand Down

0 comments on commit 0b46341

Please sign in to comment.