Fix hash collision for multi modal models (#2256)
This commit is contained in:
@@ -124,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason):
|
||||
class ImageInputs:
|
||||
"""The image related inputs."""
|
||||
|
||||
pixel_values: torch.Tensor
|
||||
pixel_values: Union[torch.Tensor, np.array]
|
||||
image_hashes: Optional[list] = None
|
||||
image_sizes: Optional[list] = None
|
||||
image_offsets: Optional[list] = None
|
||||
@@ -132,7 +132,7 @@ class ImageInputs:
|
||||
modalities: Optional[list] = None
|
||||
num_image_tokens: Optional[int] = None
|
||||
|
||||
image_embeds: Optional[List[torch.Tensor]] = None
|
||||
# Llava related
|
||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@@ -141,21 +141,17 @@ class ImageInputs:
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj, vocab_size):
|
||||
# Use image hash as fake token_ids, which is then used for prefix matching
|
||||
def from_dict(obj: dict):
|
||||
ret = ImageInputs(
|
||||
pixel_values=obj["pixel_values"],
|
||||
image_hashes=obj["image_hashes"],
|
||||
)
|
||||
if not isinstance(ret.image_hashes, list):
|
||||
ret.pad_values = [
|
||||
(ret.image_hashes) % vocab_size,
|
||||
(ret.image_hashes >> 16) % vocab_size,
|
||||
(ret.image_hashes >> 32) % vocab_size,
|
||||
(ret.image_hashes >> 64) % vocab_size,
|
||||
]
|
||||
else:
|
||||
ret.pad_values = [x % vocab_size for x in ret.image_hashes]
|
||||
|
||||
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||
# Please note that if the `input_ids` is later used in the model forward,
|
||||
# you also need to clamp the values within the range of [0, vocab_size) to avoid illegal
|
||||
# cuda memory access.
|
||||
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
|
||||
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
@@ -170,21 +166,16 @@ class ImageInputs:
|
||||
|
||||
return ret
|
||||
|
||||
def merge(self, other, vocab_size):
|
||||
def merge(self, other):
|
||||
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
||||
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
||||
|
||||
if isinstance(self.image_hashes, list) and isinstance(other.image_hashes, list):
|
||||
self.image_hashes += other.image_hashes
|
||||
self.pad_values = [x % vocab_size for x in self.image_hashes]
|
||||
else:
|
||||
self.image_hashes = hash(tuple(self.image_hashes, other.image_hashes))
|
||||
self.pad_values = [
|
||||
(self.image_hashes) % vocab_size,
|
||||
(self.image_hashes >> 16) % vocab_size,
|
||||
(self.image_hashes >> 32) % vocab_size,
|
||||
(self.image_hashes >> 64) % vocab_size,
|
||||
]
|
||||
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||
# Please note that if the `input_ids` is later used in the model forward,
|
||||
# you also need to clamp the values within the range of [0, vocab_size) to avoid illegal
|
||||
# cuda memory access.
|
||||
self.image_hashes += other.image_hashes
|
||||
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
|
||||
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
@@ -297,11 +288,11 @@ class Req:
|
||||
# The number of cached tokens, that were already cached in the KV cache
|
||||
self.cached_tokens = 0
|
||||
|
||||
def extend_image_inputs(self, image_inputs, vocab_size):
|
||||
def extend_image_inputs(self, image_inputs):
|
||||
if self.image_inputs is None:
|
||||
self.image_inputs = image_inputs
|
||||
else:
|
||||
self.image_inputs.merge(image_inputs, vocab_size)
|
||||
self.image_inputs.merge(image_inputs)
|
||||
|
||||
# whether request reached finished condition
|
||||
def finished(self) -> bool:
|
||||
|
||||
@@ -526,8 +526,9 @@ class Scheduler:
|
||||
self,
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
# Create a new request
|
||||
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
||||
# Create a new request
|
||||
|
||||
if recv_req.input_embeds is not None:
|
||||
# Generate fake input_ids based on the length of input_embeds
|
||||
seq_length = len(recv_req.input_embeds)
|
||||
@@ -558,20 +559,20 @@ class Scheduler:
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
# Image inputs
|
||||
# Handle image inputs
|
||||
if recv_req.image_inputs is not None:
|
||||
image_inputs = ImageInputs.from_dict(
|
||||
recv_req.image_inputs, self.model_config.vocab_size
|
||||
)
|
||||
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
||||
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
||||
req.origin_input_ids = self.pad_input_ids_func(
|
||||
req.origin_input_ids, image_inputs
|
||||
)
|
||||
req.extend_image_inputs(image_inputs, self.model_config.vocab_size)
|
||||
req.extend_image_inputs(image_inputs)
|
||||
|
||||
if len(req.origin_input_ids) > self.max_req_input_len:
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
"Image request length is longer than the KV cache pool size or "
|
||||
"the max context length aborting because you cannot truncate the image embeds"
|
||||
"the max context length. "
|
||||
"Abort this request because you cannot truncate the image embeds"
|
||||
)
|
||||
req.image_inputs = None
|
||||
req.origin_input_ids = [0]
|
||||
@@ -579,6 +580,7 @@ class Scheduler:
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
# Copy more attributes
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||
req.stream = recv_req.stream
|
||||
|
||||
@@ -10,10 +10,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import copy
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
|
||||
|
||||
@@ -216,6 +216,7 @@ class TokenizerManager:
|
||||
input_ids = obj.input_ids
|
||||
|
||||
if self.is_generation:
|
||||
# TODO: also support getting embeddings for multimodal models
|
||||
image_inputs: Dict = await self.image_processor.process_images_async(
|
||||
obj.image_data, input_text or input_ids, obj
|
||||
)
|
||||
|
||||
@@ -147,6 +147,11 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
else:
|
||||
max_image_offset.append(-1)
|
||||
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
|
||||
@@ -597,13 +597,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
||||
`None` if no images are passed.
|
||||
"""
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
image_inputs = None
|
||||
if forward_batch.image_inputs is not None:
|
||||
image_inputs = [
|
||||
img for img in forward_batch.image_inputs if img is not None
|
||||
]
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
if (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or image_inputs is None
|
||||
@@ -617,6 +619,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
f"(3, seq_len) positions, but got {positions.size()}"
|
||||
)
|
||||
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
|
||||
Reference in New Issue
Block a user