Skip to content

Commit

Permalink
Update error message format and provide expected value for keras/call…
Browse files Browse the repository at this point in the history
…backs.py

PiperOrigin-RevId: 388983687
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Aug 5, 2021
1 parent 9eeef2e commit 45e4287
Showing 1 changed file with 40 additions and 29 deletions.
69 changes: 40 additions & 29 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ def _call_batch_hook(self, mode, hook, batch, logs=None):
elif hook == 'end':
self._call_batch_end_hook(mode, batch, logs)
else:
raise ValueError('Unrecognized hook: {}'.format(hook))
raise ValueError(
f'Unrecognized hook: {hook}. Expected values are ["begin", "end"]')

def _call_batch_begin_hook(self, mode, batch, logs):
"""Helper function for `on_*_batch_begin` methods."""
Expand Down Expand Up @@ -562,9 +563,10 @@ def _disallow_batch_hooks_in_ps_strategy(self):
cb._implements_predict_batch_hooks()):
unsupported_callbacks.append(cb)
if unsupported_callbacks:
raise ValueError('Batch-level `Callback`s are not supported with '
'`ParameterServerStrategy`. Found unsupported '
'callbacks: {}'.format(unsupported_callbacks))
raise ValueError(
'Batch-level `Callback`s are not supported with '
'`ParameterServerStrategy`. Found unsupported '
f'callbacks: {unsupported_callbacks}')
# pylint: enable=protected-access


Expand Down Expand Up @@ -971,7 +973,9 @@ def __init__(self, count_mode='samples', stateful_metrics=None):
elif count_mode == 'steps':
self.use_steps = True
else:
raise ValueError('Unknown `count_mode`: ' + str(count_mode))
raise ValueError(
f'Unknown `count_mode`: {count_mode}. '
'Expected values are ["samples", "steps"]')
# Defaults to all Model's metrics except for loss.
self.stateful_metrics = set(stateful_metrics) if stateful_metrics else set()

Expand Down Expand Up @@ -1280,14 +1284,16 @@ def __init__(self,
options, tf.train.CheckpointOptions):
self._options = options or tf.train.CheckpointOptions()
else:
raise TypeError('If save_weights_only is True, then `options` must be '
'either None or a tf.train.CheckpointOptions')
raise TypeError(
'If save_weights_only is True, then `options` must be '
f'either None or a tf.train.CheckpointOptions. Got {options}.')
else:
if options is None or isinstance(options, tf.saved_model.SaveOptions):
self._options = options or tf.saved_model.SaveOptions()
else:
raise TypeError('If save_weights_only is False, then `options` must be'
'either None or a tf.saved_model.SaveOptions')
raise TypeError(
'If save_weights_only is False, then `options` must be '
f'either None or a tf.saved_model.SaveOptions. Got {options}.')

# Deprecated field `load_weights_on_restart` is for loading the checkpoint
# file from `filepath` at the start of `model.fit()`
Expand Down Expand Up @@ -1329,7 +1335,9 @@ def __init__(self,
self.best = np.Inf

if self.save_freq != 'epoch' and not isinstance(self.save_freq, int):
raise ValueError('Unrecognized save_freq: {}'.format(self.save_freq))
raise ValueError(
f'Unrecognized save_freq: {self.save_freq}. '
'Expected save_freq are "epoch" or integer')

# Only the chief worker writes model checkpoints, but all workers
# restore checkpoint at on_train_begin().
Expand All @@ -1347,8 +1355,8 @@ def on_train_begin(self, logs=None):
# name matching the pattern.
self.model.load_weights(filepath_to_load)
except (IOError, ValueError) as e:
raise ValueError('Error loading file from {}. Reason: {}'.format(
filepath_to_load, e))
raise ValueError(
f'Error loading file from {filepath_to_load}. Reason: {e}')

def _implements_train_batch_hooks(self):
# Only call batch hooks when saving on batch
Expand Down Expand Up @@ -1437,13 +1445,13 @@ def _save_model(self, epoch, batch, logs):
except IsADirectoryError as e: # h5py 3.x
raise IOError('Please specify a non-directory filepath for '
'ModelCheckpoint. Filepath used is an existing '
'directory: {}'.format(filepath))
f'directory: {filepath}')
except IOError as e: # h5py 2.x
# `e.errno` appears to be `None` so checking the content of `e.args[0]`.
if 'is a directory' in str(e.args[0]).lower():
raise IOError('Please specify a non-directory filepath for '
'ModelCheckpoint. Filepath used is an existing '
'directory: {}'.format(filepath))
f'directory: f{filepath}')
# Re-throw the error for any other causes.
raise e

Expand All @@ -1460,8 +1468,9 @@ def _get_file_path(self, epoch, batch, logs):
file_path = self.filepath.format(
epoch=epoch + 1, batch=batch + 1, **logs)
except KeyError as e:
raise KeyError('Failed to format this callback filepath: "{}". '
'Reason: {}'.format(self.filepath, e))
raise KeyError(
f'Failed to format this callback filepath: "{self.filepath}". '
f'Reason: {e}')
self._write_filepath = distributed_file_utils.write_filepath(
file_path, self.model.distribute_strategy)
return self._write_filepath
Expand Down Expand Up @@ -1665,10 +1674,9 @@ def on_train_begin(self, logs=None):
if self.model._distribution_strategy and not isinstance(
self.model.distribute_strategy, self._supported_strategies):
raise NotImplementedError(
'%s is not supported yet. '
f'{type(self.model.distribute_strategy)} is not supported yet. '
'Currently BackupAndRestore callback only supports empty strategy, '
'MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy.' %
type(self.model.distribute_strategy).__name__)
'MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy.')
self.model._training_state = (
worker_training_state.WorkerTrainingState(self.model, self.backup_dir))
self._training_state = self.model._training_state
Expand Down Expand Up @@ -1950,9 +1958,10 @@ def on_epoch_begin(self, epoch, logs=None):
lr = self.schedule(epoch)
if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
f'should be float. Got: {lr}')
if isinstance(lr, tf.Tensor) and not lr.dtype.is_floating:
raise ValueError('The dtype of Tensor should be float')
raise ValueError(
f'The dtype of `lr` Tensor should be float. Got: {lr.dtype}')
backend.set_value(self.model.optimizer.lr, backend.get_value(lr))
if self.verbose > 0:
print('\nEpoch %05d: LearningRateScheduler setting learning '
Expand Down Expand Up @@ -2193,14 +2202,15 @@ def _validate_kwargs(self, kwargs):
'2.0. Instead, all `Embedding` variables will be '
'visualized.')

unrecognized_kwargs = set(kwargs.keys()) - {
'write_grads', 'embeddings_layer_names', 'embeddings_data', 'batch_size'
}
supported_kwargs = {'write_grads', 'embeddings_layer_names',
'embeddings_data', 'batch_size'}
unrecognized_kwargs = set(kwargs.keys()) - supported_kwargs

# Only allow kwargs that were supported in V1.
if unrecognized_kwargs:
raise ValueError('Unrecognized arguments in `TensorBoard` '
'Callback: ' + str(unrecognized_kwargs))
raise ValueError(
'Unrecognized arguments in `TensorBoard` Callback: '
f'{unrecognized_kwargs}. Supported kwargs are: {supported_kwargs}')

def set_model(self, model):
"""Sets Keras model and writes graph if specified."""
Expand Down Expand Up @@ -2291,7 +2301,7 @@ def _configure_embeddings(self):
str):
raise ValueError('Unrecognized `Embedding` layer names passed to '
'`keras.callbacks.TensorBoard` `embeddings_metadata` '
'argument: ' + str(self.embeddings_metadata.keys()))
f'argument: {self.embeddings_metadata.keys()}')

config_pbtxt = text_format.MessageToString(config)
path = os.path.join(self._log_write_dir, 'projector_config.pbtxt')
Expand Down Expand Up @@ -2345,7 +2355,7 @@ def _init_profile_batch(self, profile_batch):
profile_batch_error_message = (
'profile_batch must be a non-negative integer or 2-tuple of positive '
'integers. A pair of positive integers signifies a range of batches '
'to profile. Found: {}'.format(profile_batch))
f'to profile. Found: {profile_batch}')

# Support legacy way of specifying "start,stop" or "start" as str.
if isinstance(profile_batch, str):
Expand Down Expand Up @@ -2629,7 +2639,8 @@ def __init__(self,

self.monitor = monitor
if factor >= 1.0:
raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.')
raise ValueError(
f'ReduceLROnPlateau does not support a factor >= 1.0. Got {factor}')
if 'epsilon' in kwargs:
min_delta = kwargs.pop('epsilon')
logging.warning('`epsilon` argument is deprecated and '
Expand Down

0 comments on commit 45e4287

Please sign in to comment.