fix: prevent crashes due to logit bias dimension mismatch (#7685)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
@@ -362,6 +363,11 @@ class EagleVerifyInput:
|
||||
)
|
||||
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
||||
|
||||
if bs != len(sampling_info):
|
||||
sampling_info = copy.deepcopy(sampling_info)
|
||||
# NOTE: retrive_index are the indices of the requests that are kept.
|
||||
sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
|
||||
|
||||
# Apply the custom logit processors if registered in the sampling info.
|
||||
if sampling_info.has_custom_logit_processor:
|
||||
apply_custom_logit_processor(
|
||||
|
||||
Reference in New Issue
Block a user