forked from jasonwei20/eda_nlp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5c5c36a
commit 1c5cda4
Showing
35 changed files
with
767 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Easy data augmentation techniques for text classification | ||
# Jason Wei and Kai Zou | ||
|
||
from eda import * | ||
|
||
#arguments to be parsed from command line | ||
import argparse | ||
ap = argparse.ArgumentParser() | ||
ap.add_argument("--input", required=True, type=str, help="input file of unaugmented data") | ||
ap.add_argument("--output", required=False, type=str, help="output file of unaugmented data") | ||
ap.add_argument("--num_aug", required=False, type=int, help="number of augmented sentences per original sentence") | ||
args = ap.parse_args() | ||
|
||
#the output file | ||
output = None | ||
if args.output: | ||
output = args.output | ||
else: | ||
from os.path import dirname, basename, join | ||
output = join(dirname(args.input), 'eda_' + basename(args.input)) | ||
|
||
#number of augmented sentences to generate per original sentence | ||
num_aug = 9 #default | ||
if args.num_aug: | ||
num_aug = args.num_aug | ||
|
||
#generate more data with standard augmentation | ||
def gen_eda(train_orig, output_file, num_aug=9): | ||
|
||
writer = open(output_file, 'w') | ||
lines = open(train_orig, 'r').readlines() | ||
|
||
for i, line in enumerate(lines): | ||
parts = line[:-1].split('\t') | ||
label = parts[0] | ||
sentence = parts[1] | ||
aug_sentences = eda(sentence, num_aug=num_aug) | ||
for aug_sentence in aug_sentences: | ||
writer.write(label + "\t" + aug_sentence + '\n') | ||
|
||
writer.close() | ||
print("generated augmented sentences with eda for " + train_orig + " to " + output_file + " with num_aug=" + str(num_aug)) | ||
|
||
#main function | ||
if __name__ == "__main__": | ||
|
||
#generate augmented sentences and output into a new file | ||
gen_eda(args.input, output, num_aug=num_aug) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
# Easy data augmentation techniques for text classification | ||
# Jason Wei and Kai Zou | ||
|
||
import random | ||
from random import shuffle | ||
random.seed(1) | ||
|
||
#stop words list | ||
stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', | ||
'ours', 'ourselves', 'you', 'your', 'yours', | ||
'yourself', 'yourselves', 'he', 'him', 'his', | ||
'himself', 'she', 'her', 'hers', 'herself', | ||
'it', 'its', 'itself', 'they', 'them', 'their', | ||
'theirs', 'themselves', 'what', 'which', 'who', | ||
'whom', 'this', 'that', 'these', 'those', 'am', | ||
'is', 'are', 'was', 'were', 'be', 'been', 'being', | ||
'have', 'has', 'had', 'having', 'do', 'does', 'did', | ||
'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', | ||
'because', 'as', 'until', 'while', 'of', 'at', | ||
'by', 'for', 'with', 'about', 'against', 'between', | ||
'into', 'through', 'during', 'before', 'after', | ||
'above', 'below', 'to', 'from', 'up', 'down', 'in', | ||
'out', 'on', 'off', 'over', 'under', 'again', | ||
'further', 'then', 'once', 'here', 'there', 'when', | ||
'where', 'why', 'how', 'all', 'any', 'both', 'each', | ||
'few', 'more', 'most', 'other', 'some', 'such', 'no', | ||
'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', | ||
'very', 's', 't', 'can', 'will', 'just', 'don', | ||
'should', 'now', ''] | ||
|
||
#cleaning up text | ||
import re | ||
def get_only_chars(line): | ||
|
||
clean_line = "" | ||
|
||
line = line.replace("’", "") | ||
line = line.replace("'", "") | ||
line = line.replace("-", " ") #replace hyphens with spaces | ||
line = line.replace("\t", " ") | ||
line = line.replace("\n", " ") | ||
line = line.lower() | ||
|
||
for char in line: | ||
if char in 'qwertyuiopasdfghjklzxcvbnm ': | ||
clean_line += char | ||
else: | ||
clean_line += ' ' | ||
|
||
clean_line = re.sub(' +',' ',clean_line) #delete extra spaces | ||
if clean_line[0] == ' ': | ||
clean_line = clean_line[1:] | ||
return clean_line | ||
|
||
######################################################################## | ||
# Synonym replacement | ||
# Replace n words in the sentence with synonyms from wordnet | ||
######################################################################## | ||
|
||
#for the first time you use wordnet | ||
#import nltk | ||
#nltk.download('wordnet') | ||
from nltk.corpus import wordnet | ||
|
||
def synonym_replacement(words, n): | ||
new_words = words.copy() | ||
random_word_list = list(set([word for word in words if word not in stop_words])) | ||
random.shuffle(random_word_list) | ||
num_replaced = 0 | ||
for random_word in random_word_list: | ||
synonyms = get_synonyms(random_word) | ||
if len(synonyms) >= 1: | ||
synonym = random.choice(list(synonyms)) | ||
new_words = [synonym if word == random_word else word for word in new_words] | ||
#print("replaced", random_word, "with", synonym) | ||
num_replaced += 1 | ||
if num_replaced >= n: #only replace up to n words | ||
break | ||
|
||
#this is stupid but we need it, trust me | ||
sentence = ' '.join(new_words) | ||
new_words = sentence.split(' ') | ||
|
||
return new_words | ||
|
||
def get_synonyms(word): | ||
synonyms = set() | ||
for syn in wordnet.synsets(word): | ||
for l in syn.lemmas(): | ||
synonym = l.name().replace("_", " ").replace("-", " ").lower() | ||
synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm']) | ||
synonyms.add(synonym) | ||
if word in synonyms: | ||
synonyms.remove(word) | ||
return list(synonyms) | ||
|
||
######################################################################## | ||
# Random deletion | ||
# Randomly delete words from the sentence with probability p | ||
######################################################################## | ||
|
||
def random_deletion(words, p): | ||
|
||
#obviously, if there's only one word, don't delete it | ||
if len(words) == 1: | ||
return words | ||
|
||
#randomly delete words with probability p | ||
new_words = [] | ||
for word in words: | ||
r = random.uniform(0, 1) | ||
if r > p: | ||
new_words.append(word) | ||
|
||
#if you end up deleting all words, just return a random word | ||
if len(new_words) == 0: | ||
rand_int = random.randint(0, len(words)-1) | ||
return [words[rand_int]] | ||
|
||
return new_words | ||
|
||
######################################################################## | ||
# Random swap | ||
# Randomly swap two words in the sentence n times | ||
######################################################################## | ||
|
||
def random_swap(words, n): | ||
new_words = words.copy() | ||
for _ in range(n): | ||
new_words = swap_word(new_words) | ||
return new_words | ||
|
||
def swap_word(new_words): | ||
random_idx_1 = random.randint(0, len(new_words)-1) | ||
random_idx_2 = random_idx_1 | ||
counter = 0 | ||
while random_idx_2 == random_idx_1: | ||
random_idx_2 = random.randint(0, len(new_words)-1) | ||
counter += 1 | ||
if counter > 3: | ||
return new_words | ||
new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1] | ||
return new_words | ||
|
||
######################################################################## | ||
# Random insertion | ||
# Randomly insert n words into the sentence | ||
######################################################################## | ||
|
||
def random_insertion(words, n): | ||
new_words = words.copy() | ||
for _ in range(n): | ||
add_word(new_words) | ||
return new_words | ||
|
||
def add_word(new_words): | ||
synonyms = [] | ||
counter = 0 | ||
while len(synonyms) < 1: | ||
random_word = new_words[random.randint(0, len(new_words)-1)] | ||
synonyms = get_synonyms(random_word) | ||
counter += 1 | ||
if counter >= 10: | ||
return | ||
random_synonym = synonyms[0] | ||
random_idx = random.randint(0, len(new_words)-1) | ||
new_words.insert(random_idx, random_synonym) | ||
|
||
######################################################################## | ||
# main data augmentation function | ||
######################################################################## | ||
|
||
def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9): | ||
|
||
sentence = get_only_chars(sentence) | ||
words = sentence.split(' ') | ||
words = [word for word in words if word is not ''] | ||
num_words = len(words) | ||
|
||
augmented_sentences = [] | ||
num_new_per_technique = int(num_aug/4)+1 | ||
n_sr = max(1, int(alpha_sr*num_words)) | ||
n_ri = max(1, int(alpha_ri*num_words)) | ||
n_rs = max(1, int(alpha_rs*num_words)) | ||
|
||
#sr | ||
for _ in range(num_new_per_technique): | ||
a_words = synonym_replacement(words, n_sr) | ||
augmented_sentences.append(' '.join(a_words)) | ||
|
||
#ri | ||
for _ in range(num_new_per_technique): | ||
a_words = random_insertion(words, n_ri) | ||
augmented_sentences.append(' '.join(a_words)) | ||
|
||
#rs | ||
for _ in range(num_new_per_technique): | ||
a_words = random_swap(words, n_rs) | ||
augmented_sentences.append(' '.join(a_words)) | ||
|
||
#rd | ||
for _ in range(num_new_per_technique): | ||
a_words = random_deletion(words, p_rd) | ||
augmented_sentences.append(' '.join(a_words)) | ||
|
||
augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences] | ||
shuffle(augmented_sentences) | ||
|
||
#trim so that we have the desired number of augmented sentences | ||
if num_aug >= 1: | ||
augmented_sentences = augmented_sentences[:num_aug] | ||
else: | ||
keep_prob = num_aug / len(augmented_sentences) | ||
augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob] | ||
|
||
#append the original sentence | ||
augmented_sentences.append(sentence) | ||
|
||
return augmented_sentences |
Oops, something went wrong.