Refactor vlm embedding routine to use precomputed feature (#6543)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -252,40 +252,36 @@ def get_embedding_chunk(
|
||||
return embedding_chunk, start_index, end_index
|
||||
|
||||
|
||||
def get_embedding_and_mask(
|
||||
def _get_precomputed_embedding(
|
||||
items: List[MultimodalDataItem],
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
If all items have precomputed_features, return their concatenation.
|
||||
If some but not all have precomputed_features, raise NotImplementedError.
|
||||
If none have precomputed_features, return None.
|
||||
"""
|
||||
precomputed_features = [item.precomputed_features for item in items]
|
||||
if any(feature is not None for feature in precomputed_features):
|
||||
if not all(feature is not None for feature in precomputed_features):
|
||||
raise NotImplementedError(
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
result = torch.concat(precomputed_features)
|
||||
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
|
||||
result = result.reshape(-1, result.shape[-1])
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _get_chunked_prefill_embedding(
|
||||
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
||||
embedding_items: List[MultimodalDataItem],
|
||||
placeholder_tensor: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
items_size: List[int],
|
||||
prefix_length: List[int],
|
||||
extend_length: List[int],
|
||||
items_offset_list: List[List[Tuple[int, int]]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
|
||||
|
||||
Args:
|
||||
data_embedding_func: Function that generates embeddings for multimodal items
|
||||
embedding_items: List of multimodal items to embed
|
||||
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
|
||||
input_ids: The input token IDs tensor
|
||||
items_size: Cumulative sizes of multimodal items per request
|
||||
prefix_length: Prefix lengths for each request
|
||||
extend_length: Sequence lengths for each request
|
||||
items_offset_list: List of offset ranges for multimodal items in each request
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The generated embeddings tensor
|
||||
- A boolean mask tensor indicating where these embeddings should be placed
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of multimodal tokens in input_ids doesn't match
|
||||
the number of tokens in the generated embeddings
|
||||
"""
|
||||
# 1. Get the embedding
|
||||
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
||||
embedding_list = []
|
||||
for i in range(len(items_size) - 1):
|
||||
if items_size[i] == items_size[i + 1]:
|
||||
@@ -321,21 +317,28 @@ def get_embedding_and_mask(
|
||||
embedding_cache.free(embedding_items_hash)
|
||||
embedding_list.append(embedding_per_req_chunk)
|
||||
if len(embedding_list) == 0:
|
||||
return None, None
|
||||
embedding = torch.concat(embedding_list, dim=0)
|
||||
# 2. Check the embedding
|
||||
num_mm_tokens_in_embedding = embedding.shape[0]
|
||||
special_multimodal_mask = torch.isin(
|
||||
input_ids,
|
||||
placeholder_tensor,
|
||||
).unsqueeze(-1)
|
||||
return None
|
||||
return torch.concat(embedding_list, dim=0)
|
||||
|
||||
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
||||
|
||||
def _get_multimodal_mask(
|
||||
input_ids: torch.Tensor, placeholder_tensor: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return torch.isin(input_ids, placeholder_tensor).unsqueeze(-1)
|
||||
|
||||
|
||||
def _adjust_embedding_length(
|
||||
embedding: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
logger,
|
||||
) -> torch.Tensor:
|
||||
num_mm_tokens_in_embedding = embedding.shape[0]
|
||||
num_mm_tokens_in_input_ids = mask.sum().item()
|
||||
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
||||
logger.warning(
|
||||
f"Number of tokens in multimodal embedding does not match those in the input text. "
|
||||
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
||||
"tokens from multimodal embeddings."
|
||||
f"tokens from multimodal embeddings."
|
||||
)
|
||||
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
||||
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
||||
@@ -353,7 +356,54 @@ def get_embedding_and_mask(
|
||||
raise RuntimeError(
|
||||
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
|
||||
)
|
||||
return embedding
|
||||
|
||||
|
||||
def get_embedding_and_mask(
|
||||
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
||||
embedding_items: List[MultimodalDataItem],
|
||||
placeholder_tensor: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
items_size: List[int],
|
||||
prefix_length: List[int],
|
||||
extend_length: List[int],
|
||||
items_offset_list: List[List[Tuple[int, int]]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
|
||||
|
||||
Args:
|
||||
data_embedding_func: Function that generates embeddings for multimodal items
|
||||
embedding_items: List of multimodal items to embed
|
||||
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
|
||||
input_ids: The input token IDs tensor
|
||||
items_size: Cumulative sizes of multimodal items per request
|
||||
prefix_length: Prefix lengths for each request
|
||||
extend_length: Sequence lengths for each request
|
||||
items_offset_list: List of offset ranges for multimodal items in each request
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The generated embeddings tensor
|
||||
- A boolean mask tensor indicating where these embeddings should be placed
|
||||
"""
|
||||
# 1. Get embedding
|
||||
embedding = _get_precomputed_embedding(embedding_items)
|
||||
if embedding is None:
|
||||
embedding = _get_chunked_prefill_embedding(
|
||||
data_embedding_func,
|
||||
embedding_items,
|
||||
items_size,
|
||||
prefix_length,
|
||||
extend_length,
|
||||
items_offset_list,
|
||||
)
|
||||
if embedding is None:
|
||||
return None, None
|
||||
# 2. Get mask
|
||||
special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
|
||||
# 3. Adjust embedding length if needed
|
||||
embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
|
||||
return embedding, special_multimodal_mask
|
||||
|
||||
|
||||
|
||||
@@ -144,12 +144,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
if base_output.images:
|
||||
if images_are_preprocessed:
|
||||
image_grid_thw = torch.concat(
|
||||
[
|
||||
torch.as_tensor(item.image_grid_thws)
|
||||
for item in base_output.images
|
||||
]
|
||||
)
|
||||
all_image_grid_thws = [
|
||||
item.image_grid_thws
|
||||
for item in base_output.images
|
||||
if item.image_grid_thws is not None
|
||||
]
|
||||
all_pixel_values = [
|
||||
item.pixel_values
|
||||
for item in base_output.images
|
||||
@@ -160,6 +159,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
for item in base_output.images
|
||||
if item.precomputed_features is not None
|
||||
]
|
||||
image_grid_thw = (
|
||||
torch.concat(all_image_grid_thws) if all_image_grid_thws else None
|
||||
)
|
||||
pixel_values = (
|
||||
torch.concat(all_pixel_values) if all_pixel_values else None
|
||||
)
|
||||
|
||||
@@ -282,13 +282,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
if any(item.precomputed_features is not None for item in items):
|
||||
if not all(item.precomputed_features is not None for item in items):
|
||||
raise NotImplementedError(
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
return torch.concat([item.precomputed_features for item in items])
|
||||
|
||||
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
||||
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
||||
vision_outputs_list = []
|
||||
|
||||
@@ -499,12 +499,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
if any(item.precomputed_features is not None for item in items):
|
||||
if not all(item.precomputed_features is not None for item in items):
|
||||
raise NotImplementedError(
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
return torch.concat([item.precomputed_features for item in items])
|
||||
# in qwen-vl, last dim is the same
|
||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
|
||||
@@ -486,12 +486,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
if any(item.precomputed_features is not None for item in items):
|
||||
if not all(item.precomputed_features is not None for item in items):
|
||||
raise NotImplementedError(
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
return torch.concat([item.precomputed_features for item in items])
|
||||
# in qwen-vl, last dim is the same
|
||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
|
||||
Reference in New Issue
Block a user