Skip to content

Commit

Permalink
fixed small bugs adwords env
Browse files Browse the repository at this point in the history
  • Loading branch information
alomrani committed Oct 31, 2021
1 parent 5b89be1 commit 79d1a39
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion policy/ff_model_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(

self.embedding_dim = embedding_dim
self.decode_type = None
self.num_actions = 3 if opts.problem != "adwords" else 4
self.num_actions = 3 if opts.problem != "adwords" else 5
self.is_bipartite = problem.NAME == "bipartite"
self.problem = problem
self.shrink_size = None
Expand Down
4 changes: 2 additions & 2 deletions policy/gnn_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _inner(self, input, opts):
s,
idx.repeat(1, state.u_size + 1, 1),
state.size.unsqueeze(2).repeat(1, state.u_size + 1, 1)
/ state.u_size,
/ state.orig_budget.sum(-1)[:, None, None],
fixed_node_identity,
mask.unsqueeze(2),
incoming_node_embeddings.repeat(1, state.u_size + 1, 1),
Expand All @@ -228,7 +228,7 @@ def _inner(self, input, opts):
u_embeddings.mean(1).unsqueeze(1).repeat(1, state.u_size + 1, 1),
),
dim=2,
)
).float()
pi = self.ff(s).reshape(state.batch_size, state.u_size + 1)
# Select the indices of the next nodes in the sequences, result (batch_size) long
selected, p = self._select_node(
Expand Down
2 changes: 1 addition & 1 deletion policy/inv_ff_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(

self.embedding_dim = embedding_dim
self.decode_type = None
self.num_actions = 16 if opts.problem != "adwords" else 18
self.num_actions = 16 if opts.problem != "adwords" else 19
self.problem = problem
self.model_name = "inv-ff-hist"
self.ff = nn.Sequential(
Expand Down
18 changes: 9 additions & 9 deletions problem_state/adwords_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ def get_curr_state(self, model):
w = self.adj[:, 0, :].float().clone()
s = None
if model == "ff":
s = torch.cat((w, self.curr_budget, mask.float()), dim=1)
s = torch.cat((w, self.curr_budget, mask.float()), dim=1).float()
elif model == "inv-ff":
deg = (w != 0).float().sum(1)
deg[deg == 0.0] = 1.0
mean_w = w.sum(1) / deg
mean_budget = self.curr_budget.sum(2) / self.u_size
mean_budget = mean_budget[:, None, :].repeat(1, self.u_size + 1, 1)
mean_budget = self.curr_budget.sum(1) / self.u_size
mean_budget = mean_budget[:, None, None].repeat(1, self.u_size + 1, 1)
mean_w = mean_w[:, None, None].repeat(1, self.u_size + 1, 1)
fixed_node_identity = torch.zeros(
self.batch_size, self.u_size + 1, 1, device=opts.device
Expand All @@ -188,7 +188,7 @@ def get_curr_state(self, model):
mean_budget,
),
dim=2,
)
).float()

elif model == "ff-hist" or model == "ff-supervised":
(
Expand All @@ -210,7 +210,7 @@ def get_curr_state(self, model):
h_mean.squeeze(1),
h_var.squeeze(1),
h_mean_degree.squeeze(1),
self.size / self.orig_budget.sum(-1),
self.size / self.orig_budget.sum(-1).unsqueeze(1),
ind.float(),
mean_sol,
var_sol,
Expand All @@ -226,9 +226,9 @@ def get_curr_state(self, model):
deg = (w != 0).float().sum(1)
deg[deg == 0.0] = 1.0
mean_w = w.sum(1) / deg
mean_budget = self.curr_budget.sum(2) / self.u_size
mean_budget = self.curr_budget.sum(1) / self.u_size
mean_w = mean_w[:, None, None].repeat(1, self.u_size + 1, 1)
mean_budget = mean_budget[:, None, :].repeat(1, self.u_size + 1, 1)
mean_budget = mean_budget[:, None, None].repeat(1, self.u_size + 1, 1)
s = w.reshape(self.batch_size, self.u_size + 1, 1)
(
h_mean,
Expand Down Expand Up @@ -258,7 +258,7 @@ def get_curr_state(self, model):
h_mean_degree.transpose(1, 2),
ind.unsqueeze(2).repeat(1, self.u_size + 1, 1),
self.size.unsqueeze(2).repeat(1, self.u_size + 1, 1)
/ self.orig_budget.sum(-1),
/ self.orig_budget.sum(-1)[:, None, None],
mean_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
var_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
n_skip.unsqueeze(2).repeat(1, self.u_size + 1, 1),
Expand Down Expand Up @@ -293,7 +293,7 @@ def get_node_features(self):
(future_node_feature, fixed_node_feature, incoming_node_features), dim=1
).reshape(batch_size, step_size)

return node_features
return node_features.float()

def get_hist_features(self):
i = self.i - (self.u_size + 1)
Expand Down

0 comments on commit 79d1a39

Please sign in to comment.