forked from qolina/NNED
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
597 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python trigger_lstm.py -train ../ni_data/pre_processed_feng/tmp.train -test ../ni_data/pre_processed_feng/tmp.test -tag ../ni_data/pre_processed_feng/labellist -embed ../ni_data/pre_processed_feng/wordvector -vocab ../ni_data/pre_processed_feng/wordlist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
### This file is the copy of the get_oracle.py from RNNG https://github.com/clab/rnng | ||
|
||
import sys | ||
import constituent_dict | ||
|
||
# tokens is a list of tokens, so no need to split it again | ||
def unkify(tokens, words_dict): | ||
final = [] | ||
for token in tokens: | ||
# only process the train singletons and unknown words | ||
if len(token.rstrip()) == 0: | ||
final.append('UNK') | ||
elif not(token.rstrip() in words_dict): | ||
numCaps = 0 | ||
hasDigit = False | ||
hasDash = False | ||
hasLower = False | ||
for char in token.rstrip(): | ||
if char.isdigit(): | ||
hasDigit = True | ||
elif char == '-': | ||
hasDash = True | ||
elif char.isalpha(): | ||
if char.islower(): | ||
hasLower = True | ||
elif char.isupper(): | ||
numCaps += 1 | ||
result = 'UNK' | ||
lower = token.rstrip().lower() | ||
ch0 = token.rstrip()[0] | ||
if ch0.isupper(): | ||
if numCaps == 1: | ||
result = result + '-INITC' | ||
if lower in words_dict: | ||
result = result + '-KNOWNLC' | ||
else: | ||
result = result + '-CAPS' | ||
elif not(ch0.isalpha()) and numCaps > 0: | ||
result = result + '-CAPS' | ||
elif hasLower: | ||
result = result + '-LC' | ||
if hasDigit: | ||
result = result + '-NUM' | ||
if hasDash: | ||
result = result + '-DASH' | ||
if lower[-1] == 's' and len(lower) >= 3: | ||
ch2 = lower[-2] | ||
if not(ch2 == 's') and not(ch2 == 'i') and not(ch2 == 'u'): | ||
result = result + '-s' | ||
elif len(lower) >= 5 and not(hasDash) and not(hasDigit and numCaps > 0): | ||
if lower[-2:] == 'ed': | ||
result = result + '-ed' | ||
elif lower[-3:] == 'ing': | ||
result = result + '-ing' | ||
elif lower[-3:] == 'ion': | ||
result = result + '-ion' | ||
elif lower[-2:] == 'er': | ||
result = result + '-er' | ||
elif lower[-3:] == 'est': | ||
result = result + '-est' | ||
elif lower[-2:] == 'ly': | ||
result = result + '-ly' | ||
elif lower[-3:] == 'ity': | ||
result = result + '-ity' | ||
elif lower[-1] == 'y': | ||
result = result + '-y' | ||
elif lower[-2:] == 'al': | ||
result = result + '-al' | ||
final.append(result) | ||
else: | ||
final.append(token.rstrip()) | ||
return final | ||
|
||
def is_next_open_bracket(line, start_idx): | ||
for char in line[(start_idx + 1):]: | ||
if char == '(': | ||
return True | ||
elif char == ')': | ||
return False | ||
raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket') | ||
|
||
def get_between_brackets(line, start_idx): | ||
output = [] | ||
for char in line[(start_idx + 1):]: | ||
if char == ')': | ||
break | ||
assert not(char == '(') | ||
output.append(char) | ||
return ''.join(output) | ||
|
||
# start_idx = open bracket | ||
#def skip_terminals(line, start_idx): | ||
# line_end_idx = len(line) - 1 | ||
# for i in range(start_idx + 1, line_end_idx): | ||
# if line[i] == ')': | ||
# assert line[i + 1] == ' ' | ||
# return (i + 2) | ||
# raise IndexError('No close bracket found in a terminal') | ||
|
||
def get_tags_tokens_lowercase(line): | ||
output = [] | ||
#print 'curr line', line_strip | ||
line_strip = line.rstrip() | ||
#print 'length of the sentence', len(line_strip) | ||
for i in range(len(line_strip)): | ||
if i == 0: | ||
assert line_strip[i] == '(' | ||
if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol | ||
output.append(get_between_brackets(line_strip, i)) | ||
#print 'output:',output | ||
output_tags = [] | ||
output_tokens = [] | ||
output_lowercase = [] | ||
for terminal in output: | ||
terminal_split = terminal.split() | ||
assert len(terminal_split) == 2 # each terminal contains a POS tag and word | ||
output_tags.append(terminal_split[0]) | ||
output_tokens.append(terminal_split[1]) | ||
output_lowercase.append(terminal_split[1].lower()) | ||
return [output_tags, output_tokens, output_lowercase] | ||
|
||
def get_nonterminal(line, start_idx): | ||
assert line[start_idx] == '(' # make sure it's an open bracket | ||
output = [] | ||
for char in line[(start_idx + 1):]: | ||
if char == ' ': | ||
break | ||
assert not(char == '(') and not(char == ')') | ||
output.append(char) | ||
return ''.join(output) | ||
|
||
|
||
def get_actions(line): | ||
output_actions = [] | ||
line_strip = line.rstrip() | ||
i = 0 | ||
max_idx = (len(line_strip) - 1) | ||
while i <= max_idx: | ||
assert line_strip[i] == '(' or line_strip[i] == ')' | ||
if line_strip[i] == '(': | ||
if is_next_open_bracket(line_strip, i): # open non-terminal | ||
curr_NT = get_nonterminal(line_strip, i) | ||
output_actions.append('NT(' + curr_NT + ')') | ||
i += 1 | ||
while line_strip[i] != '(': # get the next open bracket, which may be a terminal or another non-terminal | ||
i += 1 | ||
else: # it's a terminal symbol | ||
output_actions.append('SHIFT') | ||
while line_strip[i] != ')': | ||
i += 1 | ||
i += 1 | ||
while line_strip[i] != ')' and line_strip[i] != '(': | ||
i += 1 | ||
else: | ||
output_actions.append('REDUCE') | ||
if i == max_idx: | ||
break | ||
i += 1 | ||
while line_strip[i] != ')' and line_strip[i] != '(': | ||
i += 1 | ||
assert i == max_idx | ||
return output_actions | ||
|
||
def main(): | ||
if len(sys.argv) != 3: | ||
raise NotImplementedError('Program only takes two arguments: train file and dev file (for vocabulary mapping purposes)') | ||
train_file = open(sys.argv[1], 'r') | ||
lines = train_file.readlines() | ||
train_file.close() | ||
dev_file = open(sys.argv[2], 'r') | ||
dev_lines = dev_file.readlines() | ||
dev_file.close() | ||
words_list = constituent_dict.get_dict(lines) | ||
line_ctr = 0 | ||
# get the oracle for the train file | ||
for line in dev_lines: | ||
line_ctr += 1 | ||
# assert that the parenthesis are balanced | ||
if line.count('(') != line.count(')'): | ||
raise NotImplementedError('Unbalanced number of parenthesis in line ' + str(line_ctr)) | ||
# first line: the bracketed tree itself itself | ||
print '# ' + line.rstrip() | ||
tags, tokens, lowercase = get_tags_tokens_lowercase(line) | ||
assert len(tags) == len(tokens) | ||
assert len(tokens) == len(lowercase) | ||
print ' '.join(tags) | ||
print ' '.join(tokens) | ||
print ' '.join(lowercase) | ||
unkified = unkify(tokens, words_list) | ||
print ' '.join(unkified) | ||
output_actions = get_actions(line) | ||
for action in output_actions: | ||
print action | ||
print '' | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import torch | ||
import torch.autograd as autograd | ||
from torch.autograd import Variable | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
|
||
torch.manual_seed(1) | ||
|
||
class LSTMTrigger(nn.Module): | ||
def __init__(self, pretrain_embedding, pretrain_embed_dim, hidden_dim, vocab_size, tagset_size, dropout, bilstm, num_layers, random_dim, gpu): | ||
super(LSTMTrigger, self).__init__() | ||
|
||
embedding_dim = pretrain_embed_dim | ||
self.hidden_dim = hidden_dim | ||
self.random_embed = False | ||
if random_dim >= 50: | ||
self.word_embeddings = nn.Embedding(vocab_size, random_dim) | ||
self.pretrain_word_embeddings = torch.from_numpy(pretrain_embedding) | ||
self.random_embed = True | ||
embedding_dim += random_dim | ||
else: | ||
self.word_embeddings = nn.Embedding(vocab_size, pretrain_embed_dim) | ||
if pretrain_embedding is not None: | ||
self.word_embeddings.weight.data.copy_(torch.from_numpy(pretrain_embedding)) | ||
|
||
self.drop = nn.Dropout(dropout) | ||
self.bilstm_flag = bilstm | ||
self.lstm_layer = num_layers | ||
|
||
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=self.lstm_layer, bidirectional=self.bilstm_flag) | ||
if self.bilstm_flag: | ||
self.hidden2tag = nn.Linear(hidden_dim*2, tagset_size) | ||
else: | ||
self.hidden2tag = nn.Linear(hidden_dim, tagset_size) | ||
if gpu: | ||
self.drop = self.drop.cuda() | ||
self.word_embeddings = self.word_embeddings.cuda() | ||
self.lstm = self.lstm.cuda() | ||
self.hidden2tag = self.hidden2tag.cuda() | ||
|
||
self.hidden = self.init_hidden(gpu) | ||
|
||
def init_hidden(self, gpu): | ||
if self.bilstm_flag: | ||
h0 = autograd.Variable(torch.zeros(2*self.lstm_layer, 1, self.hidden_dim)) | ||
c0 = autograd.Variable(torch.zeros(2*self.lstm_layer, 1, self.hidden_dim)) | ||
else: | ||
h0 = autograd.Variable(torch.zeros(self.lstm_layer, 1, self.hidden_dim)) | ||
c0 = autograd.Variable(torch.zeros(self.lstm_layer, 1, self.hidden_dim)) | ||
|
||
if gpu: | ||
h0 = h0.cuda() | ||
c0 = c0.cuda() | ||
return (h0,c0) | ||
|
||
def forward(self, sentence, gpu): | ||
self.hidden = self.init_hidden(gpu) | ||
|
||
embeds = self.word_embeddings(sentence) | ||
#print embeds | ||
|
||
if self.random_embed: | ||
sent_tensor = sentence.data | ||
embeds = embeds.data | ||
if gpu: sent_tensor = sent_tensor.cpu() | ||
if gpu: embeds = embeds.cpu() | ||
pretrain_embeds = torch.index_select(self.pretrain_word_embeddings, 0, sent_tensor) | ||
embeds = torch.cat((pretrain_embeds, embeds.double()), 1) | ||
embeds = Variable(embeds.float()) | ||
if gpu: embeds = embeds.cuda() | ||
#print embeds | ||
|
||
embeds = self.drop(embeds) | ||
lstm_out, self.hidden = self.lstm( | ||
embeds.view(len(sentence), 1, -1), self.hidden) | ||
tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1)) | ||
tag_scores = F.log_softmax(tag_space) | ||
return tag_scores | ||
|
||
|
Oops, something went wrong.