1
+ # reference: https://medium.com/@geronimo7/llms-multi-gpu-inference-with-accelerate-5a8333e4c5db
2
+
1
3
from accelerate import Accelerator
2
4
from accelerate .utils import gather_object
3
5
from transformers import AutoModelForCausalLM , AutoTokenizer
4
- import torch , time , json
5
- import argparse
6
6
from datasets import load_dataset
7
+
8
+ import argparse
9
+ import torch , time , json , os
10
+ from pathlib import Path
7
11
from tqdm import tqdm
8
- import warnings
9
- warnings .filterwarnings ("ignore" )
10
- import os
11
12
from datetime import timedelta
12
13
from accelerate .utils import InitProcessGroupKwargs
13
14
15
+ import warnings
16
+ warnings .filterwarnings ("ignore" )
17
+
14
18
kwargs = InitProcessGroupKwargs (timeout = timedelta (seconds = 36000 ))
15
19
accelerator = Accelerator (kwargs_handlers = [kwargs ])
16
20
17
- parser = argparse .ArgumentParser ()
18
- parser .add_argument ('--model' , type = str , default = 'UCLA-AGI/zephyr-7b-sft-full-SPIN-iter0' )
19
- parser .add_argument ('--data_frac' , type = int , default = 0 )
20
- parser .add_argument ('--frac_len' , type = int , default = 0 )
21
- parser .add_argument ('--output_dir' , type = str , default = 'generated/iter1' )
22
- parser .add_argument ('--batch_size' , type = int , default = 16 )
23
- parser .add_argument ('--input_dir' , type = str , default = 'UCLA-AGI/SPIN_iter0' )
24
- parser .add_argument ('--split' , type = str , default = 'train' )
25
-
26
- args = parser .parse_args ()
27
- model_path = args .model
28
- data_frac = args .data_frac
29
- batch_size = args .batch_size
30
-
31
- if not os .path .exists (args .output_dir ):
32
- os .makedirs (args .output_dir )
33
-
34
- # load a base model and tokenizer
35
- model = AutoModelForCausalLM .from_pretrained (
36
- model_path ,
37
- device_map = {"" : accelerator .process_index },
38
- torch_dtype = torch .bfloat16 ,
39
- )
40
- tokenizer = AutoTokenizer .from_pretrained (model_path )
41
- tokenizer .pad_token = tokenizer .eos_token
42
-
43
- # load data
44
- data = load_dataset (args .input_dir , split = args .split )
45
- data = data .shuffle (seed = 42 )
46
- if args .frac_len > 0 :
47
- sub_len = args .frac_len
48
- if sub_len * (data_frac + 1 ) > len (data ):
49
- data = data [sub_len * data_frac :]['chosen' ]
50
- else :
51
- data = data [sub_len * data_frac :sub_len * (data_frac + 1 )]['chosen' ]
52
-
53
- prompts_all = ["### Instruction: " + data [idx ][0 ]['content' ] + "\n \n ### Response: " for idx in range (len (data ))]
54
- prompts_old = [data [idx ][0 ]['content' ] for idx in range (len (data ))]
55
- corrects_all = [data [idx ][1 ]['content' ] for idx in range (len (data ))]
56
-
57
- # batch, left pad (for inference), and tokenize
21
+ def parse_arguments ():
22
+ """Parse command line arguments."""
23
+ parser = argparse .ArgumentParser ()
24
+ parser .add_argument ('--model' , type = str , default = 'UCLA-AGI/zephyr-7b-sft-full-SPIN-iter0' )
25
+ parser .add_argument ('--data_frac' , type = int , default = 0 )
26
+ parser .add_argument ('--frac_len' , type = int , default = 0 )
27
+ parser .add_argument ('--output_dir' , type = str , default = 'generated/iter1' )
28
+ parser .add_argument ('--batch_size' , type = int , default = 16 )
29
+ parser .add_argument ('--input_dir' , type = str , default = 'UCLA-AGI/SPIN_iter0' )
30
+ parser .add_argument ('--split' , type = str , default = 'train' )
31
+ return parser .parse_args ()
32
+
58
33
def prepare_prompts (prompts , tokenizer , batch_size = 4 ):
34
+ """Prepare prompts for tokenization."""
59
35
batches = [prompts [i :i + batch_size ] for i in range (0 , len (prompts ), batch_size )]
60
36
batches_tok = []
61
37
tokenizer .padding_side = "left"
@@ -72,43 +48,76 @@ def prepare_prompts(prompts, tokenizer, batch_size=4):
72
48
tokenizer .padding_side = "right"
73
49
return batches_tok
74
50
75
- # sync GPUs and start the timer
76
- accelerator .wait_for_everyone ()
77
- start = time .time ()
78
-
79
- # divide the prompt list onto the available GPUs
80
- with accelerator .split_between_processes (prompts_all ) as prompts :
81
- results = []
82
-
83
- # have each GPU do inference in batches
84
- prompt_batches = prepare_prompts (prompts , tokenizer , batch_size = args .batch_size )
85
-
86
- for prompts_tokenized in tqdm (prompt_batches ):
87
- # set max_new_tokens smaller for faster inference
88
- outputs_tokenized = model .generate (** prompts_tokenized , max_new_tokens = 256 , pad_token_id = tokenizer .eos_token_id )
89
-
90
- # remove prompt from gen. tokens
91
- outputs_tokenized = [ tok_out [len (tok_in ):]
92
- for tok_in , tok_out in zip (prompts_tokenized ["input_ids" ], outputs_tokenized ) ]
93
- # decode gen. tokens
94
- outputs = tokenizer .batch_decode (outputs_tokenized )
95
- results .extend (outputs )
96
-
97
- # collect results from all the GPUs and remove paddings
98
- results_gathered = gather_object (results )
99
- results = [r .replace ("</s>" ,"" ).lstrip () for r in results_gathered ]
100
-
101
- if accelerator .is_local_main_process :
102
- timediff = time .time ()- start
103
- print (f"time elapsed: { timediff } " )
104
-
105
- # collecting data
106
- for idx in range (len (corrects_all )):
107
- d = {"chosen" : [{"role" : "user" , "content" : prompts_old [idx ]}, {"role" : "assistant" , "content" : corrects_all [idx ]}], "rejected" : [{"role" : "user" , "content" : prompts_old [idx ]}, {"role" : "assistant" , "content" : results [idx ]}]}
108
- if args .split == 'test' :
109
- filename = f"{ args .output_dir } /loser_{ data_frac } _test.jsonl"
51
+ def main ():
52
+ args = parse_arguments ()
53
+ model_path = args .model
54
+ data_frac = args .data_frac
55
+ batch_size = args .batch_size
56
+ output_dir = Path (args .output_dir )
57
+ output_dir .mkdir (parents = True , exist_ok = True )
58
+
59
+ # load a base model and tokenizer
60
+ model = AutoModelForCausalLM .from_pretrained (
61
+ model_path ,
62
+ device_map = {"" : accelerator .process_index },
63
+ torch_dtype = torch .bfloat16 ,
64
+ )
65
+ tokenizer = AutoTokenizer .from_pretrained (model_path )
66
+ tokenizer .pad_token = tokenizer .eos_token
67
+
68
+ # load data
69
+ data = load_dataset (args .input_dir , split = args .split )
70
+ data = data .shuffle (seed = 42 )
71
+ if args .frac_len > 0 :
72
+ sub_len = args .frac_len
73
+ if sub_len * (data_frac + 1 ) > len (data ):
74
+ data = data [sub_len * data_frac :]['chosen' ]
110
75
else :
111
- filename = f"{ args .output_dir } /loser_{ data_frac } .jsonl"
112
- with open (filename , 'a' ) as f :
113
- json .dump (d , f )
114
- f .write ('\n ' )
76
+ data = data [sub_len * data_frac :sub_len * (data_frac + 1 )]['chosen' ]
77
+
78
+ prompts_all = ["### Instruction: " + data [idx ][0 ]['content' ] + "\n \n ### Response: " for idx in range (len (data ))]
79
+ prompts_old = [data [idx ][0 ]['content' ] for idx in range (len (data ))]
80
+ corrects_all = [data [idx ][1 ]['content' ] for idx in range (len (data ))]
81
+
82
+ # sync GPUs and start the timer
83
+ accelerator .wait_for_everyone ()
84
+ start = time .time ()
85
+
86
+ # divide the prompt list onto the available GPUs
87
+ with accelerator .split_between_processes (prompts_all ) as prompts :
88
+ results = []
89
+ prompt_batches = prepare_prompts (prompts , tokenizer , batch_size = args .batch_size )
90
+
91
+ for prompts_tokenized in tqdm (prompt_batches ):
92
+ # set max_new_tokens smaller for faster inference
93
+ outputs_tokenized = model .generate (** prompts_tokenized , max_new_tokens = 256 , pad_token_id = tokenizer .eos_token_id )
94
+
95
+ # remove prompt from gen. tokens
96
+ outputs_tokenized = [ tok_out [len (tok_in ):]
97
+ for tok_in , tok_out in zip (prompts_tokenized ["input_ids" ], outputs_tokenized ) ]
98
+ # decode gen. tokens
99
+ outputs = tokenizer .batch_decode (outputs_tokenized )
100
+ results .extend (outputs )
101
+
102
+ # collect results from all the GPUs and remove paddings
103
+ results_gathered = gather_object (results )
104
+ results = [r .replace ("</s>" ,"" ).lstrip () for r in results_gathered ]
105
+
106
+ if accelerator .is_local_main_process :
107
+ timediff = time .time ()- start
108
+ print (f"time elapsed: { timediff } " )
109
+
110
+ # collecting data
111
+ for idx in range (len (corrects_all )):
112
+ d = {"chosen" : [{"role" : "user" , "content" : prompts_old [idx ]}, {"role" : "assistant" , "content" : corrects_all [idx ]}], "rejected" : [{"role" : "user" , "content" : prompts_old [idx ]}, {"role" : "assistant" , "content" : results [idx ]}]}
113
+ if args .split == 'test' :
114
+ filename = f"{ args .output_dir } /loser_{ data_frac } _test.jsonl"
115
+ else :
116
+ filename = f"{ args .output_dir } /loser_{ data_frac } .jsonl"
117
+ with open (filename , 'a' ) as f :
118
+ json .dump (d , f )
119
+ f .write ('\n ' )
120
+
121
+
122
+ if __name__ == "__main__" :
123
+ main ()
0 commit comments