Skip to content

Commit

Permalink
Copybara import of the pull request google/flax#2497
Browse files Browse the repository at this point in the history
--
e0b43740bbe0a6adf128f9d3b8fcabc03fb1d2c1 by ivyzheng <[email protected]>:

Explicitly raise error when Flax restore_checkpoint path given is nonexistent. Also fixing all third_party usages.

PiperOrigin-RevId: 488009224
  • Loading branch information
IvyZX authored and The jax3d Authors committed Nov 12, 2022
1 parent 36998f3 commit 838fbce
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
5 changes: 4 additions & 1 deletion jax3d/projects/generative/nerf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ def load_checkpoint(self, init_state: TrainState) -> TrainState:
logging.info("Using pre-trained checkpoint %s",
self.pre_trained_checkpoint)
checkpoint_dir = self.pre_trained_checkpoint
new_state = checkpoints.restore_checkpoint(checkpoint_dir, init_state)
try:
new_state = checkpoints.restore_checkpoint(checkpoint_dir, init_state)
except ValueError:
new_state = init_state

if new_state.step != init_state.step:
logging.info("Restored from checkpoint at step %i", new_state.step)
Expand Down
20 changes: 13 additions & 7 deletions jax3d/projects/nesf/nerfstatic/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,14 @@ def restore_ds_checkpoint_for_process(
Dataset iterator state.
"""
ds_state_ckpt = checkpoints.restore_checkpoint(
os.fspath(save_dir),
target=None,
prefix=_CKPT_PREFIX_DS.format(process_id=jax.process_index()),
)
try:
ds_state_ckpt = checkpoints.restore_checkpoint(
os.fspath(save_dir),
target=None,
prefix=_CKPT_PREFIX_DS.format(process_id=jax.process_index()),
)
except ValueError:
ds_state_ckpt = None

# Handle case where a ckpt was found.
if ds_state_ckpt:
Expand Down Expand Up @@ -150,8 +153,11 @@ def restore_opt_checkpoint(*,
Restored model state.
"""
state = checkpoints.restore_checkpoint(
os.fspath(save_dir), state, prefix=_CKPT_PREFIX_OPT)
try:
state = checkpoints.restore_checkpoint(
os.fspath(save_dir), state, prefix=_CKPT_PREFIX_OPT)
except ValueError:
pass
return state


Expand Down

0 comments on commit 838fbce

Please sign in to comment.