fix EagleVerifyInput (#3378)
This commit is contained in:
@@ -177,29 +177,21 @@ class EagleVerifyInput:
|
||||
spec_steps: int,
|
||||
num_verify_token: int,
|
||||
):
|
||||
score_list = torch.cat(score_list, dim=1).flatten(
|
||||
1
|
||||
) # b, n, topk; n= 1 + (num_steps-1) * self.topk
|
||||
ss_token_list = torch.cat(
|
||||
token_list, dim=1
|
||||
) # b, (self.topk + (num_steps-1) * self.topk)
|
||||
top_scores = torch.topk(score_list, num_verify_token - 1, dim=-1)
|
||||
top_scores_index = top_scores.indices
|
||||
top_scores_index = torch.sort(top_scores_index).values
|
||||
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
||||
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1)
|
||||
parent_list = torch.cat(parents_list[:-1], dim=1)
|
||||
tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
|
||||
parent_list,
|
||||
top_scores_index,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
topk,
|
||||
spec_steps,
|
||||
num_verify_token,
|
||||
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
||||
build_tree_kernel(
|
||||
verified_id,
|
||||
score_list,
|
||||
token_list,
|
||||
parents_list,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
topk,
|
||||
spec_steps,
|
||||
num_verify_token,
|
||||
)
|
||||
)
|
||||
return cls(
|
||||
draft_tokens.flatten(),
|
||||
draft_tokens,
|
||||
tree_mask,
|
||||
position,
|
||||
retrive_index,
|
||||
|
||||
Reference in New Issue
Block a user