Skip to content

Commit

Permalink
Generate: Add text streamer decoding options (huggingface#22544)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Apr 4, 2023
1 parent 41a2f35 commit 1905384
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 56 deletions.
94 changes: 38 additions & 56 deletions src/transformers/generation/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class TextStreamer(BaseStreamer):
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
Expand All @@ -59,10 +63,15 @@ class TextStreamer(BaseStreamer):
```
"""

def __init__(self, tokenizer: "AutoTokenizer"):
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs

# variables used in the streaming process
self.token_cache = []
self.print_len = 0
self.next_tokens_are_prompt = True

def put(self, value):
"""
Expand All @@ -73,11 +82,15 @@ def put(self, value):
elif len(value.shape) > 1:
value = value[0]

if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return

# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache)
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)

# After symbol for a new line, we flush the cache.
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
Expand All @@ -94,30 +107,34 @@ def end(self):
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache)
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""

# Print a newline (and the remaining text, if any)
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)

def on_finalized_text(self, token: str, stream_end: bool = False):
"""Prints the new text to stdout."""
print(token, flush=True, end="" if not stream_end else None)
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
print(text, flush=True, end="" if not stream_end else None)


class TextIteratorStreamer(BaseStreamer):
class TextIteratorStreamer(TextStreamer):
"""
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
useful for applications that want to use the generated text in a non-blocking way (e.g. in an interactive Gradio
demo).
useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive
Gradio demo).
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
Expand All @@ -142,58 +159,23 @@ class TextIteratorStreamer(BaseStreamer):
```
"""

def __init__(self, tokenizer: "AutoTokenizer"):
self.tokenizer = tokenizer
self.token_cache = []
self.print_len = 0
self.queue = Queue()
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = Queue()
self.stop_signal = None

def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
self.text_queue.put(text)
if stream_end:
self.text_queue.put(self.stop_signal)

def __iter__(self):
return self

def __next__(self):
value = self.queue.get()
value = self.text_queue.get()
if value == self.stop_signal:
raise StopIteration()
else:
return value

def put(self, value):
"""
Recives tokens, decodes them, and pushes text to the queue as soon as it form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]

# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache)

# After symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)
self.queue.put(printable_text)

def end(self):
"""Flushes any remaining cache and puts the stop signal in the queue."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache)
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""

self.queue.put(printable_text)
self.queue.put(self.stop_signal) # Put the stop signal
39 changes: 39 additions & 0 deletions tests/generation/test_streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@


if is_torch_available():
import torch

from transformers import AutoModelForCausalLM


Expand Down Expand Up @@ -63,3 +65,40 @@ def test_iterator_streamer_matches_non_streaming(self):
streamer_text += new_text

self.assertEqual(streamer_text, greedy_text)

def test_text_streamer_skip_prompt(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1

input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
new_greedy_ids = greedy_ids[:, input_ids.shape[1] :]
new_greedy_text = tokenizer.decode(new_greedy_ids[0])

with CaptureStdout() as cs:
streamer = TextStreamer(tokenizer, skip_prompt=True)
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
streamer_text = cs.out[:-1]

self.assertEqual(streamer_text, new_greedy_text)

def test_text_streamer_decode_kwargs(self):
# Tests that we can pass `decode_kwargs` to the streamer to control how the tokens are decoded. Must be tested
# with actual models -- the dummy models' tokenizers are not aligned with their models, and
# `skip_special_tokens=True` has no effect on them
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2").to(torch_device)
model.config.eos_token_id = -1

input_ids = torch.ones((1, 5), device=torch_device).long() * model.config.bos_token_id
with CaptureStdout() as cs:
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)

# The prompt contains a special token, so the streamer should not print it. As such, the output text, when
# re-tokenized, must only contain one token
streamer_text = cs.out[:-1] # Remove the final "\n"
streamer_text_tokenized = tokenizer(streamer_text, return_tensors="pt")
self.assertEqual(streamer_text_tokenized.input_ids.shape, (1, 1))

0 comments on commit 1905384

Please sign in to comment.