Large Language Models (LLMs) have seen adoption and integration into a variety of pipelines to extract and stuctured data and even act autonomously through projects such as LangChain and AutoGPT. In such use cases, success of the system depends on the LLM's output adhereing to a specific format or syntax, eg. JSON. While LLMs can follow formatting instructions provided in their prompts, deviation can occur (especially for complex formatting instructions) leading to failure when the output is parsed.
This project explores guided LM sampling under syntax constraints allowing user-defined syntax constraints to be enforced during token sampling and guaranteeing adherence to this syntax in the final generated text.
Currently only supports greedy and single-beam sampling.
- Explore currently supported constraints with Vicuna-13B in oobabooga's text-generation web UI
- Test applying syntax constraints to non-instruction-finetuned LM's
- Walkthrough of developing and applying a new syntax constraint
Install our fork of 🤗 transformers:
pip install git+https://github.com/IsaacRe/transformers@syntactically-constrained-sampling
Either
- Install the most recent version of this repo:
pip install git+https://github.com/IsaacRe/Syntactically-Constrained-Sampling
- Or build and install local version
python -m build && pip install dist/sampling-constraints-<VERSION>.tar.gz
Try out some of the constraints below:
Have an idea for a useful constraint? Open an issue!
Force LM output to follow a specific JSON object, defined in typescript(ish) notation:
from transformers.pipelines import pipeline
schema = """
{
output: string,
array_output: []number,
optional_output?: number,
nested_schema: []{
inner_output: string
}
}
"""
pipe = pipeline(model='gpt2')
output = pipe('Input', enforce_json_schema=schema)
Example output:
{
"output": "some text",
"array_output": [1, 2, 3],
"nested_schema": [
{"inner_output": "more text"},
{"inner_output": "even more text"}
]
}
Force LLM to select from a particular set of outputs
from transformers.pipelines import pipeline
options = 'Option A,Option B'
pipe = pipeline(model='gpt2')
output = pipe('Input', enforce_one_of=options)
Example output:
Option A
The current approach relies on an IncrementalParser
class to periodically check validity of sampled sequences under its particular syntax constraint. Inheriting classes must implement the following methods:
append(self, chars: str) -> None:
- append
chars
to the string being parsed and continue parsing - raise
ParseFailure
if deviation from expected format
copy(self) -> IncrementalParser
- return copy of the parser and its internal state
Parsing a sequence can be carried out by making subsequent calls to append
, each time passing a new token to append to the string being parsed. Validity of a sequence after a new token is appended can be checked by creating a copy of the current parser then calling its append
method with the candidate token and checking if a ParseFailure
is raised.
Using this fork of 🤗 transformers we hook into the generation loop and enforce syntax constraints defined via an IncrementalParser
before each sampling step. For each token in the tokenizer's vocabulary it checks whether a ParseFailure
is raised when appending it to a copy of the current parse state. Once the LM's forward pass completes logits corresponding to next tokens that caused a ParseFailure
are suppressed (effectively removed from the distribution of tokens that may be sampled).
Iterating over all tokens in a LM's vocab is expensive. To save time we can use TokenGroup
's to include or exclude sets of tokens from sampling during each generation step, without having to check continued syntax validity for each token in that set.
TokenGroup
classes must implement the static filter
method:
filter(token: str) -> bool
- returns whether the passed token is a member of this group
Incremental Parsers can leverage Token Groups by implementing the following methods:
valid_token_group(self) -> Type[TokenGroup]
- returns a
TokenGroup
class defining a subset of tokens allowed to be sampled next
invalid_token_group(self) -> Type[TokenGroup]
- returns a
TokenGroup
class defining a subset of tokens not allowed to be sampled next
The check handler will filter the tokenizer vocab into groups before generation begins and will call the above methods to check for valid/invalid token groups before checking per-token validity. If present, tokens in the invalid token group will be suppressed. Remaining tokens not present in either valid/invalid groups are then checked for syntax validity.
There may be times during parsing when a particular character or sequence of characters are required to maintain syntax validity. An IncrementalParser
can implement get_next
to skip token validity checks when this is the case:
get_next(self) -> List[str]
- returns exhaustive list of valid next sequences to constrain generation, or empty list ot leave generation unconstrained
The check handler evaluates the result of get_next
before all other checks. If one or more results are returned, one token prefixing each result will be allowed and all other tokens will be suppressed, otherwise checks continue.