Skip to content

Commit

Permalink
Fix: checkpoint-related naming and numpy compatibility of inhomogeneo…
Browse files Browse the repository at this point in the history
…us arrays (alibaba#260)
  • Loading branch information
LiSu authored Apr 26, 2023
1 parent a5ee83c commit 6677f70
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 19 deletions.
30 changes: 14 additions & 16 deletions graphlearn/examples/tf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,20 @@ class TFTrainer(object):
Args:
ckpt_dir: checkpoint dir.
ckpt_freq: checkpoint frequency.
ckpt_steps: checkpoint steps.
save_checkpoint_secs: checkpoint frequency.
save_checkpoint_steps: checkpoint steps.
profiling: whether write timeline for profiling, default is False.
progress_steps: print a progress logs for given steps.
"""
def __init__(self,
ckpt_dir=None,
save_checkpoint_secs=600,
save_checkpoint_steps=None,
ckpt_steps=None,
profiling=False,
progress_steps=10):
self.ckpt_dir = ckpt_dir
self.save_checkpoint_secs = save_checkpoint_secs
self.save_checkpoint_steps = save_checkpoint_steps
self.ckpt_steps = ckpt_steps
self.profiling = profiling
self.progress_steps = progress_steps

Expand All @@ -85,9 +83,9 @@ def init_session(self, hooks=None, **kwargs):
checkpoint_args = dict()
if self.ckpt_dir is not None:
checkpoint_args['checkpoint_dir'] = self.ckpt_dir
if self.ckpt_freq is not None:
if self.save_checkpoint_secs is not None:
checkpoint_args['save_checkpoint_secs'] = self.save_checkpoint_secs
if self.ckpt_steps is not None:
if self.save_checkpoint_steps is not None:
checkpoint_args['save_checkpoint_steps'] = self.save_checkpoint_steps

self.sess = tf.train.MonitoredTrainingSession(
Expand Down Expand Up @@ -238,18 +236,18 @@ class LocalTrainer(TFTrainer):
Args:
ckpt_dir: checkpoint dir.
ckpt_freq: checkpoint frequency.
ckpt_steps: checkpoint steps.
save_checkpoint_freq: checkpoint frequency.
save_checkpoint_steps: checkpoint steps.
profiling: whether write timeline for profiling, default is False.
progress_steps: print a progress logs for given steps.
"""
def __init__(self,
ckpt_dir=None,
ckpt_freq=None,
ckpt_steps=None,
save_checkpoint_secs=None,
save_checkpoint_steps=None,
profiling=False,
progress_steps=10):
super().__init__(ckpt_dir, ckpt_freq, ckpt_steps, profiling, progress_steps)
super().__init__(ckpt_dir, save_checkpoint_secs, save_checkpoint_steps, profiling, progress_steps)
self.is_local = True

if hasattr(contextlib, 'nullcontext'):
Expand Down Expand Up @@ -287,8 +285,8 @@ class DistTrainer(TFTrainer):
task_index: index of this worker.
worker_count: The number of TensorFlow worker.
ckpt_dir: checkpoint dir.
ckpt_freq: checkpoint frequency.
ckpt_steps: checkpoint steps.
save_checkpoint_freq: checkpoint frequency.
save_checkpoint_steps: checkpoint steps.
profiling: whether write timeline for profiling, default is False.
progress_steps: print a progress logs for given steps.
"""
Expand All @@ -298,11 +296,11 @@ def __init__(self,
task_index,
worker_count,
ckpt_dir=None,
ckpt_freq=None,
ckpt_steps=None,
save_checkpoint_secs=None,
save_checkpoint_steps=None,
profiling=False,
progress_steps=10):
super().__init__(ckpt_dir, ckpt_freq, ckpt_steps, profiling, progress_steps)
super().__init__(ckpt_dir, save_checkpoint_secs, save_checkpoint_steps, profiling, progress_steps)
self.is_local = False
self.cluster_spec = cluster_spec
self.job_name = job_name
Expand Down
2 changes: 1 addition & 1 deletion graphlearn/python/nn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def parse_value(value, masks):
values.extend([value.dense_shape])
else:
values.extend([None, None, None])
return list(np.array(values)[feat_masks + id_masks + sparse_masks])
return list(np.array(values, dtype=object)[feat_masks + id_masks + sparse_masks])

try:
values = self._ds.next()
Expand Down
4 changes: 2 additions & 2 deletions graphlearn/python/nn/tf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,15 @@ def _data_types_and_shapes(self, node_decoder, is_edge=False, is_sparse=False):
tf.TensorShape([None, node_decoder.string_attr_num]),
tf.TensorShape([None]), # labels
tf.TensorShape([None]), # weights
tf.TensorShape([None])])[feat_masks] # timestamps
tf.TensorShape([None])], dtype=object)[feat_masks] # timestamps

id_types = np.array([tf.int64, tf.int64])[id_masks] # ids, dst_ids
id_shapes = np.array([tf.TensorShape([None]), tf.TensorShape([None])])[id_masks]
# offsets, indices and dense_shape for sparse Data.
sparse_types = np.array([tf.int64, tf.int64, tf.int64])[sparse_masks]
sparse_shapes = np.array([tf.TensorShape([None]),
tf.TensorShape([None, 2]),
tf.TensorShape([None])])[sparse_masks]
tf.TensorShape([None])], dtype=object)[sparse_masks]
return list(feat_types) + list(id_types) + list(sparse_types), \
list(feat_shapes) + list(id_shapes) + list(sparse_shapes)

Expand Down

0 comments on commit 6677f70

Please sign in to comment.