fix EagleVerifyInput (#3378)
This commit is contained in:
@@ -177,29 +177,21 @@ class EagleVerifyInput:
|
|||||||
spec_steps: int,
|
spec_steps: int,
|
||||||
num_verify_token: int,
|
num_verify_token: int,
|
||||||
):
|
):
|
||||||
score_list = torch.cat(score_list, dim=1).flatten(
|
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
||||||
1
|
build_tree_kernel(
|
||||||
) # b, n, topk; n= 1 + (num_steps-1) * self.topk
|
verified_id,
|
||||||
ss_token_list = torch.cat(
|
score_list,
|
||||||
token_list, dim=1
|
token_list,
|
||||||
) # b, (self.topk + (num_steps-1) * self.topk)
|
parents_list,
|
||||||
top_scores = torch.topk(score_list, num_verify_token - 1, dim=-1)
|
seq_lens,
|
||||||
top_scores_index = top_scores.indices
|
seq_lens_sum,
|
||||||
top_scores_index = torch.sort(top_scores_index).values
|
topk,
|
||||||
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
spec_steps,
|
||||||
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1)
|
num_verify_token,
|
||||||
parent_list = torch.cat(parents_list[:-1], dim=1)
|
)
|
||||||
tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
|
|
||||||
parent_list,
|
|
||||||
top_scores_index,
|
|
||||||
seq_lens,
|
|
||||||
seq_lens_sum,
|
|
||||||
topk,
|
|
||||||
spec_steps,
|
|
||||||
num_verify_token,
|
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
draft_tokens.flatten(),
|
draft_tokens,
|
||||||
tree_mask,
|
tree_mask,
|
||||||
position,
|
position,
|
||||||
retrive_index,
|
retrive_index,
|
||||||
|
|||||||
Reference in New Issue
Block a user