Fixed the issue where eagle3 TPOT was not as good as without eagle3. (#9404)
This commit is contained in:
@@ -453,12 +453,13 @@ class EagleVerifyInput:
|
|||||||
sampling_info.top_ks, self.draft_token_num, dim=0
|
sampling_info.top_ks, self.draft_token_num, dim=0
|
||||||
),
|
),
|
||||||
) # (bs * draft_token_num, vocab_size)
|
) # (bs * draft_token_num, vocab_size)
|
||||||
target_probs = top_p_renorm_prob(
|
if not torch.all(sampling_info.top_ps == 1.0):
|
||||||
target_probs,
|
target_probs = top_p_renorm_prob(
|
||||||
torch.repeat_interleave(
|
target_probs,
|
||||||
sampling_info.top_ps, self.draft_token_num, dim=0
|
torch.repeat_interleave(
|
||||||
),
|
sampling_info.top_ps, self.draft_token_num, dim=0
|
||||||
)
|
),
|
||||||
|
)
|
||||||
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
||||||
|
|
||||||
draft_probs = torch.zeros(
|
draft_probs = torch.zeros(
|
||||||
|
|||||||
Reference in New Issue
Block a user