Skip to content

Commit

Permalink
arxiv version code
Browse files Browse the repository at this point in the history
  • Loading branch information
muhanzhang committed Apr 26, 2019
1 parent aa41865 commit 089bbf5
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ results/
software/
PyG_GNN/
data/
plot_figure.py
33 changes: 20 additions & 13 deletions Main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@



parser = argparse.ArgumentParser(description='Link Prediction with SEAL')
parser = argparse.ArgumentParser(description='Learning Inductive Graph Patterns for recommender systems')
# general settings
parser.add_argument('--testing', action='store_true', default=False,
help='turn on testing mode')
Expand Down Expand Up @@ -60,13 +60,16 @@
help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=50, metavar='N',
help='batch size during training')
# transfer learning settings
# transfer learning and visualization settings
parser.add_argument('--standard-rating', action='store_true', default=False,
help='if True, maps all ratings to standard 1, 2, 3.4, 5 before training')
parser.add_argument('--transfer', action='store_true', default=False,
help='if True, load a pretrained model instead of training')
parser.add_argument('--model-pos', default='',
help="where to load the transferred model's state")
help="where to load the transferred model's state, will use current \
res_dir's model if not specified")
parser.add_argument('--visualize', action='store_true', default=False,
help='if True, load a pretrained model and do visualization exps')
# sparsity experiment settings
parser.add_argument('--ratio', type=float, default=1.0,
help="For ml_100k, if ratio < 1, sort train data by timestamp and\
Expand Down Expand Up @@ -102,6 +105,8 @@
else:
val_test_appendix = 'valmode'
args.res_dir = os.path.join(args.file_dir, 'results/{}{}_{}'.format(args.data_name, args.save_appendix, val_test_appendix))
if args.model_pos == '':
args.model_pos = os.path.join(args.res_dir, 'model_checkpoint{}.pth'.format(args.epochs))
if args.transfer:
args.res_dir += '_transfer'
if not os.path.exists(args.res_dir):
Expand All @@ -119,14 +124,14 @@
# backup current main.py, model.py files
copy('Main.py', args.res_dir)
copy('util_functions.py', args.res_dir)
copy('PyG_GNN/models.py', args.res_dir)
copy('PyG_GNN/train_eval.py', args.res_dir)
copy('models.py', args.res_dir)
copy('train_eval.py', args.res_dir)
if args.transfer: copy(args.model_pos, args.res_dir)
# save command line input
cmd_input = 'python ' + ' '.join(sys.argv)
with open(os.path.join(args.res_dir, 'cmd_input.txt'), 'w') as f:
f.write(cmd_input)
print('Command line input: ' + cmd_input + ' is saved.')
# save command line input
cmd_input = 'python ' + ' '.join(sys.argv)
with open(os.path.join(args.res_dir, 'cmd_input.txt'), 'a') as f:
f.write(cmd_input)
print('Command line input: ' + cmd_input + ' is saved.')


if args.data_name == 'ml_1m' or args.data_name == 'ml_10m':
Expand Down Expand Up @@ -242,9 +247,10 @@
num_bases=4,
regression=True)

with open(os.path.join(args.res_dir, 'cmd_input.txt'), 'a') as f:
f.write(str(model.k))
print('k is saved.')
if not args.transfer:
with open(os.path.join(args.res_dir, 'cmd_input.txt'), 'a') as f:
f.write(' --k ' + str(model.k) + '\n')
print('k is saved.')

def logger(info, model, optimizer):
epoch, train_loss, test_rmse = info['epoch'], info['train_loss'], info['test_rmse']
Expand Down Expand Up @@ -278,6 +284,7 @@ def logger(info, model, optimizer):
rmse = test_once(test_graphs, model, args.batch_size, logger)
print('Transfer learning rmse is: {:.4f}'.format(rmse))
elif args.visualize:
visualize(model, test_graphs, args.res_dir, args.data_name, class_values)



Expand Down
4 changes: 2 additions & 2 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def download_dataset(dataset, files, data_dir):
print('Downloading %s dataset' % dataset)

if dataset in ['ml_100k', 'ml_1m']:
target_dir = 'data/' + dataset.replace('_', '-')
target_dir = 'raw_data/' + dataset.replace('_', '-')
elif dataset == 'ml_10m':
target_dir = 'data/' + 'ml-10M100K'
target_dir = 'raw_data/' + 'ml-10M100K'
else:
raise ValueError('Invalid dataset option %s' % dataset)

Expand Down
1 change: 1 addition & 0 deletions raw_data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ml*
1 change: 0 additions & 1 deletion raw_data/ml_100k/README.md

This file was deleted.

1 change: 0 additions & 1 deletion raw_data/ml_10m/README.md

This file was deleted.

1 change: 0 additions & 1 deletion raw_data/ml_1m/README.md

This file was deleted.

76 changes: 70 additions & 6 deletions train_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import time
import os
import math
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
from torch import tensor
Expand All @@ -8,6 +11,10 @@
from torch_geometric.data import DataLoader, DenseDataLoader as DenseLoader
from tqdm import tqdm
import pdb
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from util_functions import PyGGraph_to_nx

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand Down Expand Up @@ -149,14 +156,71 @@ def eval_rmse(model, loader, device, show_progress=False):
return rmse


def visualize(model, graphs, num=3):
def visualize(model, graphs, res_dir, data_name, class_values, num=5):
model.eval()
model.to(device)
preds = []
for data in graphs:
graph_loader = DataLoader(graphs, 50, shuffle=False)
for data in tqdm(graph_loader):
data = data.to(device)
pred = model(data)
preds.append(pred.item())
order = np.argsort(preds)
highest = [graphs[i] for i in order[-num:]]
lowest = [graphs[i] for i in order[:num]]
preds.extend(pred.view(-1).tolist())
order = np.argsort(preds).tolist()
highest = [PyGGraph_to_nx(graphs[i]) for i in order[-num:][::-1]]
lowest = [PyGGraph_to_nx(graphs[i]) for i in order[:num]]
highest_scores = [preds[i] for i in order[-num:][::-1]]
lowest_scores = [preds[i] for i in order[:num]]
scores = highest_scores + lowest_scores
type_to_label = {0: 'u0', 1: 'v0', 2: 'u1', 3: 'v1', 4: 'u2', 5: 'v2'}
#type_to_color = {0: 'r', 1: 'r', 2: 'k', 3: 'k', 4: 'k', 5: 'k'}
#type_to_color = {0: 'xkcd:orangered', 1: 'xkcd:azure', 2: 'xkcd:orange', 3: 'xkcd:lightblue', 4: 'y', 5: 'g'}
type_to_color = {0: 'xkcd:red', 1: 'xkcd:blue', 2: 'xkcd:orange', 3: 'xkcd:lightblue', 4: 'y', 5: 'g'}
plt.axis('off')
f = plt.figure(figsize=(20, 10))
axs = f.subplots(2, num)
#cmap = plt.cm.coolwarm
cmap = plt.cm.get_cmap('rainbow')
vmin, vmax = min(class_values), max(class_values)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
sm.set_array([])
for i, g in enumerate(highest + lowest):
u_nodes = [x for x, y in g.nodes(data=True) if y['type'] % 2 == 0]
u0, v0 = 0, len(u_nodes)
pos = nx.drawing.layout.bipartite_layout(g, u_nodes)
bottom_u_node = min(pos, key=lambda x: (pos[x][0], pos[x][1]))
bottom_v_node = min(pos, key=lambda x: (-pos[x][0], pos[x][1]))
# swap u0 and v0 with bottom nodes if they are not already
if u0 != bottom_u_node:
pos[u0], pos[bottom_u_node] = pos[bottom_u_node], pos[u0]
if v0 != bottom_v_node:
pos[v0], pos[bottom_v_node] = pos[bottom_v_node], pos[v0]
labels = {x: type_to_label[y] for x, y in nx.get_node_attributes(g, 'type').items()}
node_colors = [type_to_color[y] for x, y in nx.get_node_attributes(g, 'type').items()]
edge_types = nx.get_edge_attributes(g, 'type')
edge_types = [edge_types[x] for x in g.edges()]
#f.add_subplot(2, num, i+1)
axs[i//num, i%num].axis('off')
nx.draw_networkx(g, pos,
#labels=labels,
with_labels=False,
node_size=150,
node_color=node_colors, edge_color=edge_types,
ax=axs[i//num, i%num], edge_cmap=cmap, edge_vmin=vmin, edge_vmax=vmax,
)
# make u0 v0 on top of other nodes
nx.draw_networkx_nodes(g, {u0: pos[u0]}, nodelist=[u0], node_size=150,
node_color='xkcd:red', ax=axs[i//num, i%num])
nx.draw_networkx_nodes(g, {v0: pos[v0]}, nodelist=[v0], node_size=150,
node_color='xkcd:blue', ax=axs[i//num, i%num])
axs[i//num, i%num].set_title('{:.4f}'.format(scores[i]), x=0.5, y=-0.05, fontsize=20)
f.subplots_adjust(right=0.85)
cbar_ax = f.add_axes([0.88, 0.15, 0.02, 0.7])
if len(class_values) > 20:
class_values = np.linspace(min(class_values), max(class_values), 20, dtype=int).tolist()
cbar = plt.colorbar(sm, cax=cbar_ax, ticks=class_values)
cbar.ax.tick_params(labelsize=22)
f.savefig(os.path.join(res_dir, "visualization_{}.pdf".format(data_name)),
interpolation='nearest', bbox_inches='tight')



0 comments on commit 089bbf5

Please sign in to comment.