Skip to content

Commit

Permalink
GraphInferenceEngine support
Browse files Browse the repository at this point in the history
  • Loading branch information
poedator committed Apr 15, 2024
1 parent 9192d99 commit 4cc555d
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import torch

from Engine.Engine import GraphInferenceEngineTG
from Engine.Engine import GraphInferenceEngineTG, GraphInferenceEngine
from Engine.offload_engine import OffloadEngine
from utils import _make_causal_mask
from tqdm import tqdm, trange
Expand Down Expand Up @@ -37,28 +37,39 @@ def benchmark(args):

if args.offloading:
graph_engine = OffloadEngine(max_length=args.max_length, model_name_or_path=args.model, dtype=dtype, device=device) # set stay_layers?
graph_engine.inference(input_ids=prefix, storage_ids=prefix_storage_ids, position_ids=prefix_position_ids, \
attn_mask=attn_mask[..., :args.prefix_length,:args.prefix_length])
else:
graph_engine = GraphInferenceEngineTG(max_length=args.max_length, model_name_or_path=args.model, dtype=dtype, device=device, offloading=args.offloading)

graph_engine.inference(input_ids=prefix, storage_ids=prefix_storage_ids, position_ids=prefix_position_ids, attn_mask=attn_mask[..., :args.prefix_length,:args.prefix_length])
graph_engine = GraphInferenceEngine(max_length=args.max_length, model_name_or_path=args.model, dtype=dtype, device=device)
print("initializing GraphInferenceEngine model")
graph_capture_list = [1,2,4,8,16,32,64,128,256,512,768,1024]
graph_engine.initialize_cuda_graph(graph_capture_list)
graph_engine.inference(input_ids=prefix, storage_ids=prefix_storage_ids, position_ids=prefix_position_ids, \
attn_mask=attn_mask[..., :args.prefix_length,:])
print("test run OK")

avg_forward_pass_times = []
for decode_length in decode_lengths:
input_ids = torch.randint(low=3, high=30000, size=(1, decode_length), device=device)
storage_ids = torch.arange(decode_length, device=device) + args.prefix_length
position_ids = storage_ids.clone().unsqueeze(0)
curr_attn_mask = attn_mask[..., args.prefix_length: args.prefix_length + decode_length,:args.prefix_length + decode_length].clone()
if args.offloading:
curr_attn_mask = attn_mask[..., args.prefix_length: args.prefix_length + decode_length,:args.prefix_length + decode_length].clone()
else:
curr_attn_mask = attn_mask[..., args.prefix_length: args.prefix_length + decode_length,:].clone()

for _ in trange(args.warmup, desc=f"warmup, {decode_length=}", leave=False):
graph_engine.inference(input_ids=input_ids, storage_ids=storage_ids, position_ids=position_ids, attn_mask=curr_attn_mask)
graph_engine.set_kv_len(args.prefix_length)
if args.offloading:
graph_engine.set_kv_len(args.prefix_length)

torch.cuda.synchronize()
t1 = time.time()

for _ in trange(args.num_repeats, desc=f"measuring, {decode_length=}", leave=False):
graph_engine.inference(input_ids=input_ids, storage_ids=storage_ids, position_ids=position_ids, attn_mask=curr_attn_mask)
graph_engine.set_kv_len(args.prefix_length)
if args.offloading:
graph_engine.set_kv_len(args.prefix_length)

torch.cuda.synchronize()
t2 = time.time()
Expand Down

0 comments on commit 4cc555d

Please sign in to comment.