Skip to content

Commit

Permalink
[Server] use fastchat.model.model_adapter.get_conversation_template m…
Browse files Browse the repository at this point in the history
…ethod to get model template (vllm-project#357)
  • Loading branch information
akxxsb authored Jul 5, 2023
1 parent 98fe8cb commit 3d64cf0
Showing 1 changed file with 5 additions and 17 deletions.
22 changes: 5 additions & 17 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastchat.conversation import (Conversation, SeparatorStyle,
get_conv_template)
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template

import uvicorn

from vllm.engine.arg_utils import AsyncEngineArgs
Expand All @@ -36,7 +37,6 @@

logger = init_logger(__name__)
served_model = None
chat_template = None
app = fastapi.FastAPI()


Expand All @@ -63,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]:


async def get_gen_prompt(request) -> str:
conv = get_conv_template(chat_template)
conv = get_conversation_template(request.model)
conv = Conversation(
name=conv.name,
system=conv.system,
Expand Down Expand Up @@ -560,14 +560,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument(
"--chat-template",
type=str,
default=None,
help="The chat template name used in the ChatCompletion endpoint. If "
"not specified, we use the API model name as the template name. See "
"https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py "
"for the list of available templates.")

parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

Expand All @@ -586,11 +579,6 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
else:
served_model = args.model

if args.chat_template is not None:
chat_template = args.chat_template
else:
chat_template = served_model

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
Expand Down

0 comments on commit 3d64cf0

Please sign in to comment.