Skip to content

Commit

Permalink
Change inputa format
Browse files Browse the repository at this point in the history
  • Loading branch information
DIC authored and DIC committed Feb 3, 2023
1 parent d48e478 commit 26bd54f
Showing 1 changed file with 27 additions and 32 deletions.
59 changes: 27 additions & 32 deletions scripts/zs_generator.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,65 @@
import pandas as pd
import click
from tqdm import tqdm
from zs_model_generators import t0_generator, gpt_generator, bloom_generator, keybert_generator
from zs_model_generators import t0_generator, gpt_generator, keybert_generator

from generate_prompts import get_prompts

def find_model_class(model_name_or_path):

if "t0" in model_name_or_path.lower():
return t0_generator
elif "gpt" in model_name_or_path.lower():
return gpt_generator
elif "bloom" in model_name_or_path.lower():
return bloom_generator
elif "keybert" in model_name_or_path.lower():
return keybert_generator
else:
raise ValueError(f"{model_name_or_path} is not supported")


@click.command()
@click.option('--url',
type=str,
required=True,
help='Url to input examples.'
)
@click.option('--prompt',
type=str,
required=True,
help='Prompt to be used.'
)
@click.option('--model_name_or_path',
type=str,
required=True,
help='Model to be used for the generation.'
@click.option("--dataset", type=str, required=True, help="Dataset's path")
@click.option("--prompt", type=str, required=True, help="Prompt to be used.")
@click.option(
"--model_name_or_path",
type=str,
required=True,
help="Model to be used for the generation.",
)
@click.option('--cache_dir',
type=str,
required=True,
help='Cache directory for the model.'
@click.option(
"--cache_dir", type=str, required=True, help="Cache directory for the model."
)
@click.option('--output_path',
type=str,
required=True,
help='Path to save the generations.'
@click.option(
"--output_path", type=str, required=True, help="Path to save the generations."
)
def main(url, prompt, model_name_or_path, cache_dir, output_path):
def main(dataset, prompt, model_name_or_path, cache_dir, output_path):

model_class = find_model_class(model_name_or_path)
model = model_class(model_name_or_path, cache_dir)

cases = get_prompts(url)
data = pd.read_csv(dataset)

prompts = []
generations = []

for case in tqdm(cases):
for case in tqdm(data.iterrows()):

input_text = prompt.format(case)
input_text = prompt.format(case["text"])
prompts.append(prompt)

generated = model.generate_text(input_text)

generations.append(generated[0])

df = pd.DataFrame.from_dict({"utterance":cases, "prompt":prompts, model_name_or_path:generations})
df = pd.DataFrame.from_dict(
{
"utterance": data["cases"].to_list(),
"prompt": prompts,
model_name_or_path: generations,
"category": data["category"].to_list(),
}
)
df.to_csv(output_path)


if __name__ == "__main__":
main()

0 comments on commit 26bd54f

Please sign in to comment.