Skip to content

Commit

Permalink
lstm_trigger
Browse files Browse the repository at this point in the history
  • Loading branch information
qolina committed Jul 31, 2017
1 parent d830bb8 commit f9d4cc8
Show file tree
Hide file tree
Showing 5 changed files with 597 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/cmd.trigger
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python trigger_lstm.py -train ../ni_data/pre_processed_feng/tmp.train -test ../ni_data/pre_processed_feng/tmp.test -tag ../ni_data/pre_processed_feng/labellist -embed ../ni_data/pre_processed_feng/wordvector -vocab ../ni_data/pre_processed_feng/wordlist
198 changes: 198 additions & 0 deletions src/get_constituent_topdown_oracle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
### This file is the copy of the get_oracle.py from RNNG https://github.com/clab/rnng

import sys
import constituent_dict

# tokens is a list of tokens, so no need to split it again
def unkify(tokens, words_dict):
final = []
for token in tokens:
# only process the train singletons and unknown words
if len(token.rstrip()) == 0:
final.append('UNK')
elif not(token.rstrip() in words_dict):
numCaps = 0
hasDigit = False
hasDash = False
hasLower = False
for char in token.rstrip():
if char.isdigit():
hasDigit = True
elif char == '-':
hasDash = True
elif char.isalpha():
if char.islower():
hasLower = True
elif char.isupper():
numCaps += 1
result = 'UNK'
lower = token.rstrip().lower()
ch0 = token.rstrip()[0]
if ch0.isupper():
if numCaps == 1:
result = result + '-INITC'
if lower in words_dict:
result = result + '-KNOWNLC'
else:
result = result + '-CAPS'
elif not(ch0.isalpha()) and numCaps > 0:
result = result + '-CAPS'
elif hasLower:
result = result + '-LC'
if hasDigit:
result = result + '-NUM'
if hasDash:
result = result + '-DASH'
if lower[-1] == 's' and len(lower) >= 3:
ch2 = lower[-2]
if not(ch2 == 's') and not(ch2 == 'i') and not(ch2 == 'u'):
result = result + '-s'
elif len(lower) >= 5 and not(hasDash) and not(hasDigit and numCaps > 0):
if lower[-2:] == 'ed':
result = result + '-ed'
elif lower[-3:] == 'ing':
result = result + '-ing'
elif lower[-3:] == 'ion':
result = result + '-ion'
elif lower[-2:] == 'er':
result = result + '-er'
elif lower[-3:] == 'est':
result = result + '-est'
elif lower[-2:] == 'ly':
result = result + '-ly'
elif lower[-3:] == 'ity':
result = result + '-ity'
elif lower[-1] == 'y':
result = result + '-y'
elif lower[-2:] == 'al':
result = result + '-al'
final.append(result)
else:
final.append(token.rstrip())
return final

def is_next_open_bracket(line, start_idx):
for char in line[(start_idx + 1):]:
if char == '(':
return True
elif char == ')':
return False
raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket')

def get_between_brackets(line, start_idx):
output = []
for char in line[(start_idx + 1):]:
if char == ')':
break
assert not(char == '(')
output.append(char)
return ''.join(output)

# start_idx = open bracket
#def skip_terminals(line, start_idx):
# line_end_idx = len(line) - 1
# for i in range(start_idx + 1, line_end_idx):
# if line[i] == ')':
# assert line[i + 1] == ' '
# return (i + 2)
# raise IndexError('No close bracket found in a terminal')

def get_tags_tokens_lowercase(line):
output = []
#print 'curr line', line_strip
line_strip = line.rstrip()
#print 'length of the sentence', len(line_strip)
for i in range(len(line_strip)):
if i == 0:
assert line_strip[i] == '('
if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol
output.append(get_between_brackets(line_strip, i))
#print 'output:',output
output_tags = []
output_tokens = []
output_lowercase = []
for terminal in output:
terminal_split = terminal.split()
assert len(terminal_split) == 2 # each terminal contains a POS tag and word
output_tags.append(terminal_split[0])
output_tokens.append(terminal_split[1])
output_lowercase.append(terminal_split[1].lower())
return [output_tags, output_tokens, output_lowercase]

def get_nonterminal(line, start_idx):
assert line[start_idx] == '(' # make sure it's an open bracket
output = []
for char in line[(start_idx + 1):]:
if char == ' ':
break
assert not(char == '(') and not(char == ')')
output.append(char)
return ''.join(output)


def get_actions(line):
output_actions = []
line_strip = line.rstrip()
i = 0
max_idx = (len(line_strip) - 1)
while i <= max_idx:
assert line_strip[i] == '(' or line_strip[i] == ')'
if line_strip[i] == '(':
if is_next_open_bracket(line_strip, i): # open non-terminal
curr_NT = get_nonterminal(line_strip, i)
output_actions.append('NT(' + curr_NT + ')')
i += 1
while line_strip[i] != '(': # get the next open bracket, which may be a terminal or another non-terminal
i += 1
else: # it's a terminal symbol
output_actions.append('SHIFT')
while line_strip[i] != ')':
i += 1
i += 1
while line_strip[i] != ')' and line_strip[i] != '(':
i += 1
else:
output_actions.append('REDUCE')
if i == max_idx:
break
i += 1
while line_strip[i] != ')' and line_strip[i] != '(':
i += 1
assert i == max_idx
return output_actions

def main():
if len(sys.argv) != 3:
raise NotImplementedError('Program only takes two arguments: train file and dev file (for vocabulary mapping purposes)')
train_file = open(sys.argv[1], 'r')
lines = train_file.readlines()
train_file.close()
dev_file = open(sys.argv[2], 'r')
dev_lines = dev_file.readlines()
dev_file.close()
words_list = constituent_dict.get_dict(lines)
line_ctr = 0
# get the oracle for the train file
for line in dev_lines:
line_ctr += 1
# assert that the parenthesis are balanced
if line.count('(') != line.count(')'):
raise NotImplementedError('Unbalanced number of parenthesis in line ' + str(line_ctr))
# first line: the bracketed tree itself itself
print '# ' + line.rstrip()
tags, tokens, lowercase = get_tags_tokens_lowercase(line)
assert len(tags) == len(tokens)
assert len(tokens) == len(lowercase)
print ' '.join(tags)
print ' '.join(tokens)
print ' '.join(lowercase)
unkified = unkify(tokens, words_list)
print ' '.join(unkified)
output_actions = get_actions(line)
for action in output_actions:
print action
print ''


if __name__ == "__main__":
main()
81 changes: 81 additions & 0 deletions src/lstm_trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

class LSTMTrigger(nn.Module):
def __init__(self, pretrain_embedding, pretrain_embed_dim, hidden_dim, vocab_size, tagset_size, dropout, bilstm, num_layers, random_dim, gpu):
super(LSTMTrigger, self).__init__()

embedding_dim = pretrain_embed_dim
self.hidden_dim = hidden_dim
self.random_embed = False
if random_dim >= 50:
self.word_embeddings = nn.Embedding(vocab_size, random_dim)
self.pretrain_word_embeddings = torch.from_numpy(pretrain_embedding)
self.random_embed = True
embedding_dim += random_dim
else:
self.word_embeddings = nn.Embedding(vocab_size, pretrain_embed_dim)
if pretrain_embedding is not None:
self.word_embeddings.weight.data.copy_(torch.from_numpy(pretrain_embedding))

self.drop = nn.Dropout(dropout)
self.bilstm_flag = bilstm
self.lstm_layer = num_layers

self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=self.lstm_layer, bidirectional=self.bilstm_flag)
if self.bilstm_flag:
self.hidden2tag = nn.Linear(hidden_dim*2, tagset_size)
else:
self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
if gpu:
self.drop = self.drop.cuda()
self.word_embeddings = self.word_embeddings.cuda()
self.lstm = self.lstm.cuda()
self.hidden2tag = self.hidden2tag.cuda()

self.hidden = self.init_hidden(gpu)

def init_hidden(self, gpu):
if self.bilstm_flag:
h0 = autograd.Variable(torch.zeros(2*self.lstm_layer, 1, self.hidden_dim))
c0 = autograd.Variable(torch.zeros(2*self.lstm_layer, 1, self.hidden_dim))
else:
h0 = autograd.Variable(torch.zeros(self.lstm_layer, 1, self.hidden_dim))
c0 = autograd.Variable(torch.zeros(self.lstm_layer, 1, self.hidden_dim))

if gpu:
h0 = h0.cuda()
c0 = c0.cuda()
return (h0,c0)

def forward(self, sentence, gpu):
self.hidden = self.init_hidden(gpu)

embeds = self.word_embeddings(sentence)
#print embeds

if self.random_embed:
sent_tensor = sentence.data
embeds = embeds.data
if gpu: sent_tensor = sent_tensor.cpu()
if gpu: embeds = embeds.cpu()
pretrain_embeds = torch.index_select(self.pretrain_word_embeddings, 0, sent_tensor)
embeds = torch.cat((pretrain_embeds, embeds.double()), 1)
embeds = Variable(embeds.float())
if gpu: embeds = embeds.cuda()
#print embeds

embeds = self.drop(embeds)
lstm_out, self.hidden = self.lstm(
embeds.view(len(sentence), 1, -1), self.hidden)
tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
tag_scores = F.log_softmax(tag_space)
return tag_scores


Loading

0 comments on commit f9d4cc8

Please sign in to comment.