-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
1,635 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
.DS_Store | ||
*.env | ||
*.tmp | ||
*.conf | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
SATRE with Data-efficiency and Computational Efficiency | ||
========== | ||
|
||
This code introduces Self-attention over Tree for Relation Extraction (SATRE) for the large scale sentence-level relation extraction task (TACRED). | ||
|
||
|
||
See below for an overview of the model architecture: | ||
|
||
![SATRE](fig/satre.png"SATRE") | ||
|
||
|
||
|
||
## Requirements | ||
|
||
Our model was trained on two Nvidia GTX 1080Ti graphic cards. | ||
|
||
- Python 3 (tested on 3.7.6) | ||
|
||
- Pytorch (tested on 1.2.0) | ||
- CUDA (tested on 10.2.89) | ||
- tqdm | ||
|
||
- unzip, wget (for downloading only) | ||
|
||
|
||
|
||
## Preparation | ||
|
||
The code requires that you have access to the TACRED dataset (LDC license required). Once you have the TACRED data, please put the JSON files under the directory `dataset/tacred`. | ||
|
||
First, download and unzip GloVe vectors: | ||
|
||
``` | ||
chmod +x download.sh; ./download.sh | ||
``` | ||
|
||
|
||
|
||
Then prepare vocabulary and initial word vectors with: | ||
|
||
``` | ||
python3 prepare_vocab.py dataset/tacred dataset/vocab --glove_dir dataset/glove | ||
``` | ||
|
||
|
||
|
||
This will write vocabulary and word vectors as a numpy matrix into the dir `dataset/vocab`. | ||
|
||
|
||
|
||
## Training | ||
|
||
To train the SATRE model, run: | ||
|
||
``` | ||
bash train.sh 0 1 | ||
``` | ||
|
||
Model checkpoints and logs will be saved to `./saved_models/1`. | ||
|
||
For details on the use of other parameters, please refer to `train.py`. | ||
|
||
|
||
|
||
## Evaluation | ||
|
||
Our trained model is saved under the dir saved_models/1. To run evaluation on the test set, run: | ||
|
||
``` | ||
bash eval.sh 0 1 test | ||
``` | ||
|
||
|
||
|
||
## Related Repo | ||
|
||
Codes are adapted from the repo of the EMNLP18 paper [Graph Convolution over Pruned Dependency Trees Improves Relation Extraction](https://nlp.stanford.edu/pubs/zhang2018graph.pdf) and the repo of the ACL19 paper [Attention Guided Graph Convolutional Networks for Relation Extraction](https://aclanthology.org/P19-1024/). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
""" | ||
Data loader for TACRED json files. | ||
""" | ||
|
||
import json | ||
import random | ||
import torch | ||
import numpy as np | ||
|
||
from utils import constant | ||
|
||
class DataLoader(object): | ||
""" | ||
Load data from json files, preprocess and prepare batches. | ||
""" | ||
def __init__(self, filename, batch_size, opt, vocab, evaluation=False): | ||
self.batch_size = batch_size | ||
self.opt = opt | ||
self.vocab = vocab | ||
self.eval = evaluation | ||
self.label2id = constant.LABEL_TO_ID | ||
|
||
with open(filename) as infile: | ||
data = json.load(infile) | ||
self.raw_data = data | ||
data = self.preprocess(data, vocab, opt) | ||
|
||
# shuffle for training | ||
if not evaluation: | ||
indices = list(range(len(data))) | ||
random.shuffle(indices) | ||
data = [data[i] for i in indices] | ||
self.id2label = dict([(v,k) for k,v in self.label2id.items()]) | ||
self.labels = [self.id2label[d[-1]] for d in data] | ||
self.num_examples = len(data) | ||
|
||
# chunk into batches | ||
data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)] | ||
self.data = data | ||
print("{} batches created for {}".format(len(data), filename)) | ||
|
||
def preprocess(self, data, vocab, opt): | ||
""" Preprocess the data and convert to ids. """ | ||
processed = [] | ||
for d in data: | ||
tokens = list(d['token']) | ||
if opt['lower']: | ||
tokens = [t.lower() for t in tokens] | ||
# anonymize tokens | ||
ss, se = d['subj_start'], d['subj_end'] | ||
os, oe = d['obj_start'], d['obj_end'] | ||
tokens[ss:se+1] = ['SUBJ-'+d['subj_type']] * (se-ss+1) | ||
tokens[os:oe+1] = ['OBJ-'+d['obj_type']] * (oe-os+1) | ||
tokens = map_to_ids(tokens, vocab.word2id) | ||
pos = map_to_ids(d['stanford_pos'], constant.POS_TO_ID) | ||
ner = map_to_ids(d['stanford_ner'], constant.NER_TO_ID) | ||
deprel = map_to_ids(d['stanford_deprel'], constant.DEPREL_TO_ID) | ||
head = [int(x) for x in d['stanford_head']] | ||
assert any([x == 0 for x in head]) | ||
l = len(tokens) | ||
subj_positions = get_positions(d['subj_start'], d['subj_end'], l) | ||
obj_positions = get_positions(d['obj_start'], d['obj_end'], l) | ||
subj_type = [constant.SUBJ_NER_TO_ID[d['subj_type']]] | ||
obj_type = [constant.OBJ_NER_TO_ID[d['obj_type']]] | ||
relation = self.label2id[d['relation']] | ||
processed += [(tokens, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, relation)] | ||
return processed | ||
|
||
def gold(self): | ||
""" Return gold labels as a list. """ | ||
return self.labels | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, key): | ||
""" Get a batch with index. """ | ||
if not isinstance(key, int): | ||
raise TypeError | ||
if key < 0 or key >= len(self.data): | ||
raise IndexError | ||
batch = self.data[key] | ||
batch_size = len(batch) | ||
batch = list(zip(*batch)) | ||
assert len(batch) == 10 | ||
|
||
# sort all fields by lens for easy RNN operations | ||
lens = [len(x) for x in batch[0]] | ||
batch, orig_idx = sort_all(batch, lens) | ||
|
||
# word dropout | ||
if not self.eval: | ||
words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]] | ||
else: | ||
words = batch[0] | ||
|
||
# convert to tensors | ||
words = get_long_tensor(words, batch_size) | ||
masks = torch.eq(words, 0) | ||
pos = get_long_tensor(batch[1], batch_size) | ||
ner = get_long_tensor(batch[2], batch_size) | ||
deprel = get_long_tensor(batch[3], batch_size) | ||
head = get_long_tensor(batch[4], batch_size) | ||
subj_positions = get_long_tensor(batch[5], batch_size) | ||
obj_positions = get_long_tensor(batch[6], batch_size) | ||
subj_type = get_long_tensor(batch[7], batch_size) | ||
obj_type = get_long_tensor(batch[8], batch_size) | ||
|
||
rels = torch.LongTensor(batch[9]) | ||
|
||
return (words, masks, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, rels, orig_idx) | ||
|
||
def __iter__(self): | ||
for i in range(self.__len__()): | ||
yield self.__getitem__(i) | ||
|
||
|
||
def map_to_ids(tokens, vocab): | ||
ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens] | ||
return ids | ||
|
||
|
||
def get_positions(start_idx, end_idx, length):# [-3, -2, -1, 0, 0, 0, 1, 2, 3, 4] | ||
""" Get subj/obj position sequence. """ | ||
return list(range(-start_idx, 0)) + [0]*(end_idx - start_idx + 1) + \ | ||
list(range(1, length-end_idx)) | ||
|
||
|
||
def get_long_tensor(tokens_list, batch_size): | ||
""" Convert list of list of tokens to a padded LongTensor. """ | ||
token_len = max(len(x) for x in tokens_list) | ||
tokens = torch.LongTensor(batch_size, token_len).fill_(constant.PAD_ID) | ||
for i, s in enumerate(tokens_list): | ||
tokens[i, :len(s)] = torch.LongTensor(s) | ||
return tokens | ||
|
||
|
||
def sort_all(batch, lens): | ||
""" Sort all fields by descending order of lens, and return the original indices. """ | ||
unsorted_all = [lens] + [range(len(lens))] + list(batch) | ||
sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))] | ||
return sorted_all[2:], sorted_all[1] | ||
|
||
|
||
def word_dropout(tokens, dropout): | ||
""" Randomly dropout tokens (IDs) and replace them with <UNK> tokens. """ | ||
return [constant.UNK_ID if x != constant.UNK_ID and np.random.random() < dropout \ | ||
else x for x in tokens] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/bin/bash | ||
|
||
cd dataset; mkdir glove | ||
cd glove | ||
|
||
echo "==> Downloading glove vectors..." | ||
wget http://nlp.stanford.edu/data/glove.840B.300d.zip | ||
|
||
echo "==> Unzipping glove vectors..." | ||
unzip glove.840B.300d.zip | ||
rm glove.840B.300d.zip | ||
|
||
echo "==> Done." | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
""" | ||
Run evaluation with saved models. | ||
""" | ||
import random | ||
import argparse | ||
from tqdm import tqdm | ||
import torch | ||
|
||
from data.loader import DataLoader | ||
from model.trainer import GCNTrainer | ||
from utils import torch_utils, scorer, constant, helper | ||
from utils.vocab import Vocab | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('model_dir', type=str, help='Directory of the model.') | ||
parser.add_argument('--model', type=str, default='best_model.pt', help='Name of the model file.') | ||
parser.add_argument('--data_dir', type=str, default='dataset/tacred') | ||
parser.add_argument('--dataset', type=str, default='test', help="Evaluate on dev or test.") | ||
|
||
parser.add_argument('--seed', type=int, default=1234) | ||
parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available()) | ||
parser.add_argument('--cpu', action='store_true') | ||
args = parser.parse_args() | ||
|
||
torch.manual_seed(args.seed) | ||
random.seed(1234) | ||
if args.cpu: | ||
args.cuda = False | ||
elif args.cuda: | ||
torch.cuda.manual_seed(args.seed) | ||
|
||
# load opt | ||
model_file = args.model_dir + '/' + args.model | ||
print("Loading model from {}".format(model_file)) | ||
opt = torch_utils.load_config(model_file) | ||
trainer = GCNTrainer(opt) | ||
trainer.load(model_file) | ||
|
||
# load vocab | ||
vocab_file = args.model_dir + '/vocab.pkl' | ||
vocab = Vocab(vocab_file, load=True) | ||
assert opt['vocab_size'] == vocab.size, "Vocab size must match that in the saved model." | ||
|
||
# load data | ||
data_file = opt['data_dir'] + '/{}.json'.format(args.dataset) | ||
print("Loading data from {} with batch size {}...".format(data_file, opt['batch_size'])) | ||
batch = DataLoader(data_file, opt['batch_size'], opt, vocab, evaluation=True) | ||
|
||
# helper.print_config(opt) | ||
label2id = constant.LABEL_TO_ID | ||
id2label = dict([(v,k) for k,v in label2id.items()]) | ||
|
||
predictions = [] | ||
all_probs = [] | ||
batch_iter = tqdm(batch, mininterval=2) | ||
for i, b in enumerate(batch_iter): | ||
preds, probs, _ = trainer.predict(b) | ||
predictions += preds | ||
all_probs += probs | ||
|
||
predictions = [id2label[p] for p in predictions] | ||
p, r, f1 = scorer.score(batch.gold(), predictions, verbose=False) | ||
print("{} set evaluate result: {:.2f}\t{:.2f}\t{:.2f}".format(args.dataset,p,r,f1)) | ||
|
||
print("Evaluation ended.") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/bin/bash | ||
CUDA_VISIBLE_DEVICES=$1 python eval.py saved_models/$2/ --dataset $3 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.