Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
alomrani committed Jun 15, 2021
2 parents 2c8e829 + 0673d68 commit 6f013be
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 18 deletions.
24 changes: 12 additions & 12 deletions encoder/graph_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ def __init__(
):
super(MPNN, self).__init__()
self.l1 = nn.Linear(1, embed_dim ** 2)
self.l2 = nn.Linear(1, embed_dim ** 2)
self.node_embed = nn.Linear(node_dim_u, embed_dim)
self.node_embed_u = nn.Linear(node_dim_u, embed_dim)
if node_dim_u != node_dim_v:
self.node_embed_v = nn.Linear(node_dim_v, embed_dim)

self.conv1 = NNConv(embed_dim, embed_dim, self.l1, aggr="mean")
self.conv2 = NNConv(embed_dim, embed_dim, self.l2, aggr="mean")
self.n_layers = n_layers
self.problem = opts.problem

Expand All @@ -45,15 +43,17 @@ def forward(self, x_u, x_v, edge_index, edge_attribute, i, dummy):
x_v = self.node_embed_v(x_v)
x = torch.cat((x_u, x_v), dim=0)
else:
x = self.node_embed(x_u)

x = F.relu(x)
x = self.conv1(x, edge_index, edge_attribute.float())
x = F.relu(x)
x = self.conv2(x, edge_index, edge_attribute.float())
#for j in range(n_encode_layers):
# x = F.relu(x)
# x = self.conv1(x, edge_index, edge_attribute.float())
x = self.node_embed_u(x_u)
x_v = self.node_embed_u(x_v)
x = torch.cat((x_u, x_v), dim=0)

# x = F.relu(x)
# x = self.conv1(x, edge_index, edge_attribute.float())
# x = F.relu(x)
# x = self.conv2(x, edge_index, edge_attribute.float())
for j in range(n_encode_layers):
x = F.relu(x)
x = self.conv1(x, edge_index, edge_attribute.float())

# x = self.norm(x.view(-1, x.size(-1))).view(*x.size())

Expand Down
2 changes: 1 addition & 1 deletion policy/ff_model_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(

self.embedding_dim = embedding_dim
self.decode_type = None
self.num_actions = 5 * (opts.u_size + 1) + 7
self.num_actions = 5 * (opts.u_size + 1) + 8
self.is_bipartite = problem.NAME == "bipartite"
self.problem = problem
self.shrink_size = None
Expand Down
1 change: 0 additions & 1 deletion policy/gnn_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def _inner(self, input, opts):
),
dim=2,
)
# print(s)
pi = self.ff(s).reshape(state.batch_size, state.u_size + 1)
# Select the indices of the next nodes in the sequences, result (batch_size) long
selected, p = self._select_node(
Expand Down
10 changes: 7 additions & 3 deletions problem_state/edge_obm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def initialize(

# permute the nodes for data
idx = torch.arange(adj.shape[1], device=opts.device)
if "supervised" not in opts.model and not opts.eval_only:
idx = torch.randperm(adj.shape[1], device=opts.device)
adj = adj[:, idx, :].view(adj.size())
# if "supervised" not in opts.model and not opts.eval_only:
# idx = torch.randperm(adj.shape[1], device=opts.device)
# adj = adj[:, idx, :].view(adj.size())

return StateEdgeBipartite(
graphs=input,
Expand Down Expand Up @@ -164,6 +164,7 @@ def get_curr_state(self, model):
self.sum_sol_sq - ((self.size ** 2) / curr_sol_size)
) / curr_sol_size
mean_sol = self.size / curr_sol_size
matched_ratio = self.matched_nodes.sum(1) / self.u_size
s = torch.cat(
(
w,
Expand All @@ -178,6 +179,7 @@ def get_curr_state(self, model):
self.num_skip / i,
self.max_sol,
self.min_sol,
matched_ratio.unsqueeze(1),
),
dim=1,
).float()
Expand All @@ -197,6 +199,7 @@ def get_curr_state(self, model):
self.sum_sol_sq - ((self.size ** 2) / curr_sol_size)
) / curr_sol_size
mean_sol = self.size / curr_sol_size
matched_ratio = self.matched_nodes.sum(1) / self.u_size
s = torch.cat(
(
s,
Expand All @@ -212,6 +215,7 @@ def get_curr_state(self, model):
self.num_skip.unsqueeze(2).repeat(1, self.u_size + 1, 1) / i,
self.max_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
self.min_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
matched_ratio.unsqueeze(2).repeat(1, self.u_size + 1, 1),
),
dim=2,
).float()
Expand Down
6 changes: 5 additions & 1 deletion problem_state/osbm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def get_curr_state(self, model):
s[:, 0, :], mean_w[:, 0, :] = -1.0, -1.0

s = torch.cat((s, mean_w,), dim=2,)
elif model == "ff-hist":
elif model == "ff-hist" or model == "ff-supervised":
w = w.clone()
h_mean = self.hist_sum.squeeze(1) / i
h_var = ((self.hist_sum_sq - ((self.hist_sum ** 2) / i)) / i).squeeze(1)
Expand All @@ -263,6 +263,7 @@ def get_curr_state(self, model):
self.sum_sol_sq - ((self.size ** 2) / curr_sol_size)
) / curr_sol_size
mean_sol = self.size / curr_sol_size
matched_ratio = self.matched_nodes.sum(1) / self.u_size
s = torch.cat(
(
w,
Expand All @@ -277,6 +278,7 @@ def get_curr_state(self, model):
self.num_skip / i,
self.max_sol,
self.min_sol,
matched_ratio.unsqueeze(1),
),
dim=1,
).float()
Expand All @@ -295,6 +297,7 @@ def get_curr_state(self, model):
self.sum_sol_sq - ((self.size ** 2) / curr_sol_size)
) / curr_sol_size
mean_sol = self.size / curr_sol_size
matched_ratio = self.matched_nodes.sum(1) / self.u_size
s = torch.cat(
(
s,
Expand All @@ -310,6 +313,7 @@ def get_curr_state(self, model):
self.num_skip.unsqueeze(2).repeat(1, self.u_size + 1, 1) / i,
self.max_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
self.min_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
matched_ratio.unsqueeze(2).repeat(1, self.u_size + 1, 1),
),
dim=2,
).float()
Expand Down

0 comments on commit 6f013be

Please sign in to comment.