vlm: enable radix cache for qwen-vl models (#5349)
Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -463,6 +463,8 @@ class EmbeddingReqInput:
|
||||
image_data: Optional[
|
||||
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
||||
] = None
|
||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
audio_data: Optional[Union[List[str], str]] = None
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# The request id.
|
||||
|
||||
@@ -10,12 +10,13 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import print_warning_once
|
||||
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
return padded_ids
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
|
||||
class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
|
||||
"""In this pattern, data tokens should be represented as repetitions of a single token
|
||||
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
||||
"""
|
||||
|
||||
def __init__(self, image_token_id: torch.Tensor) -> None:
|
||||
self.image_token_id = image_token_id
|
||||
def __init__(self, token_ids: List[int]) -> None:
|
||||
self.token_ids = token_ids
|
||||
|
||||
def pad_input_tokens(self, input_ids: List[int], mm_inputs) -> List[int]:
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will replace the data-tokens in between with pad_values accordingly
|
||||
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
|
||||
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
|
||||
"""
|
||||
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
||||
assert len(pad_values) != 0
|
||||
if not pad_values:
|
||||
# No multimodal items, return original input_ids
|
||||
return input_ids
|
||||
if not input_ids:
|
||||
return []
|
||||
|
||||
input_ids_tensor = torch.tensor(input_ids)
|
||||
mask = torch.isin(input_ids_tensor, self.image_token_id)
|
||||
device = input_ids_tensor.device
|
||||
token_ids_tensor = torch.tensor(self.token_ids, device=device)
|
||||
mask = torch.isin(input_ids_tensor, token_ids_tensor)
|
||||
|
||||
num_image_tokens = mask.sum().item()
|
||||
repeated_pad_values = torch.tensor(pad_values).repeat(
|
||||
num_image_tokens // len(pad_values) + 1
|
||||
)[:num_image_tokens]
|
||||
if not mask.any():
|
||||
# No tokens match token_ids, return original input_ids
|
||||
return input_ids
|
||||
|
||||
input_ids_tensor[mask] = repeated_pad_values
|
||||
return input_ids_tensor.tolist()
|
||||
# Find contiguous regions
|
||||
padded_mask = torch.cat(
|
||||
(
|
||||
torch.tensor([False], device=device),
|
||||
mask,
|
||||
torch.tensor([False], device=device),
|
||||
)
|
||||
)
|
||||
# Find indices where the mask value changes
|
||||
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
|
||||
|
||||
# Start indices are where False changes to True
|
||||
starts = diff_indices[::2]
|
||||
# End indices are where True changes to False (exclusive index)
|
||||
ends = diff_indices[1::2]
|
||||
|
||||
# Check if the number of regions matches the number of pad values
|
||||
if len(starts) != len(pad_values):
|
||||
# Maybe log a warning here?
|
||||
num_regions = len(starts)
|
||||
num_pad_values = len(pad_values)
|
||||
if num_regions > 0 and num_pad_values > 0:
|
||||
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
|
||||
:num_regions
|
||||
]
|
||||
else: # If no regions or no pad_values, this loop won't run anyway.
|
||||
pad_values = [] # Ensure pad_values is empty if starts is empty
|
||||
|
||||
# Create a copy to modify
|
||||
output_ids_tensor = input_ids_tensor.clone()
|
||||
|
||||
# Replace tokens in each region with the corresponding pad value
|
||||
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
|
||||
for i in range(min(len(starts), len(pad_values))):
|
||||
start_idx = starts[i]
|
||||
end_idx = ends[i]
|
||||
pad_value = pad_values[i]
|
||||
if pad_value is not None: # Ensure pad_value is not None before assignment
|
||||
output_ids_tensor[start_idx:end_idx] = pad_value
|
||||
else:
|
||||
logger.warning(f"Skipping region {i} due to None pad_value.")
|
||||
|
||||
return output_ids_tensor.tolist()
|
||||
|
||||
|
||||
def get_embedding_and_mask(
|
||||
@@ -150,7 +200,6 @@ def get_embedding_and_mask(
|
||||
).unsqueeze(-1)
|
||||
|
||||
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
||||
|
||||
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
||||
logger.warning(
|
||||
f"Number of tokens in multimodal embedding does not match those in the input text."
|
||||
@@ -190,13 +239,13 @@ def embed_mm_inputs(
|
||||
audio_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
] = None,
|
||||
placeholder_token_ids: List[int] = None,
|
||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
|
||||
|
||||
Args:
|
||||
placeholder_token_ids: denoting the token of multimodal data in input_ids.
|
||||
placeholder_tokens: denoting the token of multimodal data in input_ids.
|
||||
If none, the pad_values of multimodal items are used
|
||||
|
||||
Returns:
|
||||
@@ -208,9 +257,17 @@ def embed_mm_inputs(
|
||||
|
||||
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
||||
# we assume that multimodal data are represented with its pad_values in input_ids
|
||||
placeholder_token_ids = placeholder_token_ids or [
|
||||
item.pad_value for item in mm_inputs.mm_items
|
||||
]
|
||||
# See `pad_input_ids` for more detail
|
||||
|
||||
# if placeholder_tokens is specified
|
||||
if placeholder_tokens is not None:
|
||||
placeholder_token_ids = flatten_nested_list(
|
||||
[placeholder_token for placeholder_token in placeholder_tokens.values()]
|
||||
)
|
||||
else:
|
||||
placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
|
||||
|
||||
assert isinstance(placeholder_token_ids[0], int)
|
||||
|
||||
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
|
||||
|
||||
@@ -233,7 +290,7 @@ def embed_mm_inputs(
|
||||
using_all_items = False
|
||||
if len(appearing_items) == 0:
|
||||
# This happens mostly when arg placeholder_token_ids is passed
|
||||
logger.warning_once(
|
||||
logger.warning(
|
||||
"No multimodal data item's pad value exist in placeholder ids. Using all items"
|
||||
)
|
||||
using_all_items = True
|
||||
@@ -253,7 +310,8 @@ def embed_mm_inputs(
|
||||
data_embedding_func=image_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=(
|
||||
placeholder_tensor
|
||||
# use the specified modality token to identify the location to embed
|
||||
placeholder_tokens[Modality.IMAGE]
|
||||
if using_all_items
|
||||
else torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
@@ -275,7 +333,7 @@ def embed_mm_inputs(
|
||||
data_embedding_func=audio_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=(
|
||||
placeholder_tensor
|
||||
placeholder_tokens[Modality.AUDIO]
|
||||
if using_all_items
|
||||
else torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
@@ -296,7 +354,7 @@ def embed_mm_inputs(
|
||||
input_ids.clamp_(min=0, max=vocab_size - 1)
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
|
||||
# 4. scatter embeddings into input embedding
|
||||
# 4. Scatter embeddings into input embedding
|
||||
for embedding, mask in zip(embeddings, masks):
|
||||
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
@@ -316,7 +374,7 @@ def general_mm_embed_routine(
|
||||
audio_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
] = None,
|
||||
placeholder_token_ids: List[int] = None,
|
||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -328,7 +386,6 @@ def general_mm_embed_routine(
|
||||
audio_data_embedding_func : the function returning the image embedding
|
||||
|
||||
Returns:
|
||||
inputs_embedding
|
||||
forwarded hidden states
|
||||
|
||||
"""
|
||||
@@ -346,9 +403,9 @@ def general_mm_embed_routine(
|
||||
input_embedding=embed_tokens,
|
||||
image_data_embedding_func=image_data_embedding_func,
|
||||
audio_data_embedding_func=audio_data_embedding_func,
|
||||
placeholder_token_ids=placeholder_token_ids,
|
||||
placeholder_tokens=placeholder_tokens,
|
||||
)
|
||||
# once used, mm_inputs is useless
|
||||
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|
||||
# just being defensive here
|
||||
forward_batch.mm_inputs = None
|
||||
else:
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality
|
||||
@@ -92,7 +93,12 @@ class BaseMultimodalProcessor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def process_mm_data_async(
|
||||
self, image_data, input_text, max_req_input_len, **kwargs
|
||||
self,
|
||||
image_data,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -104,6 +110,8 @@ class BaseMultimodalProcessor(ABC):
|
||||
from decord import VideoReader, cpu
|
||||
|
||||
# Before processing inputs
|
||||
if not image_data or len(image_data) == 0:
|
||||
return []
|
||||
estimated_frames_list = []
|
||||
for image in image_data:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
@@ -215,6 +223,9 @@ class BaseMultimodalProcessor(ABC):
|
||||
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||
|
||||
"""
|
||||
|
||||
if image_data is None:
|
||||
image_data = []
|
||||
if isinstance(multimodal_tokens.image_token, int):
|
||||
multimodal_tokens.image_token = (
|
||||
self._processor.tokenizer.convert_ids_to_tokens(
|
||||
@@ -229,6 +240,8 @@ class BaseMultimodalProcessor(ABC):
|
||||
prompt = self._processor.tokenizer.decode(prompt)
|
||||
else:
|
||||
prompt = prompt
|
||||
|
||||
assert isinstance(prompt, str)
|
||||
if return_text:
|
||||
import re
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -35,7 +36,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
self.IMAGE_TOKEN = "<image>"
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
@@ -45,7 +52,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
input_ids,
|
||||
input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
max_req_input_len=max_req_input_len,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import List, Union
|
||||
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
@@ -13,7 +11,6 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
|
||||
|
||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
|
||||
# will be removed in the future
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
@@ -28,7 +25,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
@@ -41,7 +38,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_ids,
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
max_req_input_len=max_req_input_len,
|
||||
|
||||
@@ -17,7 +17,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
@@ -31,7 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
processor = self._processor
|
||||
|
||||
base_out = self.load_mm_data(
|
||||
prompt=input_ids,
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=processor.image_token
|
||||
|
||||
@@ -51,9 +51,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
audio_data = request_obj.audio_data
|
||||
if not image_data and not audio_data:
|
||||
@@ -64,7 +65,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
audio_data = [audio_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_ids,
|
||||
prompt=input_text,
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import List, Union
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
@@ -27,6 +28,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
self.image_token_id = hf_config.image_token_id
|
||||
self.video_token_id = hf_config.video_token_id
|
||||
self.vision_start_token_id = hf_config.vision_start_token_id
|
||||
self.vision_end_token_id = hf_config.vision_end_token_id
|
||||
self.NUM_TOKEN_PER_FRAME = 770
|
||||
self.IMAGE_FACTOR = 28
|
||||
self.MIN_PIXELS = 4 * 28 * 28
|
||||
@@ -36,20 +39,18 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
prompt,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
prompt=prompt,
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
max_req_input_len=max_req_input_len,
|
||||
@@ -116,29 +117,53 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
async def resize_image_async(image):
|
||||
return resize_image(image)
|
||||
|
||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||
resized_images = await asyncio.gather(*resize_tasks)
|
||||
if base_output.images:
|
||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||
base_output.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=resized_images,
|
||||
images=base_output.images,
|
||||
)
|
||||
|
||||
image_grid_thws = torch.concat([ret["image_grid_thw"]])
|
||||
return {
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"mm_items": [
|
||||
items = []
|
||||
|
||||
input_ids = ret["input_ids"].flatten().tolist()
|
||||
if "pixel_values" in ret:
|
||||
items += [
|
||||
MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"],
|
||||
image_grid_thws=image_grid_thws,
|
||||
image_grid_thws=torch.concat([ret["image_grid_thw"]]),
|
||||
# TODO
|
||||
video_grid_thws=None,
|
||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
],
|
||||
]
|
||||
|
||||
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
||||
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
||||
image_token_id=self.image_token_id,
|
||||
video_token_id=self.video_token_id,
|
||||
vision_start_token_id=self.vision_start_token_id,
|
||||
model_type=self.hf_config.model_type,
|
||||
tokens_per_second=getattr(
|
||||
self.hf_config.vision_config, "tokens_per_second", None
|
||||
),
|
||||
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
||||
image_grid_thw=ret.get("image_grid_thw", None),
|
||||
video_grid_thw=ret.get("video_grid_thw", None),
|
||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||
)
|
||||
mrope_positions = mrope_positions.squeeze(1)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"mm_items": items,
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"im_token_id": self.image_token_id,
|
||||
"video_token_id": self.video_token_id,
|
||||
"mrope_positions": mrope_positions,
|
||||
"mrope_position_delta": mrope_position_delta,
|
||||
}
|
||||
|
||||
@@ -285,6 +285,7 @@ class MultimodalInputs:
|
||||
num_image_tokens: Optional[int] = None
|
||||
|
||||
# QWen2-VL related
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
|
||||
# image
|
||||
@@ -310,16 +311,12 @@ class MultimodalInputs:
|
||||
assert isinstance(ret.mm_items, list)
|
||||
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
|
||||
|
||||
assert len(ret.mm_items) != 0
|
||||
|
||||
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||
# Please note that if the `input_ids` is later used in the model forward,
|
||||
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
||||
# errors in cuda kernels. See also llava.py for example.
|
||||
for item in ret.mm_items:
|
||||
item.set_pad_value()
|
||||
|
||||
optional_args = [
|
||||
"mrope_positions",
|
||||
"mrope_position_delta",
|
||||
"im_token_id",
|
||||
"im_start_id",
|
||||
"im_end_id",
|
||||
@@ -350,20 +347,26 @@ class MultimodalInputs:
|
||||
merge image inputs when requests are being merged
|
||||
"""
|
||||
|
||||
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||
# Please note that if the `input_ids` is later used in the model forward,
|
||||
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
||||
# errors in cuda kernels. See also llava.py for example.
|
||||
|
||||
# args needed to be merged
|
||||
optional_args = [
|
||||
"mm_items",
|
||||
"image_pad_len",
|
||||
"mrope_position_delta",
|
||||
]
|
||||
for arg in optional_args:
|
||||
self_arg = getattr(self, arg, None)
|
||||
if self_arg is not None:
|
||||
setattr(self, arg, self_arg + getattr(other, arg))
|
||||
|
||||
mrope_positions = self.mrope_positions
|
||||
if mrope_positions is not None:
|
||||
if other.mrope_positions is None:
|
||||
self.mrope_positions = mrope_positions
|
||||
else:
|
||||
self.mrope_positions = torch.cat(
|
||||
[self.mrope_positions, other.mrope_positions], dim=1
|
||||
)
|
||||
|
||||
# other args would be kept intact
|
||||
|
||||
|
||||
|
||||
@@ -419,7 +419,10 @@ class TokenizerManager:
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
|
||||
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
||||
image_data=obj.image_data,
|
||||
input_text=input_text or input_ids,
|
||||
request_obj=obj,
|
||||
max_req_input_len=self.max_req_input_len,
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
|
||||
Reference in New Issue
Block a user