Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
laohur committed Sep 28, 2019
1 parent 1765fb0 commit 333cc92
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions StructuredSelfAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def __init__(self, config):
use_pretrained_embeddings = config["use_pretrained_embeddings"]
super(StructuredSelfAttention, self).__init__()
self.embeddings = self._load_embeddings(config,use_pretrained_embeddings, vocab_size, 300)
# self.embeddings.requires_grad = False
self.embeddings.requires_grad = False
self.lstm = torch.nn.LSTM(config["emb_dim"], config["lstm_hid_dim"], batch_first=True, bidirectional=True)
self.linear_first = torch.nn.Linear(config["lstm_hid_dim"]* 2, config["d_a"])
self.r=config["r"]
self.r=config["r"] #=1 只取句向量
self.linear_second = torch.nn.Linear(config["d_a"], self.r)
self.dropout = torch.nn.Dropout(0.1)

Expand Down Expand Up @@ -58,17 +58,17 @@ def softmax(self, input, axis=1):
def forward(self, x): # batch_size*max_len
x = x.to(device)
embeddings = self.embeddings(x) # batch_size*max_len*emb_dim
return embeddings.sum(1) #batch*emb_dim
# return embeddings.sum(1) #batch*emb_dim
outputs, _ = self.lstm(embeddings) # batch_size*max_len*emb_dim #10*256
# outputs batch_size*max_len*lstm_hid_dim
# x = F.tanh(self.linear_first(self.dropout(outputs))) # batch_size*max_len*d_a
x = F.tanh(self.linear_first(outputs)) # batch_size*max_len*d_a
x = self.linear_second(x) # batch_size*max_len*r
x = self.softmax(x, 1) #batch*seq*64

# x = self.softmax(x, 1) #batch*seq*64
x=F.softmax(x,dim=1)
attention = x.transpose(1, 2) # batch_size*r*max_len
sentence_embeddings = attention @ outputs # batch_size*r*lstm_hid_dim
avg_sentence_embeddings = torch.sum(sentence_embeddings, 1) / self.r # batch_size*lstm_hid_dim
avg_sentence_embeddings = torch.sum(sentence_embeddings, 1) / self.r # batch_size*lstm_hid_dim # 不如让r=1
# return F.log_softmax(avg_sentence_embeddings)
return avg_sentence_embeddings #batch*128

Expand Down
Binary file modified __pycache__/StructuredSelfAttention.cpython-36.pyc
Binary file not shown.
Binary file modified __pycache__/models.cpython-36.pyc
Binary file not shown.
6 changes: 3 additions & 3 deletions fewshot_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def main():
"EPISODE":1000000, #1000000
"TEST_EPISODE":1000, #1000
"LEARNING_RATE":0.0001, #0.01
"FEATURE_DIM":300,
"FEATURE_DIM":128,
"RELATION_DIM":8,
"max_len":12,
"emb_dim": 300,
"lstm_hid_dim": 64,
"d_a": 64,
"r": 16,
"r": 1,
"max_len": 10,
"n_classes": 5,
"num_layers": 1,
Expand All @@ -35,7 +35,7 @@ def main():
"vocab_size": len(word2index)
}
feature_encoder = StructuredSelfAttention(config).to(device)
relation_network = RelationNetwork(config["FEATURE_DIM"], config["RELATION_DIM"]).to(device)
relation_network = RelationNetwork(2*config["FEATURE_DIM"], config["RELATION_DIM"]).to(device)

feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=config["LEARNING_RATE"])
feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=100000, gamma=0.5)
Expand Down
2 changes: 1 addition & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class RelationNetwork(nn.Module):

def __init__(self, input_size, hidden_size):
super(RelationNetwork, self).__init__()
self.fc1 = nn.Linear(input_size * 2, hidden_size)
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, 1)

def forward(self, a, b): # 475*100 #valid 25*100
Expand Down

0 comments on commit 333cc92

Please sign in to comment.