-
Notifications
You must be signed in to change notification settings - Fork 0
/
prompt.py
96 lines (78 loc) · 3.8 KB
/
prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from Models.load_models import load_model, MODEL_TYPES
from typing import Callable
from torch.nn.utils.rnn import pad_sequence
import argparse
import os
import torch
def process_api_batch(batch, model_fn, max_length):
output = [model_fn(entry['text'], max_length) for entry in batch]
return output
def process_local_batch(batch, model_fn, max_length):
input_ids = [entry['tokens'] for entry in batch]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
if isinstance(model_fn, Callable):
output = model_fn.generate(input_ids=input_ids, max_length=max_length)
else:
raise ValueError("Invalid model function")
print(f"Output is {output} for input {input_ids}")
return output
BATCH_PROCESSORS = {
"api": process_api_batch,
"local": process_local_batch
}
def process_batch(batch, model_fn, max_length, model_type):
process_fn = BATCH_PROCESSORS.get(model_type)
if process_fn is None:
raise ValueError(f"Invalid model type: {model_type}")
return process_fn(batch, model_fn, max_length)
def batch_prompt(dataset, model, batch_size, num_strings, max_length):
'''
Parameters:
dataset: Each row is a dict containing the prefix text and the prefix tokens: 'text' 'tokens'
model: key of the model to use
batch_size: batch size for model prompting
num_strings: total number of strings to use
max_length: max number of tokens to generate
Returns:
output: Each row is a dict containing the output tokens of the model.
'''
model_fn = load_model(model)
model_type = MODEL_TYPES[model] # Get the model type
output = []
if num_strings <= len(dataset):
for i in range(0, num_strings, batch_size):
batch = dataset[i:i+batch_size]
batch_output = process_batch(batch, model_fn, max_length, model_type)
output.extend(batch_output)
# For if num_strings is more than the dataset length
# Frankly, this is fragile and shouldn't be used lol
else:
num_iterations = (num_strings + len(dataset) - 1) // len(dataset)
for _ in range(num_iterations):
for i in range(0, len(dataset), batch_size):
batch = dataset[i:i+batch_size]
batch_output = process_batch(batch, model_fn, max_length, model_type)
output.extend(batch_output)
if len(output) >= num_strings:
break
if len(output) >= num_strings:
break
output = output[:num_strings]
return output
if __name__ == "__main__":
# Parsing configurations
parser = argparse.ArgumentParser(description="Generate strings using a given model and prefix map.")
parser.add_argument("--prefix_map", type=str, required=True, help="The directory name of the prefix map to use (e.g., 'mamba-3b_corpus_10000').")
parser.add_argument("--model", type=str, required=True, help="The model to use for generating strings (e.g., 'mamba-3b').")
parser.add_argument("--num_strings", type=int, required=True, help="The total number of strings to generate. If N is larger than the dataset size, prefixes will be used repeatedly.") # Need to implement wraparound here.
parser.add_argument("--batch_size", type=int, default=10, help="The number of strings to generate in each batch.")
parser.add_argument("--max_length", type=int, default = 20, help="Max output length from the model. Output length may not reach this length.")
args = parser.parse_args()
# Unload the variables
prefix_map_dir = args.prefix_map
model = args.model
num_strings = args.num_strings
batch_size = args.batch_size
max_length = args.max_length
prefix_map = torch.load(f"Prefixes/{prefix_map_dir}/prefix_map.pt")
output = batch_prompt(prefix_map, model, batch_size, num_strings, max_length)