forked from eliorc/node2vec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit f44c7e8
Showing
13 changed files
with
779 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Compiled python modules. | ||
*.pyc | ||
|
||
# Setuptools distribution folder. | ||
/dist/ | ||
|
||
# Python egg metadata, regenerated from source files by setuptools. | ||
/*.egg-info |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Node2Vec | ||
|
||
Python3 implementation of the node2vec algorithm Aditya Grover, Jure Leskovec and Vid Kocijan. | ||
[node2vec: Scalable Feature Learning for Networks. A. Grover, J. Leskovec. ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), 2016.](https://snap.stanford.edu/node2vec/) | ||
|
||
## Installation | ||
|
||
`pip install node2vec` | ||
|
||
## Usage | ||
```python | ||
import networkx as nx | ||
from node2vec import Node2Vec | ||
|
||
# Create a graph | ||
graph = nx.fast_gnp_random_graph(n=100, p=0.5) | ||
|
||
# Precompute probabilities and generate walks | ||
node2vec = Node2Vec(graph, dimensions=64, walk_length=30, num_walks=200) | ||
|
||
# Embed | ||
model = node2vec.fit(window=10, min_count=1, batch_words=4) # Any keywords acceptable by gensim.Word2Vec can be passed | ||
|
||
# Look for most similar nodes | ||
model.most_similar('2') # Output node names are always strings | ||
|
||
``` | ||
|
||
### Parameters | ||
- `Node2Vec` constructor: | ||
1. `graph`: The first positional argument has to be a networkx graph. Node names must be all integers or all strings. On the output model they will always be strings. | ||
2. `dimensions`: Embedding dimensions (default: 128) | ||
3. `walk_length`: Number of nodes in each walk (default: 80) | ||
4. `num_walks`: Number of walks per node (default: 10) | ||
5. `p`: Return hyper parameter (default: 1) | ||
6. `q`: Inout parameter (default: 1) | ||
7. `weight_key`: On weighted graphs, this is the key for the weight attribute (default: 'weight') | ||
8. `sampling_strategy`: Node specific sampling strategies, supports setting node specific 'q', 'p', 'num_walks' and 'walk_length'. | ||
Use these keys exactly. If not set, will use the global ones which were passed on the object initialization` | ||
- `Node2Vec.fit` method: | ||
Accepts any key word argument acceptable by gensim.Word2Vec | ||
|
||
## Caveats | ||
- Node names in the input graph must be all strings, or all ints | ||
|
||
## TODO | ||
- Parallel implementation for probability precomputation and walk generation | ||
|
||
## Contributing | ||
I will probably not be maintaining this package actively, if someone wants to contribute and maintain, please contact me. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import networkx as nx | ||
from node2vec import Node2Vec | ||
|
||
# Create a graph | ||
graph = nx.fast_gnp_random_graph(n=100, p=0.5) | ||
|
||
# Precompute probabilities and generate walks | ||
node2vec = Node2Vec(graph, dimensions=64, walk_length=30, num_walks=200) | ||
|
||
# Embed | ||
model = node2vec.fit(window=10, min_count=1, batch_words=4) # Any keywords acceptable by gensim.Word2Vec can be passed | ||
|
||
# Look for most similar nodes | ||
model.most_similar('2') # Output node names are always strings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .node2vec import Node2Vec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import random | ||
from ast import literal_eval | ||
|
||
import numpy as np | ||
import gensim | ||
from tqdm import tqdm | ||
|
||
|
||
class Node2Vec: | ||
|
||
PROBABILITIES_KEY = 'probabilities' | ||
NUM_WALKS_KEY = 'num_walks' | ||
WALK_LENGTH_KEY = 'walk_length' | ||
P_KEY = 'p' | ||
Q_KEY = 'q' | ||
|
||
def __init__(self, graph, dimensions=128, walk_length=80, num_walks=10, p=1, q=1, weight_key='weight', | ||
sampling_strategy={}): | ||
""" | ||
Initiates the Node2Vec object, precomputes walking probabilities and generates the walks. | ||
:param graph: Input graph | ||
:type graph: Networkx Graph | ||
:param dimensions: Embedding dimensions (default: 128) | ||
:type dimensions: int | ||
:param walk_length: Number of nodes in each walk (default: 80) | ||
:type walk_length: int | ||
:param num_walks: Number of walks per node (default: 10) | ||
:type num_walks: int | ||
:param p: Return hyper parameter (default: 1) | ||
:type p: float | ||
:param q: Inout parameter (default: 1) | ||
:type q: float | ||
:param weight_key: On weighted graphs, this is the key for the weight attribute (default: 'weight') | ||
:type weight_key: str | ||
:param sampling_strategy: Node specific sampling strategies, supports setting node specific 'q', 'p', 'num_walks' and 'walk_length'. | ||
Use these keys exactly. If not set, will use the global ones which were passed on the object initialization | ||
""" | ||
self.graph = graph | ||
self.dimensions = dimensions | ||
self.walk_length = walk_length | ||
self.num_walks = num_walks | ||
self.p = p | ||
self.q = q | ||
self.weight_key = weight_key | ||
self.sampling_strategy = sampling_strategy | ||
|
||
self._precompute_probabilities() | ||
self.walks = self._generate_walks() | ||
|
||
def _precompute_probabilities(self): | ||
""" | ||
Precomputes transition probabilities for each node. | ||
""" | ||
|
||
for source in self.graph.nodes(): | ||
|
||
for current_node in self.graph.neighbors(source): | ||
|
||
# Init probabilities dict | ||
if self.PROBABILITIES_KEY not in self.graph.node[current_node]: | ||
self.graph.node[current_node][self.PROBABILITIES_KEY] = dict() | ||
|
||
if current_node == 3 and source == 1: | ||
asd = 123 | ||
|
||
unnormalized_weights = list() | ||
|
||
# Calculate unnormalized weights | ||
for destination in self.graph.neighbors(current_node): | ||
|
||
# Retrieve p and q | ||
p = self.sampling_strategy[current_node].get(self.P_KEY, | ||
self.p) if current_node in self.sampling_strategy else self.p | ||
q = self.sampling_strategy[current_node].get(self.Q_KEY, | ||
self.q) if current_node in self.sampling_strategy else self.q | ||
|
||
if destination == source: # Backwards probability | ||
ss_weight = self.graph[current_node][destination].get(self.weight_key, 1) * 1/p | ||
elif destination in self.graph[source]: # If the neighbor is connected to the source | ||
ss_weight = self.graph[current_node][destination].get(self.weight_key, 1) | ||
else: | ||
ss_weight = self.graph[current_node][destination].get(self.weight_key, 1) * 1/q | ||
|
||
# Assign the unnormalized sampling strategy weight, normalize during random walk | ||
unnormalized_weights.append(ss_weight) | ||
|
||
# Normalize | ||
unnormalized_weights = np.array(unnormalized_weights) | ||
self.graph.node[current_node][self.PROBABILITIES_KEY][source] = unnormalized_weights / unnormalized_weights.sum() | ||
|
||
def _generate_walks(self): | ||
""" | ||
Generates the random walks which will be used as the skip-gram input. | ||
:return: List of walks. Each walk is a list of nodes. | ||
""" | ||
|
||
walks = list() | ||
|
||
with tqdm(total=self.num_walks) as pbar: | ||
pbar.set_description('Generating walks') | ||
|
||
for n_walk in range(self.num_walks): | ||
|
||
# Update progress bar | ||
pbar.update(1) | ||
|
||
# Shuffle the nodes | ||
shuffled_nodes = list(self.graph.nodes()) | ||
random.shuffle(shuffled_nodes) | ||
|
||
# Start a random walk from every node | ||
for source in shuffled_nodes: | ||
|
||
# Skip nodes with specific num_walks | ||
if source in self.sampling_strategy and \ | ||
self.NUM_WALKS_KEY in self.sampling_strategy[source] and \ | ||
self.sampling_strategy[source][self.NUM_WALKS_KEY] <= n_walk: | ||
continue | ||
|
||
# Start walk | ||
walk = [source] | ||
|
||
# Calculate walk length | ||
if source in self.sampling_strategy: | ||
walk_length = self.sampling_strategy[source].get(self.WALK_LENGTH_KEY, self.walk_length) | ||
else: | ||
walk_length = self.walk_length | ||
|
||
while len(walk) < walk_length: | ||
walk_options = list(self.graph.neighbors(walk[-1])) | ||
|
||
if len(walk) == 1: # For the first step | ||
walk_to = np.random.choice(walk_options, size=1)[0] | ||
else: | ||
probabilities = self.graph.node[walk[-1]][self.PROBABILITIES_KEY][walk[-2]] | ||
walk_to = np.random.choice(walk_options, size=1, p=probabilities)[0] | ||
|
||
walk.append(walk_to) | ||
|
||
walk = list(map(str, walk)) | ||
|
||
walks.append(walk) | ||
|
||
return walks | ||
|
||
def fit(self, **skip_gram_params): | ||
""" | ||
Creates the embeddings using gensim's Word2Vec. | ||
:param skip_gram_params: Parameteres for gensim.models.Word2Vec - do not supply 'size' it is taken from the Node2Vec 'dimensions' parameter | ||
:type skip_gram_params: dict | ||
:return: A gensim word2vec model | ||
""" | ||
return gensim.models.Word2Vec(self.walks, size=self.dimensions, **skip_gram_params) | ||
|
||
|
||
# Create a graph | ||
edgelist = [ | ||
('a', 'b'), | ||
('a', 'c'), | ||
('a', 'd'), | ||
('b', 'c'), # Clique of letters | ||
('b', 'd'), | ||
('c', 'd') | ||
] | ||
''' | ||
edgelist = [ | ||
(1, 2), | ||
(1, 3), | ||
(1, 4), | ||
(2, 3), # Clique of numbers | ||
(2, 4), | ||
(3, 4), | ||
(4, 5), | ||
(5, 6), | ||
(5, 7), | ||
(5, 8), | ||
(6, 7), # Clique of letters | ||
(6, 8), | ||
(7, 8) | ||
] | ||
''' | ||
import networkx as nx | ||
graph = nx.Graph() | ||
graph.add_edges_from(edgelist) | ||
x = Node2Vec(graph, walk_length=30, num_walks=800, sampling_strategy={1:{'walk_length': 2}, 3: {'p': 30}}) | ||
|
||
''' | ||
graph = nx.Graph() | ||
graph.add_edges_from([(1, 3), (2,3), (3,5), (3,4), (4, 5)]) | ||
x = Node2Vec(graph, walk_length=30, num_walks=800, sampling_strategy={1:{'walk_length': 2}, 3: {'p': 30}}) | ||
model = x.fit(window=5) | ||
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[metadata] | ||
description-file = README.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from distutils.core import setup | ||
|
||
setup( | ||
name='node2vec', | ||
packages=['node2vec'], | ||
version='0.1.0', | ||
description='Implementation of the node2vec algorithm.', | ||
author='Elior Cohen', | ||
author_email='', | ||
license='MIT', | ||
url='github package source url', | ||
keywords=['machine learning', 'embeddings'], | ||
) |