diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 39712f690..6711a0c2e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -149,8 +149,8 @@ class ImageInputs: # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. # Please note that if the `input_ids` is later used in the model forward, - # you also need to clamp the values within the range of [0, vocab_size) to avoid illegal - # cuda memory access. + # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound + # errors in cuda kernels. See also llava.py for example. ret.pad_values = [x % (1 << 30) for x in ret.image_hashes] optional_args = [ @@ -172,8 +172,8 @@ class ImageInputs: # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. # Please note that if the `input_ids` is later used in the model forward, - # you also need to clamp the values within the range of [0, vocab_size) to avoid illegal - # cuda memory access. + # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound + # errors in cuda kernels. See also llava.py for example. self.image_hashes += other.image_hashes self.pad_values = [x % (1 << 30) for x in self.image_hashes] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a72a62b69..d35e61676 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -568,15 +568,17 @@ class Scheduler: ) req.extend_image_inputs(image_inputs) - if len(req.origin_input_ids) > self.max_req_input_len: - req.finished_reason = FINISH_ABORT( - "Image request length is longer than the KV cache pool size or " - "the max context length. " - "Abort this request because you cannot truncate the image embeds" + if len(req.origin_input_ids) >= self.max_req_input_len: + logger.error( + "Multimodal prompt is too long after expanding multimodal tokens. " + f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. " ) - req.image_inputs = None req.origin_input_ids = [0] + req.image_inputs = None req.sampling_params.max_new_tokens = 0 + req.finished_reason = FINISH_ABORT( + "Multimodal prompt is too long. Check server logs for details." + ) self.waiting_queue.append(req) return diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 12993f390..eb1784145 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -134,7 +134,6 @@ class LlavaBaseForCausalLM(nn.Module): image_inputs = forward_batch.image_inputs if forward_batch.forward_mode.is_extend(): - bs = forward_batch.batch_size # Got List[List[str]] extend it to List[str] # The length of the List should be equal to batch size modalities_list = [] @@ -142,7 +141,7 @@ class LlavaBaseForCausalLM(nn.Module): for im in image_inputs: if im and im.modalities is not None: modalities_list.extend(im.modalities) - if im and im.image_offsets is not None: + if im and im.image_offsets: max_image_offset.append(max(im.image_offsets)) else: max_image_offset.append(-1) @@ -159,6 +158,7 @@ class LlavaBaseForCausalLM(nn.Module): need_vision = start_positions <= np.array(max_image_offset) if need_vision.any(): + bs = forward_batch.batch_size pixel_values = [ image_inputs[i].pixel_values for i in range(bs) if need_vision[i] ]