Skip to content

Commit

Permalink
improve speed of the data pipeline and the evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
a-arbabi committed Jan 29, 2020
1 parent 6477e7d commit b1c30dd
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
9 changes: 6 additions & 3 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,19 @@ def _make_dense(x):
'sub': x['sub'],
'rel': x['rel'],
'obj_list': tf.scatter_nd(
tf.expand_dims(x['obj_list'], 1), tf.ones_like(x['obj_list']), [n_entities]),
tf.expand_dims(x['obj_list'], 1),
tf.ones_like(x['obj_list'], dtype=tf.float32),
[n_entities]),
}
if ground_truth_triplets is not None:
output['gt_obj_list'] = tf.scatter_nd(
tf.expand_dims(x['gt_obj_list'], 1),
tf.ones_like(x['gt_obj_list']),
tf.ones_like(x['gt_obj_list'], dtype=tf.float32),
[n_entities])
return output

dataset = dataset.cache().map(_make_dense)
dataset = dataset.map(_make_dense)
#dataset = dataset.cache().map(_make_dense)

return dataset

Expand Down
12 changes: 7 additions & 5 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ def evaluate(model, config, dataset):
for batch in dataset:
logits = model(batch, training=False)
probs = tf.nn.softmax(logits, axis=1)
negatives = 1.0 - tf.cast(batch['gt_obj_list'], tf.float32)
negatives = 1.0 - batch['gt_obj_list']
negative_probs = probs * negatives
indecies = tf.where(batch['obj_list'])
for row in indecies:
rank = tf.math.count_nonzero(probs[row[0], row[1]]<=negative_probs[row[0]])+1
ranks.append(int(rank))
ranks = np.array(ranks)
gathered_probs = tf.gather_nd(probs, indecies)
gathered_negs = tf.gather(negative_probs, indecies[:,0])
batch_ranks = tf.math.count_nonzero(tf.expand_dims(gathered_probs, -1) <= gathered_negs, axis=1)+1
ranks.append(batch_ranks)

ranks = tf.concat(ranks, 0).numpy()
return {
'R@1': np.mean(ranks==1),
'R@3': np.mean(ranks<=3),
Expand Down
6 changes: 3 additions & 3 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def train_tucker(model, train_dataset, valid_dataset, config):
labels[i, train_sub_rel_pair_to_objs[tuple(batch[i])]] = 1.0
'''

labels = tf.cast(batch['obj_list'], dtype=tf.float32)
labels = batch['obj_list']

with tf.GradientTape() as tape:
logits = model(batch, training=True)
Expand Down Expand Up @@ -89,12 +89,12 @@ def filter_fn(x):
train_dataset = data_utils.create_aggregated_dataset(
triplets['train'], config.n_entities)
train_dataset = (
train_dataset.filter(filter_fn).shuffle(500).batch(config.batch_size))
train_dataset.filter(filter_fn).shuffle(500).batch(config.batch_size)).cache()

valid_dataset = data_utils.create_aggregated_dataset(
triplets['valid'], config.n_entities, gt_triplets)
valid_dataset = (
valid_dataset.filter(filter_fn).shuffle(500).batch(config.batch_size))
valid_dataset.filter(filter_fn).batch(config.batch_size))

models = [tucker.TuckerModel(config) for i in range(config.n_models)]
for model in models:
Expand Down

0 comments on commit b1c30dd

Please sign in to comment.