Skip to content

Commit

Permalink
[Bugfix] Fix a bug in RequestOutput.finished (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jun 22, 2023
1 parent 2e0d314 commit 14f0b39
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/llm_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main(args: argparse.Namespace):

request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished():
if request_output.finished:
print(request_output)

if not (engine.has_unfinished_requests() or test_prompts):
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ async def generate(
yield request_output

# Once finished, release the resources of the sequence group.
if request_output.finished():
if request_output.finished:
if self.log_requests:
logger.info(f"Finished request {request_id}.")

Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished():
if output.finished:
outputs.append(output)
if use_tqdm:
pbar.update(1)
Expand Down
12 changes: 7 additions & 5 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ def __init__(
prompt: str,
prompt_token_ids: List[int],
outputs: List[CompletionOutput],
finished: bool,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.outputs = outputs
self.finished = finished

@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
Expand Down Expand Up @@ -95,13 +97,13 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# Every sequence in the sequence group should have the same prompt.
prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs)
finished = seq_group.is_finished()
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
finished)

def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"outputs={self.outputs})")

def finished(self) -> bool:
return all(output.finished() for output in self.outputs)
f"outputs={self.outputs}, "
f"finished={self.finished})")

0 comments on commit 14f0b39

Please sign in to comment.