Skip to content

Commit

Permalink
feat: support stopping words in python backend. (TabbyML#32)
Browse files Browse the repository at this point in the history
* Improve python backend

* Update lockfile

* Support stop words in python backend

* Support LanguagePresets for triton

* Update pre-commit
  • Loading branch information
wsxiaoys authored Mar 29, 2023
1 parent 2f31418 commit be7894a
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 298 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ repos:
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/PyCQA/autoflake
rev: v2.0.2
hooks:
Expand Down
434 changes: 203 additions & 231 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ toml = "^0.10.2"
gitpython = "^3.1.31"
peft = {git = "https://github.com/huggingface/peft.git", rev = "v0.2.0"}
duckdb = "^0.7.1"

torch = "^2.0.0"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.1.1"
torch = "^2.0.0"

[build-system]
requires = ["poetry-core"]
Expand Down
3 changes: 1 addition & 2 deletions tabby/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from fastapi.responses import JSONResponse

from . import events
from .backend import PythonModelService, TritonService
from .models import CompletionRequest, CompletionResponse
from .python import PythonModelService
from .triton import TritonService

app = FastAPI(
title="TabbyServer",
Expand Down
2 changes: 2 additions & 0 deletions tabby/server/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .python import PythonModelService
from .triton import TritonService
19 changes: 19 additions & 0 deletions tabby/server/backend/language_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import List

from pydantic import BaseModel, Field


class LanguagePreset(BaseModel):
max_length: int
stop_words: List[str]


PythonPreset = LanguagePreset(
max_length=128, stop_words=["\n\n", "\ndef", "\n#", "\nimport", "\nfrom", "\nclass"]
)

JavascriptPreset = LanguagePreset(
max_length=128, stop_words=["\n\n", "\nfunction", "\n//", "\nimport", "\nclass"]
)

LanguagePresets = {"python": PythonPreset, "javascript": JavascriptPreset}
105 changes: 105 additions & 0 deletions tabby/server/backend/python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import time
from typing import List

import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)

from ..models import Choice, CompletionRequest, CompletionResponse
from .language_presets import LanguagePresets
from .utils import random_completion_id, trim_with_stopwords


class PythonModelService:
def __init__(
self,
model_name,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32

self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, local_files_only=True
)
self.model = (
AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
local_files_only=True,
)
.to(device)
.eval()
)
self.stopping_criteria_mappings = {}

def generate(self, request: CompletionRequest) -> List[Choice]:
# FIXME(meng): read preset from request.
preset_name = "python"
preset = LanguagePresets[preset_name]

stopping_criteria_list = self.stopping_criteria_for_preset(preset_name)

input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to(
self.device
)
res = self.model.generate(
input_ids,
max_length=preset.max_length,
stopping_criteria=stopping_criteria_list,
)
output_ids = res[0][len(input_ids[0]) :]
text = trim_with_stopwords(self.tokenizer.decode(output_ids), preset.stop_words)
return [Choice(index=0, text=text)]

def stopping_criteria_for_preset(self, name: str) -> StoppingCriteriaList:
return StoppingCriteriaList(
[
StopWordsIdsCriteria(
[self.tokenizer.encode(x) for x in LanguagePresets[name].stop_words]
)
]
)

def __call__(self, request: CompletionRequest) -> CompletionResponse:
choices = self.generate(request)
return CompletionResponse(
id=random_completion_id(), created=int(time.time()), choices=choices
)


class StopWordsIdsCriteria(StoppingCriteria):
def __init__(self, stop_words_ids: List[str]):
self.stop_words_ids = stop_words_ids

def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
if len(input_ids) != 1:
raise ValueError("Only 1-length list is handled")

# FIXME(meng): trie based lookup.
tokens = input_ids[0]
for stop_word in self.stop_words_ids:
if len(tokens) < len(stop_word):
continue

matched = True
for i in range(len(stop_word)):
if tokens[i - len(stop_word)] != stop_word[i]:
matched = False
break

if matched:
return True

return False
33 changes: 10 additions & 23 deletions tabby/server/triton.py → tabby/server/backend/triton.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import random
import string
import time
from typing import List

Expand All @@ -8,7 +6,9 @@
from transformers import AutoTokenizer
from tritonclient.utils import InferenceServerException, np_to_triton_dtype

from .models import Choice, CompletionRequest, CompletionResponse
from ..models import Choice, CompletionRequest, CompletionResponse
from .language_presets import LanguagePresets
from .utils import random_completion_id, trim_with_stopwords


class TritonService:
Expand All @@ -25,12 +25,13 @@ def __init__(
)

def generate(self, data: CompletionRequest) -> List[Choice]:
# FIXME(meng): Make following vars configurable
n = 1
np_type = np.uint32
max_tokens = 128
model_name = "fastertransformer"
stop_words = ["\n\n"]

# FIXME(meng): read preset from request.
preset_name = "python"
preset = LanguagePresets[preset_name]

prompt = data.prompt
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
Expand All @@ -39,10 +40,10 @@ def generate(self, data: CompletionRequest) -> List[Choice]:
input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np_type)

prompt_tokens: int = input_len[0][0]
output_len = np.ones_like(input_len).astype(np_type) * max_tokens
output_len = np.ones_like(input_len).astype(np_type) * preset.max_length

stop_word_list = np.repeat(
to_word_list_format([stop_words], self.tokenizer),
to_word_list_format([preset.stop_words], self.tokenizer),
input_start_ids.shape[0],
axis=0,
)
Expand All @@ -68,7 +69,7 @@ def generate(self, data: CompletionRequest) -> List[Choice]:
self.tokenizer.decode(out[prompt_len : prompt_len + g])
for g, out in zip(gen_len, output_data)
]
trimmed = [trim_with_stopwords(d, stop_words) for d in decoded]
trimmed = [trim_with_stopwords(d, preset.stop_words) for d in decoded]
return [Choice(index=i, text=text) for i, text in enumerate(trimmed)]

def __call__(self, data: CompletionRequest) -> CompletionResponse:
Expand All @@ -86,20 +87,6 @@ def prepare_tensor(name: str, tensor_input):
return t


def random_completion_id():
return "cmpl-" + "".join(
random.choice(string.ascii_letters + string.digits) for _ in range(29)
)


def trim_with_stopwords(output: str, stopwords: list) -> str:
for w in sorted(stopwords, key=len, reverse=True):
if output.endswith(w):
output = output[: -len(w)]
break
return output


def to_word_list_format(word_dict, tokenizer):
flat_ids = []
offsets = []
Expand Down
16 changes: 16 additions & 0 deletions tabby/server/backend/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import random
import string


def random_completion_id():
return "cmpl-" + "".join(
random.choice(string.ascii_letters + string.digits) for _ in range(29)
)


def trim_with_stopwords(output: str, stopwords: list) -> str:
for w in sorted(stopwords, key=len, reverse=True):
if output.endswith(w):
output = output[: -len(w)]
break
return output
40 changes: 0 additions & 40 deletions tabby/server/python.py

This file was deleted.

0 comments on commit be7894a

Please sign in to comment.