Skip to content

Commit

Permalink
fix:app in XAgentGen
Browse files Browse the repository at this point in the history
  • Loading branch information
AL-377 committed Nov 27, 2023
1 parent a72d3af commit c9ac0ca
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions XAgentGen/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from xgen.parser import FunctionParser
from xgen.server.datamodel import *
from xgen.server.message_formater import format
from xgen.parser import FunctionParser
import xgen.text.generate as generate
from xgen.models.transformers import Transformers, TransformersTokenizer
from xgen.models.transformers import XTransformers
from outlines.models.transformers import TransformersTokenizer
from vllm.sampling_params import LogitsProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand All @@ -19,6 +19,7 @@
import uvicorn
from transformers import AutoTokenizer


app = FastAPI()

app.add_middleware(
Expand Down Expand Up @@ -72,10 +73,10 @@ def __init__(self, extra_arguments, functions, function_call, tokenizer_path, de
outline_tokenizer = TransformersTokenizer(tokenizer_path)
fake_model = Dict()
fake_model.device = device
model = Transformers(fake_model, outline_tokenizer)
model = XTransformers(fake_model, outline_tokenizer)
self.dp.create_all_functions_model(extra_arguments, functions, function_call)
regex_list = self.dp.models_to_regex()
self.generator = generate.choice(model, regex_list)
self.generator = generate.multi_regex(model, regex_list)

def __call__(self, generated_token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
generated_token_ids = torch.LongTensor(generated_token_ids).view(1, -1).to(logits.device)
Expand Down

0 comments on commit c9ac0ca

Please sign in to comment.