Skip to content

Commit

Permalink
Feature: support for llama2 prompt adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
xusenlin committed Jul 19, 2023
1 parent 59ea4ac commit 5bc42e2
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
5 changes: 3 additions & 2 deletions api/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ def generate_stream(
)

input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)

if model.config.is_encoder_decoder:
max_src_len = context_len
else: # truncate
max_src_len = context_len - max_new_tokens - 8
max_src_len = context_len - max_new_tokens - 1

input_ids = input_ids[-max_src_len:]
output_ids = list(input_ids)
input_echo_len = len(input_ids)

if model.config.is_encoder_decoder:
Expand Down Expand Up @@ -348,6 +348,7 @@ def generate_stream(
"seq_length",
"max_position_embeddings",
"max_seq_len",
"model_max_length",
]


Expand Down
53 changes: 43 additions & 10 deletions api/prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Dict

from api.protocol import Role

# A global registry for all prompt adapters
prompt_adapters = []

Expand Down Expand Up @@ -33,9 +35,9 @@ def generate_prompt(self, messages: List[Dict[str, str]]) -> str:
user_content = []
for message in messages:
role, content = message["role"], message["content"]
if role in ["user", "system"]:
if role in [Role.USER, Role.SYSTEM]:
user_content.append(content)
elif role in ["assistant", "AI"]:
elif role == Role.ASSISTANT:
prompt += self.user_prompt.format("\n".join(user_content))
prompt += self.assistant_prompt.format(content)
user_content = []
Expand Down Expand Up @@ -63,9 +65,9 @@ def generate_prompt(self, messages: List[Dict[str, str]]) -> str:
i = 0
for message in messages:
role, content = message["role"], message["content"]
if role in ["user", "system"]:
if role in [Role.USER, Role.SYSTEM]:
user_content.append(content)
elif role in ["assistant", "AI"]:
elif role == Role.ASSISTANT:
u_content = "\n".join(user_content)
prompt += f"[Round {i}]\n{self.user_prompt.format(u_content)}"
prompt += self.assistant_prompt.format(content)
Expand Down Expand Up @@ -96,9 +98,9 @@ def generate_prompt(self, messages: List[Dict[str, str]]) -> str:
i = 1
for message in messages:
role, content = message["role"], message["content"]
if role in ["user", "system"]:
if role in [Role.USER, Role.SYSTEM]:
user_content.append(content)
elif role in ["assistant", "AI"]:
elif role == Role.ASSISTANT:
u_content = "\n".join(user_content)
prompt += f"[Round {i}]\n\n{self.user_prompt.format(u_content)}"
prompt += self.assistant_prompt.format(content)
Expand Down Expand Up @@ -268,9 +270,9 @@ def match(self, model_name):
def generate_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = "<|system|>\n<|end|>\n"
for message in messages:
if message["role"] == "system":
if message["role"] == Role.SYSTEM:
prompt += self.system_prompt.format(message["content"])
if message["role"] == "user":
if message["role"] == Role.USER:
prompt += self.user_prompt.format(message["content"])
else:
prompt += self.assistant_prompt.format(message["content"])
Expand All @@ -294,9 +296,9 @@ def match(self, model_name):
def generate_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
for message in messages:
if message["role"] == "system":
if message["role"] == Role.SYSTEM:
prompt += self.system_prompt.format(message["content"])
if message["role"] == "user":
if message["role"] == Role.ASSISTANT:
prompt += self.user_prompt.format(message["content"])
else:
prompt += self.assistant_prompt.format(message["content"])
Expand All @@ -306,6 +308,36 @@ def generate_prompt(self, messages: List[Dict[str, str]]) -> str:
return prompt


class Llama2ChatPromptAdapter(BasePromptAdapter):
""" https://github.com/facebookresearch/llama/blob/main/llama/generation.py """

system_prompt = "<s><<SYS>>\n{}\n<</SYS>>\n\n"
user_prompt = "[INST] {} "
assistant_prompt = "[/INST] {} </s><s>"
stop = ["[INST]", "[/INST]"]

def match(self, model_name):
return "llama2" in model_name

def generate_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
prompt = self.system_prompt.format(prompt)
for i, message in enumerate(messages):
if i == 0:
prompt += message["content"]
else:
if message["role"] == Role.USER:
prompt += self.user_prompt.format(message["content"])
else:
prompt += self.assistant_prompt.format(message["content"])

prompt += "[/INST] "

return prompt


register_prompt_adapter(ChatGLMPromptAdapter)
register_prompt_adapter(ChatGLM2PromptAdapter)
register_prompt_adapter(MossPromptAdapter)
Expand All @@ -321,5 +353,6 @@ def generate_prompt(self, messages: List[Dict[str, str]]) -> str:
register_prompt_adapter(BaiChuanPromptAdapter)
register_prompt_adapter(StarChatPromptAdapter)
register_prompt_adapter(AquilaChatPromptAdapter)
register_prompt_adapter(Llama2ChatPromptAdapter)

register_prompt_adapter(BasePromptAdapter)
7 changes: 7 additions & 0 deletions api/protocol.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import secrets
import time
from enum import Enum
from typing import Literal, Optional, List, Dict, Any, Union

from pydantic import BaseModel, Field


class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"


class ErrorResponse(BaseModel):
object: str = "error"
message: str
Expand Down

0 comments on commit 5bc42e2

Please sign in to comment.