fix: prevent crashes due to logit bias dimension mismatch (#7685)
This commit is contained in:
@@ -322,6 +322,12 @@ class SamplingBatchInfo:
|
||||
# Set the flag to True if any of the two has custom logit processor
|
||||
self.has_custom_logit_processor = True
|
||||
|
||||
# Merge logit bias - note this has to come before the temperatures tensor update! Otherwise will cause crashes.
|
||||
# See note below on len(self) and len(other).
|
||||
self.logit_bias = merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
|
||||
)
|
||||
|
||||
# Note: because the __len()__ operator is defined on the temperatures tensor,
|
||||
# please make sure any merge operation with len(self) or len(other) is done before
|
||||
# the merge operation of the temperatures tensor below.
|
||||
@@ -340,11 +346,6 @@ class SamplingBatchInfo:
|
||||
self.need_top_k_sampling |= other.need_top_k_sampling
|
||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||
|
||||
# Merge logit bias
|
||||
self.logit_bias = merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
|
||||
)
|
||||
|
||||
|
||||
def merge_bias_tensor(
|
||||
lhs: Optional[torch.Tensor],
|
||||
|
||||
Reference in New Issue
Block a user