Fixed the issue where eagle3 TPOT was not as good as without eagle3. (#9404)
This commit is contained in:
@@ -453,12 +453,13 @@ class EagleVerifyInput:
|
||||
sampling_info.top_ks, self.draft_token_num, dim=0
|
||||
),
|
||||
) # (bs * draft_token_num, vocab_size)
|
||||
target_probs = top_p_renorm_prob(
|
||||
target_probs,
|
||||
torch.repeat_interleave(
|
||||
sampling_info.top_ps, self.draft_token_num, dim=0
|
||||
),
|
||||
)
|
||||
if not torch.all(sampling_info.top_ps == 1.0):
|
||||
target_probs = top_p_renorm_prob(
|
||||
target_probs,
|
||||
torch.repeat_interleave(
|
||||
sampling_info.top_ps, self.draft_token_num, dim=0
|
||||
),
|
||||
)
|
||||
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
||||
|
||||
draft_probs = torch.zeros(
|
||||
|
||||
Reference in New Issue
Block a user