Skip to content

Commit

Permalink
FIX:outlines integration
Browse files Browse the repository at this point in the history
  • Loading branch information
AL-377 committed Nov 24, 2023
1 parent 007cfc9 commit 02febc5
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 968 deletions.
2 changes: 1 addition & 1 deletion XAgentGen/xgen/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .transformers import transformers,Transformers,TransformersTokenizer
from .transformers import XTransformers
220 changes: 11 additions & 209 deletions XAgentGen/xgen/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,62 +5,15 @@
TopKLogitsWarper,
TopPLogitsWarper,
)
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch
from datasets.fingerprint import Hasher
from transformers.file_utils import SPIECE_UNDERLINE

from outlines.models.tokenizer import Tokenizer

from outlines.models.transformers import Transformers,TransformersTokenizer
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer

__all__ = ["transformers"]


import torch
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]


def get_llama_tokenizer_types():
""" Get all the Llama tokenizer types/classes that need work-arounds.
When they can't be imported, a dummy class is created.
"""
try:
from transformers.models.llama import LlamaTokenizer
except ImportError:

class LlamaTokenizer: # type: ignore
pass

try:
from transformers.models.llama import LlamaTokenizerFast
except ImportError:

class LlamaTokenizerFast: # type: ignore
pass

try:
from transformers.models.code_llama import CodeLlamaTokenizer
except ImportError:

class CodeLlamaTokenizer: # type: ignore
pass

try:
from transformers.models.code_llama import CodeLlamaTokenizerFast
except ImportError:

class CodeLlamaTokenizerFast: # type: ignore
pass

return (
LlamaTokenizer,
LlamaTokenizerFast,
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
)

def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
Expand All @@ -77,22 +30,15 @@ def prepare_logits_processor(
processor_list.append(TopKLogitsWarper(top_k))
return processor_list




class Transformers:
"""Represents a `transformers` model."""
class XTransformers(Transformers):
def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
):
self.device = model.device
self.model = model
self.tokenizer = tokenizer
super().__init__(model,tokenizer)
self.logits_processor=None


def reset(self):
self.tokenizer.prompt_tokens = 0
self.tokenizer.completion_tokens = 0
Expand All @@ -116,154 +62,10 @@ def forward(
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
"""Compute a forward pass through the transformer model.
Parameters
----------
input_ids
The input token ids. Must be one or two dimensional.
attention_mask
The attention mask. Must be one or two dimensional.
past_key_values
A tuple of tuples containing the cached key and value tensors for each
attention head.
Returns
-------
The computed logits and the new cached key and value tensors.
"""
assert 0 < input_ids.ndim < 3

if past_key_values:
input_ids = input_ids[..., -1].unsqueeze(-1)

output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)

if self.logits_processor:
next_token_logits = self.logits_processor(input_ids,output.logits[...,-1,:])
else:
next_token_logits = output.logits[..., -1, :]

return next_token_logits, output.past_key_values

def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> torch.FloatTensor:
return self.forward(input_ids, attention_mask, past_key_values)[0]


class TransformersTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""

def __init__(self, model_name: str, **kwargs):
from transformers import AutoTokenizer

kwargs.setdefault("padding_side", "left")
self.model_name = model_name
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token

if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token

self.special_tokens = set(self.tokenizer.all_special_tokens)

self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())

self.prompt_tokens = 0
self.completion_tokens = 0


def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
self.prompt_tokens = output["input_ids"].shape[1]

return output["input_ids"], output["attention_mask"]

def decode(self, token_ids: torch.LongTensor) -> List[str]:
text = self.tokenizer.batch_decode(token_ids)
self.completion_tokens = token_ids.shape[1]
return text

def convert_token_to_string(self, token: str) -> str:
string = self.tokenizer.convert_tokens_to_string([token])

if self.is_llama:
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

def __eq__(self, other):
if isinstance(other, type(self)):
return other.model_name == self.model_name and other.kwargs == self.kwargs
return NotImplemented

def __hash__(self):
return hash(Hasher.hash(self.tokenizer))


def transformers(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Parameters
----------
model_name
The name of the model as listed on Hugging Face's model page.
device
The device(s) on which the model should be loaded. This overrides
the `device_map` entry in `model_kwargs` when provided.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
tokenizer_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the tokenizer.
Returns
-------
A `TransformersModel` model instance.
"""
try:
from transformers import AutoModelForCausalLM
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)

if device is not None:
model_kwargs["device_map"] = device

model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = TransformersTokenizer(model_name, **tokenizer_kwargs)

return Transformers(model, tokenizer)
next_token_logits, output_past_key_values = super.forward(input_ids,attention_mask,past_key_values)

if self.logits_processor:
next_token_logits = self.logits_processor(input_ids,next_token_logits)

return next_token_logits, output_past_key_values
8 changes: 4 additions & 4 deletions XAgentGen/xgen/parser/function_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
import json
import copy
from xgen.text.json_schema import build_regex_from_schema
from outlines.text.json_schema import build_regex_from_object
import xgen.models as models
import xgen.text.generate as generate
import torch
Expand Down Expand Up @@ -296,7 +296,7 @@ def models_to_regex(self):

json_schema = self.post_process(json_schema)
schema = json.dumps(json_schema)
self.regex_strs.append(build_regex_from_schema(schema))
self.regex_strs.append(build_regex_from_object(schema))
return self.regex_strs


Expand Down Expand Up @@ -331,7 +331,7 @@ def post_process(self,schema):

return com_schema

def create_generator(self,model:models.Transformers,function_info:Dict[str,Any],generate_params:Dict = {}):
def create_generator(self,model:models.XTransformers,function_info:Dict[str,Any],generate_params:Dict = {}):
"""
@param: model: the transformer model
@param: functions: a list of functions
Expand All @@ -348,7 +348,7 @@ def create_generator(self,model:models.Transformers,function_info:Dict[str,Any],
self.model = model
# temperature and so on
self.model.add_logits_processor(generate_params)
self.generator = generate.choice(self.model, regex_list,generate_params.get("max_tokens"))
self.generator = generate.multi_regex(self.model, regex_list,generate_params.get("max_tokens"))
return self.generator

def check(self,call_info:str):
Expand Down
1 change: 0 additions & 1 deletion XAgentGen/xgen/text/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion XAgentGen/xgen/text/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .regex import choice, float, integer, json, regex
from .regex import multi_regex
Loading

0 comments on commit 02febc5

Please sign in to comment.