From 0e7a5b26945c7a21dbaff10254477d2d3de779ff Mon Sep 17 00:00:00 2001 From: J Date: Wed, 23 Jul 2025 15:30:55 -0700 Subject: [PATCH] fix: prevent crashes due to logit bias dimension mismatch (#7685) --- python/sglang/srt/sampling/sampling_batch_info.py | 11 ++++++----- python/sglang/srt/speculative/eagle_utils.py | 6 ++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index f88082e69..bcdadbe11 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -322,6 +322,12 @@ class SamplingBatchInfo: # Set the flag to True if any of the two has custom logit processor self.has_custom_logit_processor = True + # Merge logit bias - note this has to come before the temperatures tensor update! Otherwise will cause crashes. + # See note below on len(self) and len(other). + self.logit_bias = merge_bias_tensor( + self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0 + ) + # Note: because the __len()__ operator is defined on the temperatures tensor, # please make sure any merge operation with len(self) or len(other) is done before # the merge operation of the temperatures tensor below. @@ -340,11 +346,6 @@ class SamplingBatchInfo: self.need_top_k_sampling |= other.need_top_k_sampling self.need_min_p_sampling |= other.need_min_p_sampling - # Merge logit bias - self.logit_bias = merge_bias_tensor( - self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0 - ) - def merge_bias_tensor( lhs: Optional[torch.Tensor], diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 83724b385..7f7e21e96 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import logging import os import time @@ -362,6 +363,11 @@ class EagleVerifyInput: ) accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") + if bs != len(sampling_info): + sampling_info = copy.deepcopy(sampling_info) + # NOTE: retrive_index are the indices of the requests that are kept. + sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index) + # Apply the custom logit processors if registered in the sampling info. if sampling_info.has_custom_logit_processor: apply_custom_logit_processor(