fix: prevent crashes due to logit bias dimension mismatch (#7685)

This commit is contained in:
J
2025-07-23 15:30:55 -07:00
committed by GitHub
parent 4953f4ca9a
commit 0e7a5b2694
2 changed files with 12 additions and 5 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import copy
import logging
import os
import time
@@ -362,6 +363,11 @@ class EagleVerifyInput:
)
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
if bs != len(sampling_info):
sampling_info = copy.deepcopy(sampling_info)
# NOTE: retrive_index are the indices of the requests that are kept.
sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
# Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor:
apply_custom_logit_processor(