From f50a6cf4435bd39b854efcf00814bc796b7f9b21 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 29 Nov 2024 03:15:58 -0800 Subject: [PATCH] Fix hash collision for multi modal models (#2256) --- python/sglang/srt/managers/schedule_batch.py | 45 ++++++++----------- python/sglang/srt/managers/scheduler.py | 16 ++++--- .../sglang/srt/managers/session_controller.py | 3 -- .../sglang/srt/managers/tokenizer_manager.py | 1 + python/sglang/srt/models/llava.py | 5 +++ python/sglang/srt/models/qwen2_vl.py | 11 ++++- 6 files changed, 42 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c4daa8a07..39712f690 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -124,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason): class ImageInputs: """The image related inputs.""" - pixel_values: torch.Tensor + pixel_values: Union[torch.Tensor, np.array] image_hashes: Optional[list] = None image_sizes: Optional[list] = None image_offsets: Optional[list] = None @@ -132,7 +132,7 @@ class ImageInputs: modalities: Optional[list] = None num_image_tokens: Optional[int] = None - image_embeds: Optional[List[torch.Tensor]] = None + # Llava related aspect_ratio_ids: Optional[List[torch.Tensor]] = None aspect_ratio_mask: Optional[List[torch.Tensor]] = None @@ -141,21 +141,17 @@ class ImageInputs: mrope_position_delta: Optional[torch.Tensor] = None @staticmethod - def from_dict(obj, vocab_size): - # Use image hash as fake token_ids, which is then used for prefix matching + def from_dict(obj: dict): ret = ImageInputs( pixel_values=obj["pixel_values"], image_hashes=obj["image_hashes"], ) - if not isinstance(ret.image_hashes, list): - ret.pad_values = [ - (ret.image_hashes) % vocab_size, - (ret.image_hashes >> 16) % vocab_size, - (ret.image_hashes >> 32) % vocab_size, - (ret.image_hashes >> 64) % vocab_size, - ] - else: - ret.pad_values = [x % vocab_size for x in ret.image_hashes] + + # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. + # Please note that if the `input_ids` is later used in the model forward, + # you also need to clamp the values within the range of [0, vocab_size) to avoid illegal + # cuda memory access. + ret.pad_values = [x % (1 << 30) for x in ret.image_hashes] optional_args = [ "image_sizes", @@ -170,21 +166,16 @@ class ImageInputs: return ret - def merge(self, other, vocab_size): + def merge(self, other): assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:] self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values]) - if isinstance(self.image_hashes, list) and isinstance(other.image_hashes, list): - self.image_hashes += other.image_hashes - self.pad_values = [x % vocab_size for x in self.image_hashes] - else: - self.image_hashes = hash(tuple(self.image_hashes, other.image_hashes)) - self.pad_values = [ - (self.image_hashes) % vocab_size, - (self.image_hashes >> 16) % vocab_size, - (self.image_hashes >> 32) % vocab_size, - (self.image_hashes >> 64) % vocab_size, - ] + # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. + # Please note that if the `input_ids` is later used in the model forward, + # you also need to clamp the values within the range of [0, vocab_size) to avoid illegal + # cuda memory access. + self.image_hashes += other.image_hashes + self.pad_values = [x % (1 << 30) for x in self.image_hashes] optional_args = [ "image_sizes", @@ -297,11 +288,11 @@ class Req: # The number of cached tokens, that were already cached in the KV cache self.cached_tokens = 0 - def extend_image_inputs(self, image_inputs, vocab_size): + def extend_image_inputs(self, image_inputs): if self.image_inputs is None: self.image_inputs = image_inputs else: - self.image_inputs.merge(image_inputs, vocab_size) + self.image_inputs.merge(image_inputs) # whether request reached finished condition def finished(self) -> bool: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e0a67b435..a72a62b69 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -526,8 +526,9 @@ class Scheduler: self, recv_req: TokenizedGenerateReqInput, ): + # Create a new request if recv_req.session_id is None or recv_req.session_id not in self.sessions: - # Create a new request + if recv_req.input_embeds is not None: # Generate fake input_ids based on the length of input_embeds seq_length = len(recv_req.input_embeds) @@ -558,20 +559,20 @@ class Scheduler: self.waiting_queue.append(req) return - # Image inputs + # Handle image inputs if recv_req.image_inputs is not None: - image_inputs = ImageInputs.from_dict( - recv_req.image_inputs, self.model_config.vocab_size - ) + image_inputs = ImageInputs.from_dict(recv_req.image_inputs) + # Expand a single image token into multiple dummy tokens for receiving image embeddings req.origin_input_ids = self.pad_input_ids_func( req.origin_input_ids, image_inputs ) - req.extend_image_inputs(image_inputs, self.model_config.vocab_size) + req.extend_image_inputs(image_inputs) if len(req.origin_input_ids) > self.max_req_input_len: req.finished_reason = FINISH_ABORT( "Image request length is longer than the KV cache pool size or " - "the max context length aborting because you cannot truncate the image embeds" + "the max context length. " + "Abort this request because you cannot truncate the image embeds" ) req.image_inputs = None req.origin_input_ids = [0] @@ -579,6 +580,7 @@ class Scheduler: self.waiting_queue.append(req) return + # Copy more attributes req.return_logprob = recv_req.return_logprob req.top_logprobs_num = recv_req.top_logprobs_num req.stream = recv_req.stream diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index f267a0dc2..dc5a1b670 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -10,10 +10,7 @@ # limitations under the License. # ============================================================================== -import copy import uuid -from dataclasses import dataclass -from typing import Optional from sglang.srt.managers.io_struct import TokenizedGenerateReqInput from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3b3998bec..77bc91218 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -216,6 +216,7 @@ class TokenizerManager: input_ids = obj.input_ids if self.is_generation: + # TODO: also support getting embeddings for multimodal models image_inputs: Dict = await self.image_processor.process_images_async( obj.image_data, input_text or input_ids, obj ) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 514c7c1bd..12993f390 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -147,6 +147,11 @@ class LlavaBaseForCausalLM(nn.Module): else: max_image_offset.append(-1) + # Clamp input ids. This is because the input_ids for the image tokens are + # filled with the hash values of the image for the prefix matching in the radix attention. + # There values are useless because their embeddings will be replaced by vision embeddings anyway. + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 0258ed332..dc58383ee 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -597,13 +597,15 @@ class Qwen2VLForConditionalGeneration(nn.Module): image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. `None` if no images are passed. """ + if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": + positions = forward_batch.mrope_positions + image_inputs = None if forward_batch.image_inputs is not None: image_inputs = [ img for img in forward_batch.image_inputs if img is not None ] - if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": - positions = forward_batch.mrope_positions + if ( forward_batch.forward_mode.is_decode() or image_inputs is None @@ -617,6 +619,11 @@ class Qwen2VLForConditionalGeneration(nn.Module): f"(3, seq_len) positions, but got {positions.size()}" ) + # Clamp input ids. This is because the input_ids for the image tokens are + # filled with the hash values of the image for the prefix matching in the radix attention. + # There values are useless because their embeddings will be replaced by vision embeddings anyway. + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + inputs_embeds = self.model.embed_tokens(input_ids) extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu