Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniele Comi committed Mar 17, 2023
2 parents f82dfee + 4adbbdd commit 70e9976
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 139 deletions.
32 changes: 22 additions & 10 deletions scripts/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,33 @@
from os import listdir
from os.path import isfile, join

with open('BANKING77-OOS/id-oos/test/seq.in', 'r') as seqs, open('BANKING77-OOS/id-oos/test/label_original', 'r') as labels:
with open("BANKING77-OOS/id-oos/test/seq.in", "r") as seqs, open(
"BANKING77-OOS/id-oos/test/label_original", "r"
) as labels:
seq = seqs.readlines()
label = labels.readlines()

with open("BANKING77-OOS/id-oos/test/test.csv", "w") as data:
data.write("text,category\n")
for s,l in zip(seq, label):
data.write(s.strip().replace(",", "").replace(".", "") + ',' + l.strip() + '\n')
data.write("text,category\n")
for s, l in zip(seq, label):
data.write(s.strip().replace(",", "").replace(".", "") + "," + l.strip() + "\n")

df = pd.read_csv("BANKING77-OOS/id-oos/test/test.csv")
df = df.dropna()
df.to_csv("BANKING77-OOS/id-oos/test/test.csv", index=False)

dataset = load_dataset('csv', data_files={'train': 'BANKING77-OOS/id-oos/train/train.csv',
'test': 'BANKING77-OOS/id-oos/test/test.csv'}, encoding="ISO-8859-1")

model = SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v1')
dataset = load_dataset(
"csv",
data_files={
"train": "BANKING77-OOS/id-oos/train/train.csv",
"test": "BANKING77-OOS/id-oos/test/test.csv",
},
encoding="ISO-8859-1",
)

model = SentenceTransformer(
"sentence-transformers/distiluse-base-multilingual-cased-v1"
)
sim_pred = []

filenames = [f for f in listdir("../results/") if isfile(join("../results/", f))]
Expand All @@ -32,10 +42,12 @@
df = df.dropna()
df = df.reset_index(drop=True)
for i in tqdm(range(0, len(df))):
label = dataset['test'][i]['category'].replace("_", " ")
label = dataset["test"][i]["category"].replace("_", " ")

embedding_1 = model.encode(label, convert_to_tensor=True)
embedding_2 = model.encode(" ".join(df["prediction"][i].split()[:2]), convert_to_tensor=True)
embedding_2 = model.encode(
" ".join(df["prediction"][i].split()[:2]), convert_to_tensor=True
)

sim_pred.append(util.pytorch_cos_sim(embedding_1, embedding_2).item())

Expand Down
14 changes: 10 additions & 4 deletions scripts/generate_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@
from typing import List
from urllib.request import urlopen


def get_prompts(url: str) -> List[str]:
cases = []
with urlopen(url) as response:
for line in response.readlines():
line = line.strip().decode()
if line:
cases.append(
line.title()
)
cases.append(line.title())
return cases


if __name__ == "__main__":
print(os.linesep.join(get_prompts("https://raw.githubusercontent.com/jianguoz/Few-Shot-Intent-Detection/main/Datasets/BANKING77-OOS/id-oos/test/seq.in")) + os.linesep)
print(
os.linesep.join(
get_prompts(
"https://raw.githubusercontent.com/jianguoz/Few-Shot-Intent-Detection/main/Datasets/BANKING77-OOS/id-oos/test/seq.in"
)
)
+ os.linesep
)
29 changes: 22 additions & 7 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,39 @@
from src.zberta.intent_discovery.unknown_intents import unknown_intents_set
from src.zberta.intent_discovery.zberta import ZBERTA

if __name__ == '__main__':
if __name__ == "__main__":
model_name = "bert-base-uncased"
training = False
testing = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if training:
snli = DataSNLI(model_name, device)
train_iterator, valid_iterator, test_iterator = snli.iterators()
berta = model.instantiate_model(snli.labels(), snli.output_dim(), device, model_name, snli.nli_labels())
berta = model.instantiate_model(
snli.labels(), snli.output_dim(), device, model_name, snli.nli_labels()
)
trainer = Trainer(berta, train_iterator, valid_iterator, test_iterator, device)
trainer.start_training()
if testing:
trainer.start_testing()
z_banking = DataBanking(model_name, device)
z_dataset = z_banking.z_iterator()
z_intents = unknown_intents_set("en_core_web_trf", z_dataset['test']['text'])
berta = model.instantiate_model(z_banking.labels(), z_banking.output_dim(), device, model_name,
z_banking.nli_labels(), path="model.pt", dict=True)
zberta = ZBERTA(berta, model_name, z_dataset['test']['text'], z_intents, z_dataset['test']['category'])
z_intents = unknown_intents_set("en_core_web_trf", z_dataset["test"]["text"])
berta = model.instantiate_model(
z_banking.labels(),
z_banking.output_dim(),
device,
model_name,
z_banking.nli_labels(),
path="model.pt",
dict=True,
)
zberta = ZBERTA(
berta,
model_name,
z_dataset["test"]["text"],
z_intents,
z_dataset["test"]["category"],
)
z_acc = zberta.zero_shot_intents()
print(z_acc)
52 changes: 37 additions & 15 deletions scripts/nli_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,70 @@


def nli_augmentation(df_train):
nltk.download('wordnet')
nltk.download('omw')
nltk.download("wordnet")
nltk.download("omw")
kw_model = KeyBERT()
data = kw_model.extract_keywords(df_train['text'][0], keyphrase_ngram_range=(1, 1), stop_words=None)
data = kw_model.extract_keywords(
df_train["text"][0], keyphrase_ngram_range=(1, 1), stop_words=None
)
words = []
word = ""
for i in tqdm(range(len(df_train))):
try:
word = kw_model.extract_keywords(df_train['text'][i], keyphrase_ngram_range=(1, 1), stop_words=None)[0][0]
word = kw_model.extract_keywords(
df_train["text"][i], keyphrase_ngram_range=(1, 1), stop_words=None
)[0][0]
wn.synset(wn.synsets(word)[0].name()).definition()
except:
words.append(word)

words_ = []
word = ""
out = []
intro = 'this text is about '
intro = "this text is about "

for i in tqdm(range(len(df_train))):
try:
word = kw_model.extract_keywords(df_train['text'][i], keyphrase_ngram_range=(1, 1), stop_words=None)
word = kw_model.extract_keywords(
df_train["text"][i], keyphrase_ngram_range=(1, 1), stop_words=None
)
if word[0][0] not in words:
out.append(intro + wn.synset(wn.synsets(word[0][0])[0].name()).definition())
out.append(
intro + wn.synset(wn.synsets(word[0][0])[0].name()).definition()
)
elif len(word[1][0]) > 3:
out.append(intro + wn.synset(wn.synsets(word[1][0])[0].name()).definition())
out.append(
intro + wn.synset(wn.synsets(word[1][0])[0].name()).definition()
)
else:
out.append(intro + word[0][0])
except:
words_.append(word[1][0])
out.append(intro + word[0][0])

out = [x.replace('(', '').replace(')', '') for x in out]
out = [x.replace("(", "").replace(")", "") for x in out]
out = [x.lower() for x in out]
df_train['hypothesis'] = out
df_train['gold_label'] = 'entailment'
df_train["hypothesis"] = out
df_train["gold_label"] = "entailment"

df_temp = df_train
categories = list(df_temp.drop_duplicates("category")['category'])
dictionary = [ n for n in wn.all_lemma_names() if len(n) > 6]
categories = list(df_temp.drop_duplicates("category")["category"])
dictionary = [n for n in wn.all_lemma_names() if len(n) > 6]
for category in tqdm(categories):
df_category = df_temp[df_train['category'].str.match(category)]
df_category = df_temp[df_train["category"].str.match(category)]
for i, data in tqdm(df_category.iterrows(), total=df_category.shape[0]):
rand_word = random.choice(dictionary)
df_temp = df_temp.append(pd.DataFrame({"text":[data['text']], "category":[category], "hypothesis":[intro + wn.synset(wn.synsets(rand_word)[0].name()).definition()], "gold_label":["contradiction"]}))
df_temp = df_temp.append(
pd.DataFrame(
{
"text": [data["text"]],
"category": [category],
"hypothesis": [
intro
+ wn.synset(wn.synsets(rand_word)[0].name()).definition()
],
"gold_label": ["contradiction"],
}
)
)
return df_temp
59 changes: 27 additions & 32 deletions scripts/zs_generator.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,65 @@
import pandas as pd
import click
from tqdm import tqdm
from zs_model_generators import t0_generator, gpt_generator, bloom_generator, keybert_generator
from zs_model_generators import t0_generator, gpt_generator, keybert_generator

from generate_prompts import get_prompts

def find_model_class(model_name_or_path):

if "t0" in model_name_or_path.lower():
return t0_generator
elif "gpt" in model_name_or_path.lower():
return gpt_generator
elif "bloom" in model_name_or_path.lower():
return bloom_generator
elif "keybert" in model_name_or_path.lower():
return keybert_generator
else:
raise ValueError(f"{model_name_or_path} is not supported")


@click.command()
@click.option('--url',
type=str,
required=True,
help='Url to input examples.'
)
@click.option('--prompt',
type=str,
required=True,
help='Prompt to be used.'
)
@click.option('--model_name_or_path',
type=str,
required=True,
help='Model to be used for the generation.'
@click.option("--dataset", type=str, required=True, help="Dataset's path")
@click.option("--prompt", type=str, required=True, help="Prompt to be used.")
@click.option(
"--model_name_or_path",
type=str,
required=True,
help="Model to be used for the generation.",
)
@click.option('--cache_dir',
type=str,
required=True,
help='Cache directory for the model.'
@click.option(
"--cache_dir", type=str, required=True, help="Cache directory for the model."
)
@click.option('--output_path',
type=str,
required=True,
help='Path to save the generations.'
@click.option(
"--output_path", type=str, required=True, help="Path to save the generations."
)
def main(url, prompt, model_name_or_path, cache_dir, output_path):
def main(dataset, prompt, model_name_or_path, cache_dir, output_path):

model_class = find_model_class(model_name_or_path)
model = model_class(model_name_or_path, cache_dir)

cases = get_prompts(url)
data = pd.read_csv(dataset)

prompts = []
generations = []

for case in tqdm(cases):
for _, case in tqdm(data.iterrows()):

input_text = prompt.format(case)
input_text = prompt.format(case["text"])
prompts.append(prompt)

generated = model.generate_text(input_text)

generations.append(generated[0])

df = pd.DataFrame.from_dict({"utterance":cases, "prompt":prompts, model_name_or_path:generations})
df = pd.DataFrame.from_dict(
{
"utterance": data["text"].to_list(),
"prompt": prompts,
model_name_or_path: generations,
"category": data["category"].to_list(),
}
)
df.to_csv(output_path)


if __name__ == "__main__":
main()
34 changes: 16 additions & 18 deletions scripts/zs_generator_snli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,17 @@
from transformers import pipeline
from tqdm import tqdm


@click.command()
@click.option('--datafile',
type=str,
required=True,
help='Path to the datafile'
)
@click.option('--model_name_or_path',
type=str,
default=None,
help='Model to be used for the generation.'
@click.option("--datafile", type=str, required=True, help="Path to the datafile")
@click.option(
"--model_name_or_path",
type=str,
default=None,
help="Model to be used for the generation.",
)
@click.option('--output_path',
type=str,
required=True,
help='Path to save the generations.'
@click.option(
"--output_path", type=str, required=True, help="Path to save the generations."
)
def execute(datafile, output_path, model_name_or_path=None):

Expand All @@ -28,7 +24,7 @@ def execute(datafile, output_path, model_name_or_path=None):
classifier = pipeline("zero-shot-classification")

data = pd.read_csv(datafile)
data = data.to_dict('records')
data = data.to_dict("records")

utterances = []
classes = []
Expand All @@ -44,14 +40,16 @@ def execute(datafile, output_path, model_name_or_path=None):
res = classifier(case["utterance"], labels)

label = "Unknown"
if res["scores"][0]>0.5:
label =res['labels'][0]
if res["scores"][0] > 0.5:
label = res["labels"][0]

generations.append(label)

df = pd.DataFrame.from_dict(
{"utterance": utterances, "classes": classes, "snli": generations})
{"utterance": utterances, "classes": classes, "snli": generations}
)
df.to_csv(output_path)

if __name__ == '__main__':

if __name__ == "__main__":
execute()
Loading

0 comments on commit 70e9976

Please sign in to comment.