Skip to content

Commit 4189c5c

Browse files
committed
update generate.py
1 parent 951f6de commit 4189c5c

File tree

1 file changed

+94
-85
lines changed

1 file changed

+94
-85
lines changed

spin/generate.py

+94-85
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,37 @@
1+
# reference: https://medium.com/@geronimo7/llms-multi-gpu-inference-with-accelerate-5a8333e4c5db
2+
13
from accelerate import Accelerator
24
from accelerate.utils import gather_object
35
from transformers import AutoModelForCausalLM, AutoTokenizer
4-
import torch, time, json
5-
import argparse
66
from datasets import load_dataset
7+
8+
import argparse
9+
import torch, time, json, os
10+
from pathlib import Path
711
from tqdm import tqdm
8-
import warnings
9-
warnings.filterwarnings("ignore")
10-
import os
1112
from datetime import timedelta
1213
from accelerate.utils import InitProcessGroupKwargs
1314

15+
import warnings
16+
warnings.filterwarnings("ignore")
17+
1418
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=36000))
1519
accelerator = Accelerator(kwargs_handlers=[kwargs])
1620

17-
parser = argparse.ArgumentParser()
18-
parser.add_argument('--model', type=str, default='UCLA-AGI/zephyr-7b-sft-full-SPIN-iter0')
19-
parser.add_argument('--data_frac', type=int, default=0)
20-
parser.add_argument('--frac_len', type=int, default=0)
21-
parser.add_argument('--output_dir', type=str, default='generated/iter1')
22-
parser.add_argument('--batch_size', type=int, default=16)
23-
parser.add_argument('--input_dir', type=str, default='UCLA-AGI/SPIN_iter0')
24-
parser.add_argument('--split', type=str, default='train')
25-
26-
args = parser.parse_args()
27-
model_path = args.model
28-
data_frac = args.data_frac
29-
batch_size = args.batch_size
30-
31-
if not os.path.exists(args.output_dir):
32-
os.makedirs(args.output_dir)
33-
34-
# load a base model and tokenizer
35-
model = AutoModelForCausalLM.from_pretrained(
36-
model_path,
37-
device_map={"": accelerator.process_index},
38-
torch_dtype=torch.bfloat16,
39-
)
40-
tokenizer = AutoTokenizer.from_pretrained(model_path)
41-
tokenizer.pad_token = tokenizer.eos_token
42-
43-
# load data
44-
data = load_dataset(args.input_dir, split=args.split)
45-
data = data.shuffle(seed=42)
46-
if args.frac_len > 0:
47-
sub_len = args.frac_len
48-
if sub_len*(data_frac+1) > len(data):
49-
data = data[sub_len*data_frac:]['chosen']
50-
else:
51-
data = data[sub_len*data_frac:sub_len*(data_frac+1)]['chosen']
52-
53-
prompts_all = ["### Instruction: " + data[idx][0]['content'] + "\n\n### Response: " for idx in range(len(data))]
54-
prompts_old = [data[idx][0]['content'] for idx in range(len(data))]
55-
corrects_all = [data[idx][1]['content'] for idx in range(len(data))]
56-
57-
# batch, left pad (for inference), and tokenize
21+
def parse_arguments():
22+
"""Parse command line arguments."""
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument('--model', type=str, default='UCLA-AGI/zephyr-7b-sft-full-SPIN-iter0')
25+
parser.add_argument('--data_frac', type=int, default=0)
26+
parser.add_argument('--frac_len', type=int, default=0)
27+
parser.add_argument('--output_dir', type=str, default='generated/iter1')
28+
parser.add_argument('--batch_size', type=int, default=16)
29+
parser.add_argument('--input_dir', type=str, default='UCLA-AGI/SPIN_iter0')
30+
parser.add_argument('--split', type=str, default='train')
31+
return parser.parse_args()
32+
5833
def prepare_prompts(prompts, tokenizer, batch_size=4):
34+
"""Prepare prompts for tokenization."""
5935
batches=[prompts[i:i + batch_size] for i in range(0, len(prompts), batch_size)]
6036
batches_tok=[]
6137
tokenizer.padding_side="left"
@@ -72,43 +48,76 @@ def prepare_prompts(prompts, tokenizer, batch_size=4):
7248
tokenizer.padding_side="right"
7349
return batches_tok
7450

75-
# sync GPUs and start the timer
76-
accelerator.wait_for_everyone()
77-
start=time.time()
78-
79-
# divide the prompt list onto the available GPUs
80-
with accelerator.split_between_processes(prompts_all) as prompts:
81-
results = []
82-
83-
# have each GPU do inference in batches
84-
prompt_batches=prepare_prompts(prompts, tokenizer, batch_size=args.batch_size)
85-
86-
for prompts_tokenized in tqdm(prompt_batches):
87-
# set max_new_tokens smaller for faster inference
88-
outputs_tokenized=model.generate(**prompts_tokenized, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)
89-
90-
# remove prompt from gen. tokens
91-
outputs_tokenized=[ tok_out[len(tok_in):]
92-
for tok_in, tok_out in zip(prompts_tokenized["input_ids"], outputs_tokenized) ]
93-
# decode gen. tokens
94-
outputs=tokenizer.batch_decode(outputs_tokenized)
95-
results.extend(outputs)
96-
97-
# collect results from all the GPUs and remove paddings
98-
results_gathered=gather_object(results)
99-
results = [r.replace("</s>","").lstrip() for r in results_gathered]
100-
101-
if accelerator.is_local_main_process:
102-
timediff=time.time()-start
103-
print(f"time elapsed: {timediff}")
104-
105-
# collecting data
106-
for idx in range(len(corrects_all)):
107-
d = {"chosen": [{"role": "user", "content": prompts_old[idx]}, {"role": "assistant", "content": corrects_all[idx]}], "rejected": [{"role": "user", "content": prompts_old[idx]}, {"role": "assistant", "content": results[idx]}]}
108-
if args.split == 'test':
109-
filename = f"{args.output_dir}/loser_{data_frac}_test.jsonl"
51+
def main():
52+
args = parse_arguments()
53+
model_path = args.model
54+
data_frac = args.data_frac
55+
batch_size = args.batch_size
56+
output_dir = Path(args.output_dir)
57+
output_dir.mkdir(parents=True, exist_ok=True)
58+
59+
# load a base model and tokenizer
60+
model = AutoModelForCausalLM.from_pretrained(
61+
model_path,
62+
device_map={"": accelerator.process_index},
63+
torch_dtype=torch.bfloat16,
64+
)
65+
tokenizer = AutoTokenizer.from_pretrained(model_path)
66+
tokenizer.pad_token = tokenizer.eos_token
67+
68+
# load data
69+
data = load_dataset(args.input_dir, split=args.split)
70+
data = data.shuffle(seed=42)
71+
if args.frac_len > 0:
72+
sub_len = args.frac_len
73+
if sub_len*(data_frac+1) > len(data):
74+
data = data[sub_len*data_frac:]['chosen']
11075
else:
111-
filename = f"{args.output_dir}/loser_{data_frac}.jsonl"
112-
with open(filename, 'a') as f:
113-
json.dump(d, f)
114-
f.write('\n')
76+
data = data[sub_len*data_frac:sub_len*(data_frac+1)]['chosen']
77+
78+
prompts_all = ["### Instruction: " + data[idx][0]['content'] + "\n\n### Response: " for idx in range(len(data))]
79+
prompts_old = [data[idx][0]['content'] for idx in range(len(data))]
80+
corrects_all = [data[idx][1]['content'] for idx in range(len(data))]
81+
82+
# sync GPUs and start the timer
83+
accelerator.wait_for_everyone()
84+
start=time.time()
85+
86+
# divide the prompt list onto the available GPUs
87+
with accelerator.split_between_processes(prompts_all) as prompts:
88+
results = []
89+
prompt_batches=prepare_prompts(prompts, tokenizer, batch_size=args.batch_size)
90+
91+
for prompts_tokenized in tqdm(prompt_batches):
92+
# set max_new_tokens smaller for faster inference
93+
outputs_tokenized=model.generate(**prompts_tokenized, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)
94+
95+
# remove prompt from gen. tokens
96+
outputs_tokenized=[ tok_out[len(tok_in):]
97+
for tok_in, tok_out in zip(prompts_tokenized["input_ids"], outputs_tokenized) ]
98+
# decode gen. tokens
99+
outputs=tokenizer.batch_decode(outputs_tokenized)
100+
results.extend(outputs)
101+
102+
# collect results from all the GPUs and remove paddings
103+
results_gathered=gather_object(results)
104+
results = [r.replace("</s>","").lstrip() for r in results_gathered]
105+
106+
if accelerator.is_local_main_process:
107+
timediff=time.time()-start
108+
print(f"time elapsed: {timediff}")
109+
110+
# collecting data
111+
for idx in range(len(corrects_all)):
112+
d = {"chosen": [{"role": "user", "content": prompts_old[idx]}, {"role": "assistant", "content": corrects_all[idx]}], "rejected": [{"role": "user", "content": prompts_old[idx]}, {"role": "assistant", "content": results[idx]}]}
113+
if args.split == 'test':
114+
filename = f"{args.output_dir}/loser_{data_frac}_test.jsonl"
115+
else:
116+
filename = f"{args.output_dir}/loser_{data_frac}.jsonl"
117+
with open(filename, 'a') as f:
118+
json.dump(d, f)
119+
f.write('\n')
120+
121+
122+
if __name__ == "__main__":
123+
main()

0 commit comments

Comments
 (0)