Skip to content

Commit

Permalink
add working eda wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonwei20 committed Feb 3, 2019
1 parent 5c5c36a commit 1c5cda4
Show file tree
Hide file tree
Showing 35 changed files with 767 additions and 0 deletions.
Binary file added code/__pycache__/eda.cpython-36.pyc
Binary file not shown.
48 changes: 48 additions & 0 deletions code/augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Easy data augmentation techniques for text classification
# Jason Wei and Kai Zou

from eda import *

#arguments to be parsed from command line
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--input", required=True, type=str, help="input file of unaugmented data")
ap.add_argument("--output", required=False, type=str, help="output file of unaugmented data")
ap.add_argument("--num_aug", required=False, type=int, help="number of augmented sentences per original sentence")
args = ap.parse_args()

#the output file
output = None
if args.output:
output = args.output
else:
from os.path import dirname, basename, join
output = join(dirname(args.input), 'eda_' + basename(args.input))

#number of augmented sentences to generate per original sentence
num_aug = 9 #default
if args.num_aug:
num_aug = args.num_aug

#generate more data with standard augmentation
def gen_eda(train_orig, output_file, num_aug=9):

writer = open(output_file, 'w')
lines = open(train_orig, 'r').readlines()

for i, line in enumerate(lines):
parts = line[:-1].split('\t')
label = parts[0]
sentence = parts[1]
aug_sentences = eda(sentence, num_aug=num_aug)
for aug_sentence in aug_sentences:
writer.write(label + "\t" + aug_sentence + '\n')

writer.close()
print("generated augmented sentences with eda for " + train_orig + " to " + output_file + " with num_aug=" + str(num_aug))

#main function
if __name__ == "__main__":

#generate augmented sentences and output into a new file
gen_eda(args.input, output, num_aug=num_aug)
Empty file added code/dataeda_sst2_train_500.txt
Empty file.
219 changes: 219 additions & 0 deletions code/eda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Easy data augmentation techniques for text classification
# Jason Wei and Kai Zou

import random
from random import shuffle
random.seed(1)

#stop words list
stop_words = ['i', 'me', 'my', 'myself', 'we', 'our',
'ours', 'ourselves', 'you', 'your', 'yours',
'yourself', 'yourselves', 'he', 'him', 'his',
'himself', 'she', 'her', 'hers', 'herself',
'it', 'its', 'itself', 'they', 'them', 'their',
'theirs', 'themselves', 'what', 'which', 'who',
'whom', 'this', 'that', 'these', 'those', 'am',
'is', 'are', 'was', 'were', 'be', 'been', 'being',
'have', 'has', 'had', 'having', 'do', 'does', 'did',
'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or',
'because', 'as', 'until', 'while', 'of', 'at',
'by', 'for', 'with', 'about', 'against', 'between',
'into', 'through', 'during', 'before', 'after',
'above', 'below', 'to', 'from', 'up', 'down', 'in',
'out', 'on', 'off', 'over', 'under', 'again',
'further', 'then', 'once', 'here', 'there', 'when',
'where', 'why', 'how', 'all', 'any', 'both', 'each',
'few', 'more', 'most', 'other', 'some', 'such', 'no',
'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too',
'very', 's', 't', 'can', 'will', 'just', 'don',
'should', 'now', '']

#cleaning up text
import re
def get_only_chars(line):

clean_line = ""

line = line.replace("’", "")
line = line.replace("'", "")
line = line.replace("-", " ") #replace hyphens with spaces
line = line.replace("\t", " ")
line = line.replace("\n", " ")
line = line.lower()

for char in line:
if char in 'qwertyuiopasdfghjklzxcvbnm ':
clean_line += char
else:
clean_line += ' '

clean_line = re.sub(' +',' ',clean_line) #delete extra spaces
if clean_line[0] == ' ':
clean_line = clean_line[1:]
return clean_line

########################################################################
# Synonym replacement
# Replace n words in the sentence with synonyms from wordnet
########################################################################

#for the first time you use wordnet
#import nltk
#nltk.download('wordnet')
from nltk.corpus import wordnet

def synonym_replacement(words, n):
new_words = words.copy()
random_word_list = list(set([word for word in words if word not in stop_words]))
random.shuffle(random_word_list)
num_replaced = 0
for random_word in random_word_list:
synonyms = get_synonyms(random_word)
if len(synonyms) >= 1:
synonym = random.choice(list(synonyms))
new_words = [synonym if word == random_word else word for word in new_words]
#print("replaced", random_word, "with", synonym)
num_replaced += 1
if num_replaced >= n: #only replace up to n words
break

#this is stupid but we need it, trust me
sentence = ' '.join(new_words)
new_words = sentence.split(' ')

return new_words

def get_synonyms(word):
synonyms = set()
for syn in wordnet.synsets(word):
for l in syn.lemmas():
synonym = l.name().replace("_", " ").replace("-", " ").lower()
synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
synonyms.add(synonym)
if word in synonyms:
synonyms.remove(word)
return list(synonyms)

########################################################################
# Random deletion
# Randomly delete words from the sentence with probability p
########################################################################

def random_deletion(words, p):

#obviously, if there's only one word, don't delete it
if len(words) == 1:
return words

#randomly delete words with probability p
new_words = []
for word in words:
r = random.uniform(0, 1)
if r > p:
new_words.append(word)

#if you end up deleting all words, just return a random word
if len(new_words) == 0:
rand_int = random.randint(0, len(words)-1)
return [words[rand_int]]

return new_words

########################################################################
# Random swap
# Randomly swap two words in the sentence n times
########################################################################

def random_swap(words, n):
new_words = words.copy()
for _ in range(n):
new_words = swap_word(new_words)
return new_words

def swap_word(new_words):
random_idx_1 = random.randint(0, len(new_words)-1)
random_idx_2 = random_idx_1
counter = 0
while random_idx_2 == random_idx_1:
random_idx_2 = random.randint(0, len(new_words)-1)
counter += 1
if counter > 3:
return new_words
new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
return new_words

########################################################################
# Random insertion
# Randomly insert n words into the sentence
########################################################################

def random_insertion(words, n):
new_words = words.copy()
for _ in range(n):
add_word(new_words)
return new_words

def add_word(new_words):
synonyms = []
counter = 0
while len(synonyms) < 1:
random_word = new_words[random.randint(0, len(new_words)-1)]
synonyms = get_synonyms(random_word)
counter += 1
if counter >= 10:
return
random_synonym = synonyms[0]
random_idx = random.randint(0, len(new_words)-1)
new_words.insert(random_idx, random_synonym)

########################################################################
# main data augmentation function
########################################################################

def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9):

sentence = get_only_chars(sentence)
words = sentence.split(' ')
words = [word for word in words if word is not '']
num_words = len(words)

augmented_sentences = []
num_new_per_technique = int(num_aug/4)+1
n_sr = max(1, int(alpha_sr*num_words))
n_ri = max(1, int(alpha_ri*num_words))
n_rs = max(1, int(alpha_rs*num_words))

#sr
for _ in range(num_new_per_technique):
a_words = synonym_replacement(words, n_sr)
augmented_sentences.append(' '.join(a_words))

#ri
for _ in range(num_new_per_technique):
a_words = random_insertion(words, n_ri)
augmented_sentences.append(' '.join(a_words))

#rs
for _ in range(num_new_per_technique):
a_words = random_swap(words, n_rs)
augmented_sentences.append(' '.join(a_words))

#rd
for _ in range(num_new_per_technique):
a_words = random_deletion(words, p_rd)
augmented_sentences.append(' '.join(a_words))

augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences]
shuffle(augmented_sentences)

#trim so that we have the desired number of augmented sentences
if num_aug >= 1:
augmented_sentences = augmented_sentences[:num_aug]
else:
keep_prob = num_aug / len(augmented_sentences)
augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob]

#append the original sentence
augmented_sentences.append(sentence)

return augmented_sentences
Loading

0 comments on commit 1c5cda4

Please sign in to comment.