Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Adding Speculative decoding to Throughput Benchmarking script #5223

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
adding spec decode to benchmark throughput
  • Loading branch information
abhibambhaniya committed Jun 3, 2024
commit 3906a128dbf3e45998d78d0b044b3bcb03a14c0f
16 changes: 13 additions & 3 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def sample_requests(
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
speculative_model: str,
num_speculative_tokens: int,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
Expand All @@ -78,12 +80,15 @@ def run_vllm(
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
use_v2_block_manager: bool,
gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
speculative_model=speculative_model,
num_speculative_tokens=num_speculative_tokens,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
Expand All @@ -99,6 +104,7 @@ def run_vllm(
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
use_v2_block_manager=use_v2_block_manager,
max_num_batched_tokens=max_num_batched_tokens,
)

Expand Down Expand Up @@ -219,14 +225,15 @@ def main(args: argparse.Namespace):

if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.quantization,
requests, args.model, args.speculative_model,
args.num_speculative_tokens, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.gpu_memory_utilization,
args.download_dir)
args.max_num_batched_tokens, args.use_v2_block_manager,
args.gpu_memory_utilization, args.download_dir)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -275,6 +282,8 @@ def main(args: argparse.Namespace):
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
Expand All @@ -291,6 +300,7 @@ def main(args: argparse.Namespace):
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
Expand Down
Loading