Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
benedekrozemberczki authored Feb 16, 2018
1 parent 024b18f commit 09245ba
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 539 deletions.
20 changes: 16 additions & 4 deletions src/calculation_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,26 @@ def min_norm(g, node_1, node_2):
min_norm = min(len(set(nx.neighbors(g, node_1))), len(set(nx.neighbors(g, node_2))))
return float(inter)/float(min_norm)

def overlap_generator(metric, graph):
def overlap_generator(args, graph):
"""
Function to generate weight for all of the edges.
"""
"""
if args.overlap_weighting == "normalized_overlap":
overlap_weighter = normalized_overlap
elif args.overlap_weighting == "overlap":
overlap_weighter = overlap
elif args.overlap_weighting == "min_norm":
overlap_weighter = min_norm
else:
overlap_weighter = unit
print(" ")
print("Weight calculation started.")
print(" ")
edges = nx.edges(graph)
weights = {edge: metric(graph, edge[0], edge[1]) for edge in tqdm(edges)}
weights_prime = {(edge[1],edge[0]): value for edge, value in tqdm(weights.iteritems())}
weights = {edge: overlap_weighter(graph, edge[0], edge[1]) for edge in tqdm(edges)}
weights_prime = {(edge[1],edge[0]): value for edge, value in weights.iteritems()}
weights.update(weights_prime)
print(" ")
return weights

def index_generation(weights, a_random_walk):
Expand Down
8 changes: 4 additions & 4 deletions src/embedding_clustering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from parser import parameter_parser
from print_and_read import graph_reader
from model import GEMSECWithRegularization, GEMSEC, DWWithRegularization, DW
from model import GEMSECWithRegularization, GEMSEC, DeepWalkWithRegularization, DeepWalk

def create_and_run_model(args):
"""
Expand All @@ -11,10 +11,10 @@ def create_and_run_model(args):
model = GEMSECWithRegularization(args, graph)
elif args.model == "GEMSEC":
model = GEMSEC(args, graph)
elif args.model == "DWWithRegularization":
model = DWWithRegularization(args, graph)
elif args.model == "DeepWalkWithRegularization":
model = DeepWalkWithRegularization(args, graph)
else:
model = DW(args, graph)
model = DeepWalk(args, graph)
model.train()

if __name__ == "__main__":
Expand Down
104 changes: 104 additions & 0 deletions src/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import tensorflow as tf
import math
import numpy as np

class DeepWalker:
"""
DeepWalk embedding layer class.
"""
def __init__(self, args, vocab_size, degrees):
"""
Initialization of the layer with proper matrices and biases.
The input variables are also initialized here.
"""
self.args = args
self.vocab_size = vocab_size
self.degrees = degrees
self.train_labels = tf.placeholder(tf.int64, shape=[None, self.args.window_size])


self.train_inputs = tf.placeholder(tf.int64, shape=[None])

self.embedding_matrix = tf.Variable(tf.random_uniform([self.vocab_size, self.args.dimensions],
-0.1/self.args.dimensions, 0.1/self.args.dimensions))


self.nce_weights = tf.Variable(tf.truncated_normal([self.vocab_size, self.args.dimensions],
stddev=1.0 / math.sqrt(self.args.dimensions)))

self.nce_biases = tf.Variable(tf.random_uniform([self.vocab_size], -0.1/self.args.dimensions, 0.1/self.args.dimensions))


def __call__(self):
"""
Calculating the embedding cost with NCE and returning it.
"""
self.train_labels_flat = tf.reshape(self.train_labels, [-1, 1])
self.input_ones = tf.ones_like(self.train_labels)
self.train_inputs_flat = tf.reshape(tf.multiply(self.input_ones, tf.reshape(self.train_inputs,[-1,1])),[-1])
self.embedding_partial = tf.nn.embedding_lookup(self.embedding_matrix, self.train_inputs_flat, max_norm = 1)

self.sampler = tf.nn.fixed_unigram_candidate_sampler(true_classes = self.train_labels_flat,
num_true = 1,
num_sampled = self.args.negative_sample_number,
unique = True,
range_max = self.vocab_size,
distortion = self.args.distortion,
unigrams = self.degrees)

self.embedding_losses = tf.nn.sampled_softmax_loss(weights = self.nce_weights,
biases = self.nce_biases,
labels = self.train_labels_flat,
inputs = self.embedding_partial,
num_true = 1,
num_sampled = self.args.negative_sample_number,
num_classes = self.vocab_size,
sampled_values = self.sampler)

return tf.reduce_mean(self.embedding_losses)

class Clustering:
"""
Latent space clustering class.
"""
def __init__(self, args):
"""
Initializing the cluster center matrix.
"""
self.args = args
self.cluster_means = tf.Variable(tf.random_uniform([self.args.cluster_number, self.args.dimensions],
-0.1/self.args.dimensions, 0.1/self.args.dimensions))
def __call__(self, Walker):
"""
Calculating the clustering cost.
"""

self.clustering_differences = tf.expand_dims(Walker.embedding_partial,1) - self.cluster_means
self.cluster_distances = tf.norm(self.clustering_differences, ord = 2, axis = 2)
self.to_be_averaged = tf.reduce_min(self.cluster_distances, axis = 1)
return tf.reduce_mean(self.to_be_averaged)

class Regularization:
"""
Smoothness regularization class.
"""
def __init__(self, args):
"""
Initializing the indexing variables and the weight vector.
"""
self.args = args
self.edge_indices_right = tf.placeholder(tf.int64, shape=[None])
self.edge_indices_left = tf.placeholder(tf.int64, shape=[None])
self.overlap = tf.placeholder(tf.float32, shape=[None, 1])

def __call__(self, Walker):
"""
Calculating the regularization cost.
"""
self.left_features = tf.nn.embedding_lookup(Walker.embedding_partial, self.edge_indices_left, max_norm = 1)
self.right_features = tf.nn.embedding_lookup(Walker.embedding_partial, self.edge_indices_right, max_norm = 1)
self.regularization_differences = self.left_features - self.right_features + np.random.uniform(self.args.regularization_noise,self.args.regularization_noise, (self.args.random_walk_length-1, self.args.dimensions))
self.regularization_distances = tf.norm(self.regularization_differences, ord = 2,axis=1)
self.regularization_distances = tf.reshape(self.regularization_distances, [ -1, 1])
self.regularization_loss = tf.reduce_mean(tf.matmul(tf.transpose(self.overlap), self.regularization_distances))
return self.args.lambd*self.regularization_loss
Loading

0 comments on commit 09245ba

Please sign in to comment.