Skip to content

Commit

Permalink
comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
w32zhong committed Sep 22, 2024
1 parent aea9554 commit f7d7043
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 0 deletions.
3 changes: 3 additions & 0 deletions model/cnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,8 @@ def forward(

if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
#print(past_key_values_length, hidden_states.shape)

seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = hidden_states.device if hidden_states is not None else inputs_embeds.device
Expand Down Expand Up @@ -784,6 +786,7 @@ def topK_genrate(self, hidden_states, input_ids, head, logits_processor,max_leng
# hidden_states: [1, L, 4096]
# out_hidden: [1, L, 4096]

# save stable kv cache!
self.stable_kv=past_key_values

# take out the last hidden!
Expand Down
3 changes: 3 additions & 0 deletions model/ea_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def forward(
if output_orig:
orig = self.base_model.lm_head(outputs[0])
hidden_states = outputs[0].clone()

if init:
if logits_processor is not None:
logits = orig[:, -1]
Expand All @@ -193,6 +194,7 @@ def forward(
# hidden_states: [B, L, 4096]
# input_ids: [B, L]
# ea_logits: (torch.cat(ss_token),torch.cat(ss_prob),ss_op)
#print('ea_model.forward() topK_genrate')
ea_logits = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head, logits_processor)
# ea_logits[0]: [1+4+4+1+1=11, topk=10] where 11 is the draft token tree non-leaf size

Expand Down Expand Up @@ -452,6 +454,7 @@ def ea_generate(
time_stats.push('#new tokens per iteration', accept_length.item() + 1)

last_input_ids = input_ids
# calling cnets::topK_genrate() in update_inference_inputs (model/utils.py)
input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs(
input_ids,
candidates,
Expand Down
1 change: 1 addition & 0 deletions model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def update_inference_inputs(

# after init=True, this gets called recurrently.
time_stats.start('tree_drafting')
#print('update_inference_inputs() topK_genrate')
tree_logits = model.ea_layer.topK_genrate(accept_hidden_state_new,
input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1),
head=model.base_model.lm_head, logits_processor=logits_processor)
Expand Down

0 comments on commit f7d7043

Please sign in to comment.