Skip to content

Commit

Permalink
fix max seq len (vllm-project#489)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuXiaoxuanPKU authored Jul 18, 2023
1 parent 20b0d88 commit b4b195b
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ class SchedulerConfig:
"""

def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
max_seq_len: int) -> None:
max_model_len: int) -> None:
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs
self.max_seq_len = max_seq_len
self.max_model_len = max_model_len


_STR_DTYPE_TO_TORCH_DTYPE = {
Expand Down
4 changes: 3 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def _schedule(
break

num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens > self.scheduler_config.max_seq_len:
if num_prompt_tokens > min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens):
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
" and exceeds limit of "
Expand Down
5 changes: 2 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,10 @@ def create_engine_configs(
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray)
model_max_len = getattr(model_config.hf_config,
max_model_len = getattr(model_config.hf_config,
'max_position_embeddings', float('inf'))
max_seq_len = min(self.max_num_batched_tokens, model_max_len)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, max_seq_len)
self.max_num_seqs, max_model_len)
return model_config, cache_config, parallel_config, scheduler_config


Expand Down
3 changes: 1 addition & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,7 @@ def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
continue

# Check if the sequence has reached max_seq_len.
if (seq.get_len() >
self.scheduler.scheduler_config.max_seq_len):
if seq.get_len() > self.scheduler_config.max_model_len:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
Expand Down

0 comments on commit b4b195b

Please sign in to comment.