-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathapp.py
42 lines (35 loc) · 1.43 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from potassium import Potassium, Request, Response
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM
MODEL_NAME_OR_PATH = "TheBloke/CodeLlama-7B-Python-GPTQ"
DEVICE = "cuda:0"
app = Potassium("CodeLlama-7B-Python-GPTQ")
@app.init
def init() -> dict:
"""Initialize the application with the model and tokenizer."""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
model = AutoGPTQForCausalLM.from_quantized(MODEL_NAME_OR_PATH,
use_safetensors=True,
trust_remote_code=False,
device="cuda:0",
use_triton=False,
quantize_config=None,
inject_fused_attention=False)
return {
"model": model,
"tokenizer": tokenizer
}
@app.handler()
def handler(context: dict, request: Request) -> Response:
"""Handle a request to generate text from a prompt."""
model = context.get("model")
tokenizer = context.get("tokenizer")
max_new_tokens = request.json.get("max_new_tokens", 512)
temperature = request.json.get("temperature", 0.7)
prompt = request.json.get("prompt")
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
output = model.generate(inputs=input_ids, temperature=temperature, max_new_tokens=max_new_tokens)
result = tokenizer.decode(output[0])
return Response(json={"outputs": result}, status=200)
if __name__ == "__main__":
app.serve()