Skip to content

Commit

Permalink
add prompt_template parameter (#78)
Browse files Browse the repository at this point in the history
* prompt_template parameter

* read DEFAULT_SYSTEM_PROMPT from config and try to handle mistral correctly but it's pretty ugly

* import config

* remove BOS token, because MLC adds BOS

* don't strip prompt

* add a note
  • Loading branch information
technillogue authored Dec 1, 2023
1 parent 207416a commit a733140
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 14 deletions.
2 changes: 1 addition & 1 deletion models/llama-2-70b-chat-hf-mlc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)

# Inference config
USE_SYSTEM_PROMPT = False
USE_SYSTEM_PROMPT = True

ENGINE = MLCvLLMEngine
ENGINE_KWARGS = {
Expand Down
6 changes: 5 additions & 1 deletion models/mistral-7b-instruct-v0.1-mlc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
)

# Inference config
USE_SYSTEM_PROMPT = False
USE_SYSTEM_PROMPT = True

# from mistral: "<s>[INST] + Instruction [/INST] Model answer</s>[INST] Follow-up instruction [/INST]"
PROMPT_TEMPLATE = "[INST] {system_prompt}{prompt} [/INST]"
DEFAULT_SYSTEM_PROMPT = "Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity. "

ENGINE = MLCvLLMEngine
ENGINE_KWARGS = {
Expand Down
30 changes: 20 additions & 10 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from cog import BasePredictor, ConcatenateIterator, Input, Path

import config
from config import ENGINE, ENGINE_KWARGS, USE_SYSTEM_PROMPT
from src.download import Downloader
from src.utils import seed_all, delay_prints
Expand All @@ -19,10 +19,13 @@
# These are components of the prompt that should not be changed by the users
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
PROMPT_TEMPLATE = f"{B_INST} {B_SYS}{{system_prompt}}{E_SYS}{{instruction}} {E_INST}"
# normally this would start with <s>, but MLC adds it
PROMPT_TEMPLATE = f"{B_INST} {B_SYS}{{system_prompt}}{E_SYS}{{prompt}} {E_INST}"
PROMPT_TEMPLATE = getattr(config, "PROMPT_TEMPLATE", PROMPT_TEMPLATE)

# Users may want to change the system prompt, but we use the recommended system prompt by default
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant."""
DEFAULT_SYSTEM_PROMPT = getattr(config, "DEFAULT_SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT)

# Temporary hack to disable Top K from the API. We should get rid of this once engines + configs are better standardized.
USE_TOP_K = ENGINE.__name__ not in ("MLCEngine", "MLCvLLMEngine")
Expand Down Expand Up @@ -90,7 +93,7 @@ def predict(
self,
prompt: str = Input(description="Prompt to send to the model."),
system_prompt: str = Input(
description="System prompt to send to the model. This is prepended to the prompt and helps guide system behavior.",
description="System prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Should not be blank.",
default=DEFAULT_SYSTEM_PROMPT,
),
max_new_tokens: int = Input(
Expand Down Expand Up @@ -136,6 +139,10 @@ def predict(
debug: bool = Input(
description="provide debugging output in logs", default=False
),
prompt_template: str = Input(
description="Template for formatting the prompt",
default=PROMPT_TEMPLATE,
),
# return_logits: bool = Input(
# description="if set, only return logits for the first token. only useful for testing, etc.",
# default=False,
Expand All @@ -149,14 +156,17 @@ def predict(
if stop_sequences:
stop_sequences = stop_sequences.split(",")

if USE_SYSTEM_PROMPT:
prompt = (
prompt.strip("\n").removeprefix(B_INST).removesuffix(E_INST).strip()
if USE_SYSTEM_PROMPT and prompt_template:
if B_SYS not in prompt_template:
if system_prompt.strip() and not system_prompt.endswith(" "):
# mistral doesn't have a SYS token, there's just a space between the system prompt and
system_prompt = system_prompt.strip() + " "
print("Added a space to your system prompt")
prompt = prompt_template.format(
system_prompt=system_prompt, prompt=prompt
)
prompt = PROMPT_TEMPLATE.format(
system_prompt=system_prompt.strip(), instruction=prompt.strip()
)

# MLC adds BOS token
prompt = prompt.removeprefix("<s>")
print(f"Your formatted prompt is: \n{prompt}")

if replicate_weights:
Expand Down
2 changes: 0 additions & 2 deletions src/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
An entirely self-contained config parsing util that should, if all goes well, dramatically simplify our configuration.
"""


from typing import List, Optional

from pydantic import BaseModel
Expand Down
1 change: 1 addition & 0 deletions src/inference_engines/mlc_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
model_path = os.path.join(weights_path, "params")
self.cm = ChatModule(model=model_path, chat_config=chat_config)

# this isn't used!
tokenizer_path = os.path.join(weights_path, "params")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

Expand Down

0 comments on commit a733140

Please sign in to comment.