Fix hash collision for multi modal models (#2256)
This commit is contained in:
@@ -124,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason):
|
|||||||
class ImageInputs:
|
class ImageInputs:
|
||||||
"""The image related inputs."""
|
"""The image related inputs."""
|
||||||
|
|
||||||
pixel_values: torch.Tensor
|
pixel_values: Union[torch.Tensor, np.array]
|
||||||
image_hashes: Optional[list] = None
|
image_hashes: Optional[list] = None
|
||||||
image_sizes: Optional[list] = None
|
image_sizes: Optional[list] = None
|
||||||
image_offsets: Optional[list] = None
|
image_offsets: Optional[list] = None
|
||||||
@@ -132,7 +132,7 @@ class ImageInputs:
|
|||||||
modalities: Optional[list] = None
|
modalities: Optional[list] = None
|
||||||
num_image_tokens: Optional[int] = None
|
num_image_tokens: Optional[int] = None
|
||||||
|
|
||||||
image_embeds: Optional[List[torch.Tensor]] = None
|
# Llava related
|
||||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
@@ -141,21 +141,17 @@ class ImageInputs:
|
|||||||
mrope_position_delta: Optional[torch.Tensor] = None
|
mrope_position_delta: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_dict(obj, vocab_size):
|
def from_dict(obj: dict):
|
||||||
# Use image hash as fake token_ids, which is then used for prefix matching
|
|
||||||
ret = ImageInputs(
|
ret = ImageInputs(
|
||||||
pixel_values=obj["pixel_values"],
|
pixel_values=obj["pixel_values"],
|
||||||
image_hashes=obj["image_hashes"],
|
image_hashes=obj["image_hashes"],
|
||||||
)
|
)
|
||||||
if not isinstance(ret.image_hashes, list):
|
|
||||||
ret.pad_values = [
|
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||||
(ret.image_hashes) % vocab_size,
|
# Please note that if the `input_ids` is later used in the model forward,
|
||||||
(ret.image_hashes >> 16) % vocab_size,
|
# you also need to clamp the values within the range of [0, vocab_size) to avoid illegal
|
||||||
(ret.image_hashes >> 32) % vocab_size,
|
# cuda memory access.
|
||||||
(ret.image_hashes >> 64) % vocab_size,
|
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
|
||||||
]
|
|
||||||
else:
|
|
||||||
ret.pad_values = [x % vocab_size for x in ret.image_hashes]
|
|
||||||
|
|
||||||
optional_args = [
|
optional_args = [
|
||||||
"image_sizes",
|
"image_sizes",
|
||||||
@@ -170,21 +166,16 @@ class ImageInputs:
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def merge(self, other, vocab_size):
|
def merge(self, other):
|
||||||
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
||||||
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
||||||
|
|
||||||
if isinstance(self.image_hashes, list) and isinstance(other.image_hashes, list):
|
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||||
self.image_hashes += other.image_hashes
|
# Please note that if the `input_ids` is later used in the model forward,
|
||||||
self.pad_values = [x % vocab_size for x in self.image_hashes]
|
# you also need to clamp the values within the range of [0, vocab_size) to avoid illegal
|
||||||
else:
|
# cuda memory access.
|
||||||
self.image_hashes = hash(tuple(self.image_hashes, other.image_hashes))
|
self.image_hashes += other.image_hashes
|
||||||
self.pad_values = [
|
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
|
||||||
(self.image_hashes) % vocab_size,
|
|
||||||
(self.image_hashes >> 16) % vocab_size,
|
|
||||||
(self.image_hashes >> 32) % vocab_size,
|
|
||||||
(self.image_hashes >> 64) % vocab_size,
|
|
||||||
]
|
|
||||||
|
|
||||||
optional_args = [
|
optional_args = [
|
||||||
"image_sizes",
|
"image_sizes",
|
||||||
@@ -297,11 +288,11 @@ class Req:
|
|||||||
# The number of cached tokens, that were already cached in the KV cache
|
# The number of cached tokens, that were already cached in the KV cache
|
||||||
self.cached_tokens = 0
|
self.cached_tokens = 0
|
||||||
|
|
||||||
def extend_image_inputs(self, image_inputs, vocab_size):
|
def extend_image_inputs(self, image_inputs):
|
||||||
if self.image_inputs is None:
|
if self.image_inputs is None:
|
||||||
self.image_inputs = image_inputs
|
self.image_inputs = image_inputs
|
||||||
else:
|
else:
|
||||||
self.image_inputs.merge(image_inputs, vocab_size)
|
self.image_inputs.merge(image_inputs)
|
||||||
|
|
||||||
# whether request reached finished condition
|
# whether request reached finished condition
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
|
|||||||
@@ -526,8 +526,9 @@ class Scheduler:
|
|||||||
self,
|
self,
|
||||||
recv_req: TokenizedGenerateReqInput,
|
recv_req: TokenizedGenerateReqInput,
|
||||||
):
|
):
|
||||||
|
# Create a new request
|
||||||
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
||||||
# Create a new request
|
|
||||||
if recv_req.input_embeds is not None:
|
if recv_req.input_embeds is not None:
|
||||||
# Generate fake input_ids based on the length of input_embeds
|
# Generate fake input_ids based on the length of input_embeds
|
||||||
seq_length = len(recv_req.input_embeds)
|
seq_length = len(recv_req.input_embeds)
|
||||||
@@ -558,20 +559,20 @@ class Scheduler:
|
|||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Image inputs
|
# Handle image inputs
|
||||||
if recv_req.image_inputs is not None:
|
if recv_req.image_inputs is not None:
|
||||||
image_inputs = ImageInputs.from_dict(
|
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
||||||
recv_req.image_inputs, self.model_config.vocab_size
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
||||||
)
|
|
||||||
req.origin_input_ids = self.pad_input_ids_func(
|
req.origin_input_ids = self.pad_input_ids_func(
|
||||||
req.origin_input_ids, image_inputs
|
req.origin_input_ids, image_inputs
|
||||||
)
|
)
|
||||||
req.extend_image_inputs(image_inputs, self.model_config.vocab_size)
|
req.extend_image_inputs(image_inputs)
|
||||||
|
|
||||||
if len(req.origin_input_ids) > self.max_req_input_len:
|
if len(req.origin_input_ids) > self.max_req_input_len:
|
||||||
req.finished_reason = FINISH_ABORT(
|
req.finished_reason = FINISH_ABORT(
|
||||||
"Image request length is longer than the KV cache pool size or "
|
"Image request length is longer than the KV cache pool size or "
|
||||||
"the max context length aborting because you cannot truncate the image embeds"
|
"the max context length. "
|
||||||
|
"Abort this request because you cannot truncate the image embeds"
|
||||||
)
|
)
|
||||||
req.image_inputs = None
|
req.image_inputs = None
|
||||||
req.origin_input_ids = [0]
|
req.origin_input_ids = [0]
|
||||||
@@ -579,6 +580,7 @@ class Scheduler:
|
|||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Copy more attributes
|
||||||
req.return_logprob = recv_req.return_logprob
|
req.return_logprob = recv_req.return_logprob
|
||||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||||
req.stream = recv_req.stream
|
req.stream = recv_req.stream
|
||||||
|
|||||||
@@ -10,10 +10,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
import copy
|
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
|
||||||
|
|||||||
@@ -216,6 +216,7 @@ class TokenizerManager:
|
|||||||
input_ids = obj.input_ids
|
input_ids = obj.input_ids
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
|
# TODO: also support getting embeddings for multimodal models
|
||||||
image_inputs: Dict = await self.image_processor.process_images_async(
|
image_inputs: Dict = await self.image_processor.process_images_async(
|
||||||
obj.image_data, input_text or input_ids, obj
|
obj.image_data, input_text or input_ids, obj
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -147,6 +147,11 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
max_image_offset.append(-1)
|
max_image_offset.append(-1)
|
||||||
|
|
||||||
|
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||||
|
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||||
|
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||||
|
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||||
|
|
||||||
# Embed text inputs
|
# Embed text inputs
|
||||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
|||||||
@@ -597,13 +597,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
||||||
`None` if no images are passed.
|
`None` if no images are passed.
|
||||||
"""
|
"""
|
||||||
|
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||||
|
positions = forward_batch.mrope_positions
|
||||||
|
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
if forward_batch.image_inputs is not None:
|
if forward_batch.image_inputs is not None:
|
||||||
image_inputs = [
|
image_inputs = [
|
||||||
img for img in forward_batch.image_inputs if img is not None
|
img for img in forward_batch.image_inputs if img is not None
|
||||||
]
|
]
|
||||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
|
||||||
positions = forward_batch.mrope_positions
|
|
||||||
if (
|
if (
|
||||||
forward_batch.forward_mode.is_decode()
|
forward_batch.forward_mode.is_decode()
|
||||||
or image_inputs is None
|
or image_inputs is None
|
||||||
@@ -617,6 +619,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
f"(3, seq_len) positions, but got {positions.size()}"
|
f"(3, seq_len) positions, but got {positions.size()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||||
|
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||||
|
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||||
|
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||||
|
|
||||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||||
|
|||||||
Reference in New Issue
Block a user