Refactor vlm embedding routine to use precomputed feature (#6543)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-05-24 18:39:21 -07:00
committed by GitHub
parent 0d47788025
commit 681fdc264b
8 changed files with 285 additions and 203 deletions

View File

@@ -252,40 +252,36 @@ def get_embedding_chunk(
return embedding_chunk, start_index, end_index
def get_embedding_and_mask(
def _get_precomputed_embedding(
items: List[MultimodalDataItem],
) -> Optional[torch.Tensor]:
"""
If all items have precomputed_features, return their concatenation.
If some but not all have precomputed_features, raise NotImplementedError.
If none have precomputed_features, return None.
"""
precomputed_features = [item.precomputed_features for item in items]
if any(feature is not None for feature in precomputed_features):
if not all(feature is not None for feature in precomputed_features):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
result = torch.concat(precomputed_features)
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
result = result.reshape(-1, result.shape[-1])
return result
return None
def _get_chunked_prefill_embedding(
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
embedding_items: List[MultimodalDataItem],
placeholder_tensor: torch.Tensor,
input_ids: torch.Tensor,
items_size: List[int],
prefix_length: List[int],
extend_length: List[int],
items_offset_list: List[List[Tuple[int, int]]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
Args:
data_embedding_func: Function that generates embeddings for multimodal items
embedding_items: List of multimodal items to embed
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
input_ids: The input token IDs tensor
items_size: Cumulative sizes of multimodal items per request
prefix_length: Prefix lengths for each request
extend_length: Sequence lengths for each request
items_offset_list: List of offset ranges for multimodal items in each request
Returns:
A tuple containing:
- The generated embeddings tensor
- A boolean mask tensor indicating where these embeddings should be placed
Raises:
AssertionError: If the number of multimodal tokens in input_ids doesn't match
the number of tokens in the generated embeddings
"""
# 1. Get the embedding
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
) -> Optional[torch.Tensor]:
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
embedding_list = []
for i in range(len(items_size) - 1):
if items_size[i] == items_size[i + 1]:
@@ -321,21 +317,28 @@ def get_embedding_and_mask(
embedding_cache.free(embedding_items_hash)
embedding_list.append(embedding_per_req_chunk)
if len(embedding_list) == 0:
return None, None
embedding = torch.concat(embedding_list, dim=0)
# 2. Check the embedding
num_mm_tokens_in_embedding = embedding.shape[0]
special_multimodal_mask = torch.isin(
input_ids,
placeholder_tensor,
).unsqueeze(-1)
return None
return torch.concat(embedding_list, dim=0)
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
def _get_multimodal_mask(
input_ids: torch.Tensor, placeholder_tensor: torch.Tensor
) -> torch.Tensor:
return torch.isin(input_ids, placeholder_tensor).unsqueeze(-1)
def _adjust_embedding_length(
embedding: torch.Tensor,
mask: torch.Tensor,
logger,
) -> torch.Tensor:
num_mm_tokens_in_embedding = embedding.shape[0]
num_mm_tokens_in_input_ids = mask.sum().item()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text. "
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
"tokens from multimodal embeddings."
f"tokens from multimodal embeddings."
)
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
@@ -353,7 +356,54 @@ def get_embedding_and_mask(
raise RuntimeError(
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
)
return embedding
def get_embedding_and_mask(
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
embedding_items: List[MultimodalDataItem],
placeholder_tensor: torch.Tensor,
input_ids: torch.Tensor,
items_size: List[int],
prefix_length: List[int],
extend_length: List[int],
items_offset_list: List[List[Tuple[int, int]]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
Args:
data_embedding_func: Function that generates embeddings for multimodal items
embedding_items: List of multimodal items to embed
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
input_ids: The input token IDs tensor
items_size: Cumulative sizes of multimodal items per request
prefix_length: Prefix lengths for each request
extend_length: Sequence lengths for each request
items_offset_list: List of offset ranges for multimodal items in each request
Returns:
A tuple containing:
- The generated embeddings tensor
- A boolean mask tensor indicating where these embeddings should be placed
"""
# 1. Get embedding
embedding = _get_precomputed_embedding(embedding_items)
if embedding is None:
embedding = _get_chunked_prefill_embedding(
data_embedding_func,
embedding_items,
items_size,
prefix_length,
extend_length,
items_offset_list,
)
if embedding is None:
return None, None
# 2. Get mask
special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
# 3. Adjust embedding length if needed
embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
return embedding, special_multimodal_mask

View File

@@ -144,12 +144,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
if base_output.images:
if images_are_preprocessed:
image_grid_thw = torch.concat(
[
torch.as_tensor(item.image_grid_thws)
for item in base_output.images
]
)
all_image_grid_thws = [
item.image_grid_thws
for item in base_output.images
if item.image_grid_thws is not None
]
all_pixel_values = [
item.pixel_values
for item in base_output.images
@@ -160,6 +159,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
for item in base_output.images
if item.precomputed_features is not None
]
image_grid_thw = (
torch.concat(all_image_grid_thws) if all_image_grid_thws else None
)
pixel_values = (
torch.concat(all_pixel_values) if all_pixel_values else None
)

View File

@@ -282,13 +282,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
# Process images one by one to handle flatten_batch=True constraint in vision_tower
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
vision_outputs_list = []

View File

@@ -499,12 +499,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype

View File

@@ -486,12 +486,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype