fix: prevent crashes due to logit bias dimension mismatch (#7685)
This commit is contained in:
@@ -322,6 +322,12 @@ class SamplingBatchInfo:
|
|||||||
# Set the flag to True if any of the two has custom logit processor
|
# Set the flag to True if any of the two has custom logit processor
|
||||||
self.has_custom_logit_processor = True
|
self.has_custom_logit_processor = True
|
||||||
|
|
||||||
|
# Merge logit bias - note this has to come before the temperatures tensor update! Otherwise will cause crashes.
|
||||||
|
# See note below on len(self) and len(other).
|
||||||
|
self.logit_bias = merge_bias_tensor(
|
||||||
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
|
||||||
|
)
|
||||||
|
|
||||||
# Note: because the __len()__ operator is defined on the temperatures tensor,
|
# Note: because the __len()__ operator is defined on the temperatures tensor,
|
||||||
# please make sure any merge operation with len(self) or len(other) is done before
|
# please make sure any merge operation with len(self) or len(other) is done before
|
||||||
# the merge operation of the temperatures tensor below.
|
# the merge operation of the temperatures tensor below.
|
||||||
@@ -340,11 +346,6 @@ class SamplingBatchInfo:
|
|||||||
self.need_top_k_sampling |= other.need_top_k_sampling
|
self.need_top_k_sampling |= other.need_top_k_sampling
|
||||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||||
|
|
||||||
# Merge logit bias
|
|
||||||
self.logit_bias = merge_bias_tensor(
|
|
||||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_bias_tensor(
|
def merge_bias_tensor(
|
||||||
lhs: Optional[torch.Tensor],
|
lhs: Optional[torch.Tensor],
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@@ -362,6 +363,11 @@ class EagleVerifyInput:
|
|||||||
)
|
)
|
||||||
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
if bs != len(sampling_info):
|
||||||
|
sampling_info = copy.deepcopy(sampling_info)
|
||||||
|
# NOTE: retrive_index are the indices of the requests that are kept.
|
||||||
|
sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
|
||||||
|
|
||||||
# Apply the custom logit processors if registered in the sampling info.
|
# Apply the custom logit processors if registered in the sampling info.
|
||||||
if sampling_info.has_custom_logit_processor:
|
if sampling_info.has_custom_logit_processor:
|
||||||
apply_custom_logit_processor(
|
apply_custom_logit_processor(
|
||||||
|
|||||||
Reference in New Issue
Block a user