Fixed the issue where eagle3 TPOT was not as good as without eagle3. (#9404)

This commit is contained in:
jiapingW
2025-08-21 07:42:01 +08:00
committed by GitHub
parent c10b8e6a0f
commit e99729c9f3

View File

@@ -453,12 +453,13 @@ class EagleVerifyInput:
sampling_info.top_ks, self.draft_token_num, dim=0 sampling_info.top_ks, self.draft_token_num, dim=0
), ),
) # (bs * draft_token_num, vocab_size) ) # (bs * draft_token_num, vocab_size)
target_probs = top_p_renorm_prob( if not torch.all(sampling_info.top_ps == 1.0):
target_probs, target_probs = top_p_renorm_prob(
torch.repeat_interleave( target_probs,
sampling_info.top_ps, self.draft_token_num, dim=0 torch.repeat_interleave(
), sampling_info.top_ps, self.draft_token_num, dim=0
) ),
)
target_probs = target_probs.reshape(bs, self.draft_token_num, -1) target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
draft_probs = torch.zeros( draft_probs = torch.zeros(