refactor: multimodal data (#4754)
This commit is contained in:
@@ -72,7 +72,8 @@ def eval_mmmu(args):
|
|||||||
if suffix:
|
if suffix:
|
||||||
contents += [{"type": "text", "text": suffix}]
|
contents += [{"type": "text", "text": suffix}]
|
||||||
messages = [{"role": "user", "content": contents}]
|
messages = [{"role": "user", "content": contents}]
|
||||||
model_inputs = processor.apply_chat_template(
|
try:
|
||||||
|
model_inputs = processor.tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@@ -80,9 +81,29 @@ def eval_mmmu(args):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
).to(model.device)
|
).to(model.device)
|
||||||
input_len = model_inputs["input_ids"].shape[-1]
|
input_len = model_inputs["input_ids"].shape[-1]
|
||||||
generation = model.generate(**model_inputs, generation_config=generation_config)
|
generation = model.generate(
|
||||||
|
**model_inputs, generation_config=generation_config
|
||||||
|
)
|
||||||
generation = generation[0][input_len:]
|
generation = generation[0][input_len:]
|
||||||
response = processor.decode(generation, skip_special_tokens=True)
|
response = processor.decode(generation, skip_special_tokens=True)
|
||||||
|
except:
|
||||||
|
contents = []
|
||||||
|
if prefix:
|
||||||
|
contents += [prefix]
|
||||||
|
image = PIL.Image.open(sample["image_path"])
|
||||||
|
contents += [image]
|
||||||
|
if suffix:
|
||||||
|
contents += [suffix]
|
||||||
|
messages = [{"role": "user", "content": contents}]
|
||||||
|
response = model.chat(
|
||||||
|
msgs=messages,
|
||||||
|
tokenizer=processor.tokenizer,
|
||||||
|
sampling=False,
|
||||||
|
max_new_tokens=sampling_params["max_new_tokens"],
|
||||||
|
use_tts_template=False,
|
||||||
|
generate_audio=False,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
process_result(response, sample, answer_dict, out_samples)
|
process_result(response, sample, answer_dict, out_samples)
|
||||||
|
|
||||||
|
|||||||
@@ -442,6 +442,8 @@ def calculate_ins_level_acc(results: Dict):
|
|||||||
|
|
||||||
|
|
||||||
def process_result(response, sample, answer_dict, out_samples):
|
def process_result(response, sample, answer_dict, out_samples):
|
||||||
|
if response is None:
|
||||||
|
return
|
||||||
if sample["question_type"] == "multiple-choice":
|
if sample["question_type"] == "multiple-choice":
|
||||||
pred_ans = parse_multi_choice_response(
|
pred_ans = parse_multi_choice_response(
|
||||||
response, sample["all_choices"], sample["index2ans"]
|
response, sample["all_choices"], sample["index2ans"]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Multimodality utils
|
Multi-modality utils
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
@@ -9,11 +9,13 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
MultimodalDataItem,
|
||||||
MultimodalInputs,
|
MultimodalInputs,
|
||||||
global_server_args_dict,
|
global_server_args_dict,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.utils import print_warning_once
|
||||||
from sglang.utils import logger
|
from sglang.utils import logger
|
||||||
|
|
||||||
|
|
||||||
@@ -26,7 +28,7 @@ class MultiModalityDataPaddingPattern:
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def pad_input_tokens(
|
def pad_input_tokens(
|
||||||
self, input_ids: List[int], image_inputs: MultimodalInputs
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
||||||
@@ -49,13 +51,13 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|||||||
"""
|
"""
|
||||||
This function will replace the data-tokens inbetween with pad_values accordingly
|
This function will replace the data-tokens inbetween with pad_values accordingly
|
||||||
"""
|
"""
|
||||||
pad_values = mm_inputs.pad_values
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
||||||
data_token_pairs = self.data_token_id_pairs
|
data_token_pairs = self.data_token_id_pairs
|
||||||
mm_inputs.image_offsets = []
|
mm_inputs.data_offsets = []
|
||||||
if data_token_pairs is None:
|
if data_token_pairs is None:
|
||||||
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
|
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
|
||||||
if data_token_pairs is None:
|
if data_token_pairs is None:
|
||||||
logger.warning(
|
print_warning_once(
|
||||||
"No data_token_pairs provided, RadixAttention might be influenced."
|
"No data_token_pairs provided, RadixAttention might be influenced."
|
||||||
)
|
)
|
||||||
return input_ids
|
return input_ids
|
||||||
@@ -77,10 +79,10 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|||||||
|
|
||||||
if input_ids[start_idx] in start_token_ids:
|
if input_ids[start_idx] in start_token_ids:
|
||||||
data_idx += 1
|
data_idx += 1
|
||||||
mm_inputs.image_offsets += [start_idx]
|
mm_inputs.data_offsets += [start_idx]
|
||||||
|
|
||||||
if data_idx >= len(mm_inputs.pad_values):
|
if data_idx >= len(pad_values):
|
||||||
data_idx = len(mm_inputs.pad_values) - 1
|
data_idx = len(pad_values) - 1
|
||||||
|
|
||||||
num_tokens = end_idx - start_idx - 1
|
num_tokens = end_idx - start_idx - 1
|
||||||
pad_value = pad_values[data_idx]
|
pad_value = pad_values[data_idx]
|
||||||
@@ -94,68 +96,19 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|||||||
return padded_ids
|
return padded_ids
|
||||||
|
|
||||||
|
|
||||||
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
|
|
||||||
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
|
|
||||||
which needs first to be expanded to multiple tokens, then replaced with their padding values
|
|
||||||
|
|
||||||
This strategy should be used when a single data token represents content that should
|
|
||||||
be expanded to multiple tokens during processing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
|
|
||||||
) -> None:
|
|
||||||
self.num_data_token_calc_func = num_data_token_calc_func
|
|
||||||
|
|
||||||
def pad_input_tokens(
|
|
||||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
|
||||||
This function will follow the procedure of:
|
|
||||||
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
|
||||||
2. the padded data tokens will be replaced with their pad_values
|
|
||||||
"""
|
|
||||||
image_grid_thws = mm_inputs.image_grid_thws
|
|
||||||
pad_values = mm_inputs.pad_values
|
|
||||||
|
|
||||||
image_indices = [
|
|
||||||
idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
|
|
||||||
]
|
|
||||||
|
|
||||||
mm_inputs.image_offsets = []
|
|
||||||
|
|
||||||
input_ids_with_image = []
|
|
||||||
for image_cnt, _ in enumerate(image_grid_thws):
|
|
||||||
# print(f"image_cnt {image_cnt}")
|
|
||||||
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
|
||||||
if image_cnt == 0:
|
|
||||||
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
|
||||||
else:
|
|
||||||
non_image_tokens = input_ids[
|
|
||||||
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
|
||||||
]
|
|
||||||
input_ids_with_image.extend(non_image_tokens)
|
|
||||||
mm_inputs.image_offsets.append(len(input_ids_with_image))
|
|
||||||
pad_ids = pad_values * (
|
|
||||||
(num_image_tokens + len(pad_values)) // len(pad_values)
|
|
||||||
)
|
|
||||||
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
|
||||||
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
|
||||||
|
|
||||||
return input_ids_with_image
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
|
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
|
||||||
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
|
"""In this pattern, data tokens should be represented as repetitions of a single token
|
||||||
|
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, image_token_id: torch.Tensor) -> None:
|
def __init__(self, image_token_id: torch.Tensor) -> None:
|
||||||
self.image_token_id = image_token_id
|
self.image_token_id = image_token_id
|
||||||
|
|
||||||
def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
|
def pad_input_tokens(self, input_ids: List[int], mm_inputs) -> List[int]:
|
||||||
"""
|
"""
|
||||||
This function will replace the data-tokens in between with pad_values accordingly
|
This function will replace the data-tokens in between with pad_values accordingly
|
||||||
"""
|
"""
|
||||||
pad_values = image_inputs.pad_values
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
||||||
assert len(pad_values) != 0
|
assert len(pad_values) != 0
|
||||||
|
|
||||||
input_ids_tensor = torch.tensor(input_ids)
|
input_ids_tensor = torch.tensor(input_ids)
|
||||||
@@ -170,109 +123,183 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
|
|||||||
return input_ids_tensor.tolist()
|
return input_ids_tensor.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_and_mask(
|
||||||
|
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
||||||
|
embedding_items: List[MultimodalDataItem],
|
||||||
|
placeholder_tensor: torch.Tensor,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the multimodal embedding and its mask from input_ids
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 1. Get the embedding
|
||||||
|
embedding = data_embedding_func(embedding_items)
|
||||||
|
|
||||||
|
# 2. Check the embedding
|
||||||
|
if embedding.dim() == 2:
|
||||||
|
num_mm_tokens_in_embedding = embedding.shape[0]
|
||||||
|
else:
|
||||||
|
num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
|
||||||
|
|
||||||
|
# the mask of multimodal tokens from input_ids
|
||||||
|
special_multimodal_mask = torch.isin(
|
||||||
|
input_ids,
|
||||||
|
placeholder_tensor,
|
||||||
|
).unsqueeze(-1)
|
||||||
|
|
||||||
|
num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
|
||||||
|
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
||||||
|
logger.warning(
|
||||||
|
f"Number of tokens in multimodal embedding does not match those in the input text."
|
||||||
|
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
||||||
|
"tokens from multimodal embeddings."
|
||||||
|
)
|
||||||
|
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
||||||
|
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
||||||
|
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
|
||||||
|
# extend_start_loc and extend_seq_lens
|
||||||
|
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
||||||
|
if chunked_prefill_size != -1:
|
||||||
|
logger.warning(
|
||||||
|
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
|
||||||
|
)
|
||||||
|
# extract from the end: this is a compromise
|
||||||
|
if embedding.dim() == 2:
|
||||||
|
embedding = embedding[-num_mm_tokens_in_input_ids:, :]
|
||||||
|
else:
|
||||||
|
num_multimodal = num_mm_tokens_in_input_ids // embedding.shape[0]
|
||||||
|
embedding = embedding[-num_multimodal:, :]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Insufficient multimodal embedding length. This is an internal error"
|
||||||
|
)
|
||||||
|
|
||||||
|
return embedding, special_multimodal_mask
|
||||||
|
|
||||||
|
|
||||||
def embed_mm_inputs(
|
def embed_mm_inputs(
|
||||||
mm_input: MultimodalInputs,
|
mm_inputs: MultimodalInputs,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
input_embedding: nn.Embedding,
|
input_embedding: nn.Embedding,
|
||||||
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
image_data_embedding_func: Callable[
|
||||||
|
[List[MultimodalDataItem]], torch.Tensor
|
||||||
|
] = None,
|
||||||
|
audio_data_embedding_func: Callable[
|
||||||
|
[List[MultimodalDataItem]], torch.Tensor
|
||||||
|
] = None,
|
||||||
placeholder_token_ids: List[int] = None,
|
placeholder_token_ids: List[int] = None,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Calculate the image embeddings if necessary, then scatter the result with
|
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
|
||||||
the help of a boolean mask denoting the embed locations
|
|
||||||
|
Args:
|
||||||
|
placeholder_token_ids: denoting the token of multimodal data in input_ids.
|
||||||
|
If none, the pad_values of multimodal items are used
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
final embedding: Optional[torch.Tensor]
|
final embedding: Optional[torch.Tensor]
|
||||||
"""
|
"""
|
||||||
if mm_input is None:
|
|
||||||
|
if mm_inputs is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
|
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
||||||
|
# we assume that multimodal data are represented with its pad_values in input_ids
|
||||||
|
placeholder_token_ids = placeholder_token_ids or [
|
||||||
|
item.pad_value for item in mm_inputs.mm_items
|
||||||
|
]
|
||||||
|
|
||||||
# boolean masking the special tokens
|
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
|
||||||
special_image_mask = torch.isin(
|
|
||||||
input_ids,
|
|
||||||
torch.tensor(placeholder_token_ids, device=input_ids.device),
|
|
||||||
).unsqueeze(-1)
|
|
||||||
|
|
||||||
num_image_tokens_in_input_ids = special_image_mask.sum()
|
placeholder_masks = torch.isin(input_ids, placeholder_tensor)
|
||||||
# print(f"{num_image_tokens_in_input_ids}")
|
|
||||||
# print(f"{input_ids}")
|
|
||||||
|
|
||||||
# return
|
appearing_pad_values = torch.unique(
|
||||||
if num_image_tokens_in_input_ids == 0:
|
input_ids[placeholder_masks], return_counts=False
|
||||||
# unexpected
|
)
|
||||||
|
|
||||||
|
if appearing_pad_values.numel() == 0:
|
||||||
|
# all been prefixed
|
||||||
inputs_embeds = input_embedding(input_ids)
|
inputs_embeds = input_embedding(input_ids)
|
||||||
else:
|
else:
|
||||||
# print(f"Getting image feature")
|
appearing_items = [
|
||||||
image_embedding = mm_data_embedding_func(mm_input)
|
item
|
||||||
|
for item in mm_inputs.mm_items
|
||||||
|
if item.pad_value is not None and item.pad_value in appearing_pad_values
|
||||||
|
]
|
||||||
|
|
||||||
# print(f"image_embedding: {image_embedding.shape}")
|
using_all_items = False
|
||||||
|
if len(appearing_items) == 0:
|
||||||
if image_embedding.dim() == 2:
|
# This happens mostly when arg placeholder_token_ids is passed
|
||||||
num_image_tokens_in_embedding = image_embedding.shape[0]
|
logger.warning_once(
|
||||||
else:
|
"No multimodal data item's pad value exist in placeholder ids. Using all items"
|
||||||
num_image_tokens_in_embedding = (
|
|
||||||
image_embedding.shape[0] * image_embedding.shape[1]
|
|
||||||
)
|
|
||||||
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
|
|
||||||
num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
|
|
||||||
image_embedding = image_embedding[:num_image, :]
|
|
||||||
logger.warning(
|
|
||||||
f"Number of images does not match number of special image tokens in the input text. "
|
|
||||||
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
|
|
||||||
"tokens from image embeddings."
|
|
||||||
)
|
)
|
||||||
|
using_all_items = True
|
||||||
|
appearing_items = mm_inputs.mm_items
|
||||||
|
|
||||||
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
embeddings, masks = [], []
|
||||||
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
|
|
||||||
# extend_start_loc and extend_seq_lens
|
|
||||||
if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
|
|
||||||
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
|
||||||
if chunked_prefill_size != -1:
|
|
||||||
logger.warning(
|
|
||||||
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# 2. Get multimodal embedding separately
|
||||||
|
# TODO: make this more generic
|
||||||
|
# Try get image embedding if any
|
||||||
|
if (
|
||||||
|
any(True for item in appearing_items if item.is_image())
|
||||||
|
and image_data_embedding_func
|
||||||
|
):
|
||||||
|
items = [item for item in appearing_items if item.is_image()]
|
||||||
|
embedding, mask = get_embedding_and_mask(
|
||||||
|
data_embedding_func=image_data_embedding_func,
|
||||||
|
embedding_items=items,
|
||||||
|
placeholder_tensor=(
|
||||||
|
placeholder_tensor
|
||||||
|
if using_all_items
|
||||||
|
else torch.tensor(
|
||||||
|
[item.pad_value for item in items],
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
input_ids=input_ids,
|
||||||
|
)
|
||||||
|
embeddings += [embedding]
|
||||||
|
masks += [mask]
|
||||||
|
|
||||||
|
# Try get audio embedding if any
|
||||||
|
if (
|
||||||
|
any(True for item in appearing_items if item.is_audio())
|
||||||
|
and audio_data_embedding_func
|
||||||
|
):
|
||||||
|
items = [item for item in appearing_items if item.is_audio()]
|
||||||
|
embedding, mask = get_embedding_and_mask(
|
||||||
|
data_embedding_func=audio_data_embedding_func,
|
||||||
|
embedding_items=items,
|
||||||
|
placeholder_tensor=(
|
||||||
|
placeholder_tensor
|
||||||
|
if using_all_items
|
||||||
|
else torch.tensor(
|
||||||
|
[item.pad_value for item in items],
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
input_ids=input_ids,
|
||||||
|
)
|
||||||
|
embeddings += [embedding]
|
||||||
|
masks += [mask]
|
||||||
|
|
||||||
|
# 3. Get input embeddings
|
||||||
vocab_size = input_embedding.num_embeddings
|
vocab_size = input_embedding.num_embeddings
|
||||||
# Important: clamp after getting original image regions
|
# Important: clamp after getting original multimodal regions
|
||||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
# Clamp input ids. This is because the input_ids for the multimodal tokens are
|
||||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
|
||||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||||
input_ids.clamp_(min=0, max=vocab_size - 1)
|
input_ids.clamp_(min=0, max=vocab_size - 1)
|
||||||
inputs_embeds = input_embedding(input_ids)
|
inputs_embeds = input_embedding(input_ids)
|
||||||
|
|
||||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
# 4. scatter embeddings into input embedding
|
||||||
inputs_embeds.device
|
for embedding, mask in zip(embeddings, masks):
|
||||||
)
|
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(
|
inputs_embeds = inputs_embeds.masked_scatter(
|
||||||
special_image_mask,
|
mask,
|
||||||
image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
||||||
)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
|
|
||||||
def embed_image_embedding(
|
|
||||||
inputs_embeds: torch.Tensor,
|
|
||||||
image_embedding: torch.Tensor,
|
|
||||||
image_bounds: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
scatter image_embedding into inputs_embeds according to image_bounds
|
|
||||||
"""
|
|
||||||
if len(image_bounds) > 0:
|
|
||||||
image_indices = torch.stack(
|
|
||||||
[
|
|
||||||
torch.arange(start, end, dtype=torch.long)
|
|
||||||
for start, end in image_bounds.tolist()
|
|
||||||
]
|
|
||||||
).to(inputs_embeds.device)
|
|
||||||
|
|
||||||
inputs_embeds.scatter_(
|
|
||||||
0,
|
|
||||||
image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
|
|
||||||
image_embedding.view(-1, image_embedding.shape[-1]),
|
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
@@ -280,28 +307,43 @@ def embed_image_embedding(
|
|||||||
def general_mm_embed_routine(
|
def general_mm_embed_routine(
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
embed_tokens: nn.Embedding,
|
language_model: nn.Module,
|
||||||
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
image_data_embedding_func: Callable[
|
||||||
|
[List[MultimodalDataItem]], torch.Tensor
|
||||||
|
] = None,
|
||||||
|
audio_data_embedding_func: Callable[
|
||||||
|
[List[MultimodalDataItem]], torch.Tensor
|
||||||
|
] = None,
|
||||||
placeholder_token_ids: List[int] = None,
|
placeholder_token_ids: List[int] = None,
|
||||||
):
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
a general wrapper function to get final input embeds from multimodal models
|
A general wrapper function to get final input embeds from multimodal models with a language model as causal model
|
||||||
with a language model as causal model
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
|
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
|
||||||
|
image_data_embedding_func : the function returning the image embedding
|
||||||
|
audio_data_embedding_func : the function returning the image embedding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
inputs_embedding
|
||||||
|
forwarded hidden states
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert hasattr(language_model, "get_input_embeddings")
|
||||||
|
embed_tokens = language_model.get_input_embeddings()
|
||||||
if (
|
if (
|
||||||
not forward_batch.forward_mode.is_decode()
|
not forward_batch.forward_mode.is_decode()
|
||||||
and forward_batch.contains_mm_inputs()
|
and forward_batch.contains_mm_inputs()
|
||||||
):
|
):
|
||||||
image = forward_batch.merge_mm_inputs()
|
mm_input = forward_batch.merge_mm_inputs()
|
||||||
inputs_embeds = embed_mm_inputs(
|
inputs_embeds = embed_mm_inputs(
|
||||||
mm_input=image,
|
mm_inputs=mm_input,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
input_embedding=embed_tokens,
|
input_embedding=embed_tokens,
|
||||||
mm_data_embedding_func=mm_data_embedding_func,
|
image_data_embedding_func=image_data_embedding_func,
|
||||||
|
audio_data_embedding_func=audio_data_embedding_func,
|
||||||
placeholder_token_ids=placeholder_token_ids,
|
placeholder_token_ids=placeholder_token_ids,
|
||||||
)
|
)
|
||||||
# once used, mm_inputs is useless
|
# once used, mm_inputs is useless
|
||||||
@@ -310,7 +352,13 @@ def general_mm_embed_routine(
|
|||||||
else:
|
else:
|
||||||
inputs_embeds = embed_tokens(input_ids)
|
inputs_embeds = embed_tokens(input_ids)
|
||||||
|
|
||||||
return inputs_embeds
|
hidden_states = language_model(
|
||||||
|
input_ids=None,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
input_embeds=inputs_embeds,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def get_multimodal_data_bounds(
|
def get_multimodal_data_bounds(
|
||||||
@@ -322,15 +370,13 @@ def get_multimodal_data_bounds(
|
|||||||
Returns:
|
Returns:
|
||||||
[bounds_count, 2]
|
[bounds_count, 2]
|
||||||
"""
|
"""
|
||||||
# All the images in the batch should share the same special image
|
# All the multimodal data in the batch should share the same special bound token ids.
|
||||||
# bound token ids.
|
|
||||||
start_tokens = [s for s, _e in token_pairs]
|
start_tokens = [s for s, _e in token_pairs]
|
||||||
end_tokens = [e for _s, e in token_pairs]
|
end_tokens = [e for _s, e in token_pairs]
|
||||||
|
|
||||||
assert all(isinstance(t, int) for t in start_tokens)
|
assert all(isinstance(t, int) for t in start_tokens)
|
||||||
assert all(isinstance(t, int) for t in end_tokens)
|
assert all(isinstance(t, int) for t in end_tokens)
|
||||||
|
|
||||||
# print(input_ids)
|
|
||||||
start_cond = torch.isin(
|
start_cond = torch.isin(
|
||||||
input_ids, torch.tensor(start_tokens, device=input_ids.device)
|
input_ids, torch.tensor(start_tokens, device=input_ids.device)
|
||||||
)
|
)
|
||||||
@@ -339,7 +385,7 @@ def get_multimodal_data_bounds(
|
|||||||
(data_start_tokens,) = torch.where(start_cond)
|
(data_start_tokens,) = torch.where(start_cond)
|
||||||
(data_end_tokens,) = torch.where(end_cond)
|
(data_end_tokens,) = torch.where(end_cond)
|
||||||
|
|
||||||
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
|
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
|
||||||
if len(data_start_tokens) != len(data_end_tokens):
|
if len(data_start_tokens) != len(data_end_tokens):
|
||||||
if (
|
if (
|
||||||
len(data_start_tokens) + 1 == len(data_end_tokens)
|
len(data_start_tokens) + 1 == len(data_end_tokens)
|
||||||
@@ -352,14 +398,14 @@ def get_multimodal_data_bounds(
|
|||||||
data_start_tokens,
|
data_start_tokens,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
|
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
|
||||||
|
|
||||||
if valid_image_nums == 0:
|
if valid_mm_data_nums == 0:
|
||||||
return torch.zeros((0, 2), device=input_ids.device)
|
return torch.zeros((0, 2), device=input_ids.device)
|
||||||
|
|
||||||
# Filter out pairs where start_token >= end_token
|
# Filter out pairs where start_token >= end_token
|
||||||
valid_pairs = []
|
valid_pairs = []
|
||||||
for i in range(valid_image_nums):
|
for i in range(valid_mm_data_nums):
|
||||||
start_token = data_start_tokens[i]
|
start_token = data_start_tokens[i]
|
||||||
end_token = data_end_tokens[i]
|
end_token = data_end_tokens[i]
|
||||||
if start_token < end_token:
|
if start_token < end_token:
|
||||||
|
|||||||
@@ -64,5 +64,3 @@ def get_mm_processor(
|
|||||||
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
||||||
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.image_proce
|
|
||||||
|
|||||||
@@ -8,18 +8,10 @@ from typing import Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import transformers
|
|
||||||
from decord import VideoReader, cpu
|
from decord import VideoReader, cpu
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from sglang.srt.utils import load_audio, load_image, logger
|
from sglang.srt.utils import encode_video, load_audio, load_image, logger
|
||||||
|
|
||||||
global global_processor
|
|
||||||
|
|
||||||
|
|
||||||
def get_global_processor():
|
|
||||||
global global_processor
|
|
||||||
return global_processor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput:
|
|||||||
# input_text, with each frame of video/image represented with a image_token
|
# input_text, with each frame of video/image represented with a image_token
|
||||||
input_text: str
|
input_text: str
|
||||||
|
|
||||||
mm_data_hashes: Optional[list[int]]
|
|
||||||
# images
|
|
||||||
image_sizes: Optional[list[int]]
|
|
||||||
# frames loaded from image and video, in given order
|
# frames loaded from image and video, in given order
|
||||||
images: Optional[list[PIL.Image]] = None
|
images: Optional[list[PIL.Image]] = None
|
||||||
|
|
||||||
@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput:
|
|||||||
audios: Optional[list[np.ndarray]] = None
|
audios: Optional[list[np.ndarray]] = None
|
||||||
|
|
||||||
def normalize(self):
|
def normalize(self):
|
||||||
for field_name in ["data_hashes", "image_sizes", "images", "audios"]:
|
for field_name in ["image_sizes", "images", "audios"]:
|
||||||
field = getattr(self, field_name, None)
|
field = getattr(self, field_name, None)
|
||||||
if field is not None and isinstance(field, list) and len(field) == 0:
|
if field is not None and isinstance(field, list) and len(field) == 0:
|
||||||
setattr(self, field_name, None)
|
setattr(self, field_name, None)
|
||||||
@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
# FIXME: not accurate, model and image specific
|
# FIXME: not accurate, model and image specific
|
||||||
self.NUM_TOKEN_PER_FRAME = 330
|
self.NUM_TOKEN_PER_FRAME = 330
|
||||||
|
|
||||||
# Initialize global processor first
|
self.io_executor = concurrent.futures.ThreadPoolExecutor(
|
||||||
init_global_processor(self, server_args)
|
max_workers=int(os.environ.get("SGLANG_IO_WORKERS", 4))
|
||||||
|
)
|
||||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
self.cpu_executor = concurrent.futures.ProcessPoolExecutor(
|
||||||
initializer=init_global_processor,
|
|
||||||
mp_context=mp.get_context("fork"),
|
mp_context=mp.get_context("fork"),
|
||||||
initargs=(
|
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
|
||||||
self,
|
|
||||||
server_args,
|
|
||||||
),
|
|
||||||
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_processor(self, server_args):
|
def process_mm_data(
|
||||||
"""Init the global processor for multi modal models."""
|
self, input_text, images=None, videos=None, audios=None, **kwargs
|
||||||
from sglang.srt.hf_transformers_utils import get_processor
|
):
|
||||||
|
"""
|
||||||
|
process multimodal data with transformers AutoProcessor
|
||||||
|
"""
|
||||||
|
if images is not None:
|
||||||
|
kwargs["images"] = images
|
||||||
|
if videos is not None:
|
||||||
|
kwargs["videos"] = videos
|
||||||
|
if audios is not None:
|
||||||
|
kwargs["audios"] = audios
|
||||||
|
|
||||||
return get_processor(
|
processor = self._processor
|
||||||
server_args.tokenizer_path,
|
result = processor.__call__(
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
text=[input_text],
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
|
|
||||||
return estimated_frames_list
|
return estimated_frames_list
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def encode_video(video_path, frame_count_limit=None):
|
|
||||||
if not os.path.exists(video_path):
|
|
||||||
logger.error(f"Video {video_path} does not exist")
|
|
||||||
return []
|
|
||||||
|
|
||||||
if frame_count_limit == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def uniform_sample(l, n):
|
|
||||||
gap = len(l) / n
|
|
||||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
|
||||||
return [l[i] for i in idxs]
|
|
||||||
|
|
||||||
vr = VideoReader(video_path, ctx=cpu(0))
|
|
||||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
|
||||||
frame_indices = [i for i in range(0, len(vr), sample_fps)]
|
|
||||||
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
|
|
||||||
frame_indices = uniform_sample(frame_indices, frame_count_limit)
|
|
||||||
|
|
||||||
frames = vr.get_batch(frame_indices).asnumpy()
|
|
||||||
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
|
||||||
return frames
|
|
||||||
|
|
||||||
def load_mm_data(
|
def load_mm_data(
|
||||||
self,
|
self,
|
||||||
input_ids: list[int],
|
prompt: str,
|
||||||
multimodal_tokens: MultimodalSpecialTokens,
|
multimodal_tokens: MultimodalSpecialTokens,
|
||||||
max_req_input_len: int,
|
max_req_input_len: int,
|
||||||
image_data: Optional[list] = None,
|
image_data: Optional[list] = None,
|
||||||
@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
else:
|
else:
|
||||||
multimodal_tokens.image_token = multimodal_tokens.image_token
|
multimodal_tokens.image_token = multimodal_tokens.image_token
|
||||||
|
|
||||||
if isinstance(input_ids, list) and return_text:
|
assert isinstance(prompt, str)
|
||||||
assert len(input_ids) and isinstance(input_ids[0], int)
|
|
||||||
input_text = self._processor.tokenizer.decode(input_ids)
|
if isinstance(prompt, list) and return_text:
|
||||||
|
assert len(prompt) and isinstance(prompt[0], int)
|
||||||
|
prompt = self._processor.tokenizer.decode(prompt)
|
||||||
else:
|
else:
|
||||||
input_text = input_ids
|
prompt = prompt
|
||||||
if return_text:
|
if return_text:
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
+ ")"
|
+ ")"
|
||||||
)
|
)
|
||||||
# split text into list of normal text and special tokens
|
# split text into list of normal text and special tokens
|
||||||
text_parts = re.split(pattern, input_text)
|
text_parts = re.split(pattern, prompt)
|
||||||
|
|
||||||
# TODO(mick): load from server_args, env, or sampling_params
|
# TODO(mick): load from server_args, env, or sampling_params
|
||||||
MAX_NUM_FRAMES = 30
|
MAX_NUM_FRAMES = 30
|
||||||
@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
):
|
):
|
||||||
# video
|
# video
|
||||||
path = image_file[len("video:") :]
|
path = image_file[len("video:") :]
|
||||||
frames = BaseMultimodalProcessor.encode_video(
|
frames = encode_video(
|
||||||
path, frame_count_limit=frames_to_process
|
path, frame_count_limit=frames_to_process
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
raise RuntimeError(f"An exception occurred while loading images: {e}")
|
raise RuntimeError(f"An exception occurred while loading images: {e}")
|
||||||
|
|
||||||
out = BaseMultiModalProcessorOutput(
|
out = BaseMultiModalProcessorOutput(
|
||||||
mm_data_hashes=hashes,
|
|
||||||
image_sizes=image_sizes,
|
|
||||||
images=images,
|
images=images,
|
||||||
audios=audios,
|
audios=audios,
|
||||||
input_text=new_text,
|
input_text=new_text,
|
||||||
)
|
)
|
||||||
out.normalize()
|
out.normalize()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
|
|
||||||
"""
|
|
||||||
Init the global processor for multimodal models."""
|
|
||||||
global global_processor
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
global_processor = sglang_processor._build_processor(server_args=server_args)
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import asyncio
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor,
|
BaseMultimodalProcessor,
|
||||||
get_global_processor,
|
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.models.clip import CLIPModel
|
from sglang.srt.models.clip import CLIPModel
|
||||||
from sglang.srt.utils import load_image
|
from sglang.srt.utils import load_image
|
||||||
|
|
||||||
@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor):
|
|||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _process_single_image_task(images, input_text):
|
|
||||||
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
|
|
||||||
return get_global_processor()(
|
|
||||||
images=images, text=input_text, return_tensors="pt"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _process_single_image(self, images, input_text):
|
|
||||||
if self.executor is not None:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
image_inputs = await loop.run_in_executor(
|
|
||||||
self.executor,
|
|
||||||
ClipImageProcessor._process_single_image_task,
|
|
||||||
images,
|
|
||||||
input_text,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_inputs = self._processor(
|
|
||||||
images=images, text=[input_text], return_tensors="pt"
|
|
||||||
)
|
|
||||||
|
|
||||||
return image_inputs
|
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||||
):
|
):
|
||||||
@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor):
|
|||||||
else:
|
else:
|
||||||
images = load_image(image_data[0])[0]
|
images = load_image(image_data[0])[0]
|
||||||
|
|
||||||
image_inputs = await self._process_single_image(images, input_text)
|
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||||
image_inputs["data_hashes"] = [hash(str(image_data))]
|
image_inputs["data_hashes"] = [hash(str(image_data))]
|
||||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||||
|
image_inputs["mm_items"] = [
|
||||||
|
MultimodalDataItem(
|
||||||
|
pixel_values=image_inputs["pixel_values"], modality=Modality.IMAGE
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
return image_inputs
|
return image_inputs
|
||||||
|
|||||||
@@ -16,15 +16,14 @@
|
|||||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor,
|
BaseMultimodalProcessor,
|
||||||
MultimodalSpecialTokens,
|
MultimodalSpecialTokens,
|
||||||
get_global_processor,
|
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
|
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor)
|
||||||
self.IMAGE_TOKEN = "<image>"
|
self.IMAGE_TOKEN = "<image>"
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _process_images_task(image, input_text, max_req_input_len):
|
|
||||||
processor = get_global_processor()
|
|
||||||
res = processor.__call__(
|
|
||||||
conversations=input_text, images=image, max_req_input_len=max_req_input_len
|
|
||||||
)
|
|
||||||
|
|
||||||
image_token_id = processor.image_token_id
|
|
||||||
|
|
||||||
res["im_token_id"] = image_token_id
|
|
||||||
return res
|
|
||||||
|
|
||||||
async def _process_images(self, image_data, input_text, max_req_input_len):
|
|
||||||
if self.executor is not None:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
image_inputs = await loop.run_in_executor(
|
|
||||||
self.executor,
|
|
||||||
DeepseekVL2ImageProcessor._process_images_task,
|
|
||||||
image_data,
|
|
||||||
input_text,
|
|
||||||
max_req_input_len,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_inputs = self._process_images_task(
|
|
||||||
image_data, input_text, max_req_input_len
|
|
||||||
)
|
|
||||||
|
|
||||||
return image_inputs
|
|
||||||
|
|
||||||
async def _process_images(self, image_data, input_text, max_req_input_len):
|
|
||||||
if self.executor is not None:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
image_inputs = await loop.run_in_executor(
|
|
||||||
self.executor,
|
|
||||||
DeepseekVL2ImageProcessor._process_images_task,
|
|
||||||
image_data,
|
|
||||||
input_text,
|
|
||||||
max_req_input_len,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_inputs = self._process_images_task(
|
|
||||||
image_data, input_text, max_req_input_len
|
|
||||||
)
|
|
||||||
return image_inputs
|
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
|
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
|
||||||
):
|
):
|
||||||
@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|||||||
if not isinstance(image_data, list):
|
if not isinstance(image_data, list):
|
||||||
image_data = [image_data]
|
image_data = [image_data]
|
||||||
|
|
||||||
images, image_sizes = [], []
|
|
||||||
|
|
||||||
image_token = self.IMAGE_TOKEN
|
image_token = self.IMAGE_TOKEN
|
||||||
base_output = self.load_mm_data(
|
base_output = self.load_mm_data(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|||||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||||
max_req_input_len=max_req_input_len,
|
max_req_input_len=max_req_input_len,
|
||||||
)
|
)
|
||||||
res = await self._process_images(
|
res = self.process_mm_data(
|
||||||
base_output.images, base_output.input_text, max_req_input_len
|
input_text=base_output.input_text,
|
||||||
|
images=base_output.images,
|
||||||
|
max_req_input_len=max_req_input_len,
|
||||||
|
conversations=base_output.input_text,
|
||||||
)
|
)
|
||||||
images_seq_mask = res["images_seq_mask"]
|
images_seq_mask = res["images_seq_mask"]
|
||||||
images_spatial_crop = res["images_spatial_crop"]
|
images_spatial_crop = res["images_spatial_crop"]
|
||||||
@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|||||||
batched_images_spatial_crop.append(images_spatial_crop)
|
batched_images_spatial_crop.append(images_spatial_crop)
|
||||||
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
||||||
|
|
||||||
|
items = []
|
||||||
|
item = MultimodalDataItem(
|
||||||
|
pixel_values=res["images"],
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
image_emb_mask=images_seq_mask,
|
||||||
|
image_spatial_crop=batched_images_spatial_crop,
|
||||||
|
)
|
||||||
|
items += [item]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"mm_items": items,
|
||||||
"input_ids": res["input_ids"].tolist(),
|
"input_ids": res["input_ids"].tolist(),
|
||||||
"pixel_values": res["images"],
|
"im_token_id": self._processor.image_token_id,
|
||||||
"im_token_id": res["im_token_id"],
|
|
||||||
"data_hashes": base_output.mm_data_hashes,
|
|
||||||
"image_sizes": image_sizes,
|
|
||||||
"images_emb_mask": images_seq_mask,
|
|
||||||
"image_spatial_crop": batched_images_spatial_crop,
|
|
||||||
"modalities": request_obj.modalities or ["image"],
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
MultimodalSpecialTokens,
|
MultimodalSpecialTokens,
|
||||||
get_global_processor,
|
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
|
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
|
||||||
|
|
||||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
|
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
|
||||||
@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|||||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||||
|
|
||||||
async def _process_single_image(self, images, input_text) -> dict:
|
|
||||||
if isinstance(images, list) and len(images) == 0:
|
|
||||||
images = None
|
|
||||||
processor = get_global_processor()
|
|
||||||
result = processor.__call__(
|
|
||||||
text=[input_text],
|
|
||||||
images=images,
|
|
||||||
padding=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
# if RGBA, this needs to be set
|
|
||||||
# images_kwargs={
|
|
||||||
# "input_data_format": ChannelDimension.FIRST
|
|
||||||
# }
|
|
||||||
)
|
|
||||||
|
|
||||||
pixel_values = getattr(result, "pixel_values", None)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": result.input_ids,
|
|
||||||
"pixel_values": pixel_values,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self,
|
self,
|
||||||
image_data: List[Union[str, bytes]],
|
image_data: List[Union[str, bytes]],
|
||||||
@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|||||||
|
|
||||||
image_token = self.IMAGE_TOKEN
|
image_token = self.IMAGE_TOKEN
|
||||||
base_output = self.load_mm_data(
|
base_output = self.load_mm_data(
|
||||||
input_ids=input_ids,
|
prompt=input_ids,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||||
max_req_input_len=max_req_input_len,
|
max_req_input_len=max_req_input_len,
|
||||||
discard_alpha_channel=True,
|
discard_alpha_channel=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = await self._process_single_image(
|
ret = self.process_mm_data(
|
||||||
input_text=base_output.input_text, images=base_output.images
|
input_text=base_output.input_text, images=base_output.images
|
||||||
)
|
)
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for i, image in enumerate(base_output.images):
|
||||||
|
item = MultimodalDataItem(
|
||||||
|
pixel_values=ret["pixel_values"][i],
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
)
|
||||||
|
items += [item]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"mm_items": items,
|
||||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||||
"pixel_values": ret["pixel_values"],
|
|
||||||
"data_hashes": base_output.mm_data_hashes,
|
|
||||||
"im_start_id": self.IM_START_TOKEN_ID,
|
"im_start_id": self.IM_START_TOKEN_ID,
|
||||||
"im_end_id": self.IM_END_TOKEN_ID,
|
"im_end_id": self.IM_END_TOKEN_ID,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import asyncio
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor,
|
BaseMultimodalProcessor,
|
||||||
MultimodalSpecialTokens,
|
MultimodalSpecialTokens,
|
||||||
get_global_processor,
|
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
|
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
|
||||||
|
|
||||||
|
|
||||||
@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
|||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _process_images_task(images, input_text):
|
|
||||||
processor = get_global_processor()
|
|
||||||
result = processor.__call__(
|
|
||||||
prompt=input_text, images=images, return_tensors="pt"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"input_ids": result["input_ids"],
|
|
||||||
"pixel_values": result["pixel_values"],
|
|
||||||
"images_emb_mask": result["images_emb_mask"],
|
|
||||||
"im_start_id": processor.image_start_id,
|
|
||||||
"im_end_id": processor.image_end_id,
|
|
||||||
"im_token_id": processor.image_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _process_images(self, images, input_text):
|
|
||||||
if self.executor is not None:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
image_inputs = await loop.run_in_executor(
|
|
||||||
self.executor,
|
|
||||||
JanusProImageProcessor._process_images_task,
|
|
||||||
images,
|
|
||||||
input_text,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_inputs = self._processor(
|
|
||||||
images=images, text=input_text, return_tensors="pt"
|
|
||||||
)
|
|
||||||
|
|
||||||
return image_inputs
|
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self,
|
self,
|
||||||
image_data: List[Union[str, bytes]],
|
image_data: List[Union[str, bytes]],
|
||||||
@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
|||||||
if not isinstance(image_data, list):
|
if not isinstance(image_data, list):
|
||||||
image_data = [image_data]
|
image_data = [image_data]
|
||||||
|
|
||||||
|
processor = self._processor
|
||||||
|
|
||||||
base_out = self.load_mm_data(
|
base_out = self.load_mm_data(
|
||||||
input_ids=input_ids,
|
prompt=input_ids,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
multimodal_tokens=MultimodalSpecialTokens(
|
multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
|
||||||
image_token="<image_placeholder>"
|
|
||||||
),
|
|
||||||
max_req_input_len=max_req_input_len,
|
max_req_input_len=max_req_input_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
images = base_out.images
|
images = base_out.images
|
||||||
res = await self._process_images(images=images, input_text=base_out.input_text)
|
res = self.process_mm_data(
|
||||||
# print(res)
|
input_text=base_out.input_text,
|
||||||
# print(base_out)
|
prompt=base_out.input_text,
|
||||||
# print("", res["images_emb_mask"].shape)
|
images=images,
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
|
"mm_items": [
|
||||||
|
MultimodalDataItem(
|
||||||
|
pixel_values=res["pixel_values"],
|
||||||
|
image_emb_mask=res["images_emb_mask"],
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
)
|
||||||
|
],
|
||||||
"input_ids": res["input_ids"].flatten().tolist(),
|
"input_ids": res["input_ids"].flatten().tolist(),
|
||||||
"pixel_values": res["pixel_values"],
|
"im_start_id": processor.image_start_id,
|
||||||
"images_emb_mask": res["images_emb_mask"],
|
"im_end_id": processor.image_end_id,
|
||||||
"data_hashes": base_out.mm_data_hashes,
|
"im_token_id": processor.image_id,
|
||||||
"im_start_id": res["im_start_id"],
|
|
||||||
"im_end_id": res["im_end_id"],
|
|
||||||
"im_token_id": res["im_token_id"],
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import numpy as np
|
|||||||
|
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor,
|
BaseMultimodalProcessor,
|
||||||
get_global_processor,
|
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||||
from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM
|
from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM
|
||||||
from sglang.srt.models.llavavid import LlavaVidForCausalLM
|
from sglang.srt.models.llavavid import LlavaVidForCausalLM
|
||||||
@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|||||||
image_data: Union[str, bytes],
|
image_data: Union[str, bytes],
|
||||||
image_aspect_ratio: Optional[str] = None,
|
image_aspect_ratio: Optional[str] = None,
|
||||||
image_grid_pinpoints: Optional[str] = None,
|
image_grid_pinpoints: Optional[str] = None,
|
||||||
image_processor=None,
|
processor=None,
|
||||||
):
|
):
|
||||||
processor = get_global_processor()
|
|
||||||
|
|
||||||
image_processor = image_processor or processor.image_processor
|
image_processor = processor.image_processor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image, image_size = load_image(image_data)
|
image, image_size = load_image(image_data)
|
||||||
@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|||||||
async def _process_single_image(
|
async def _process_single_image(
|
||||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||||
):
|
):
|
||||||
if self.executor is not None:
|
if self.cpu_executor is not None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return await loop.run_in_executor(
|
return await loop.run_in_executor(
|
||||||
self.executor,
|
self.cpu_executor,
|
||||||
LlavaImageProcessor._process_single_image_task,
|
LlavaImageProcessor._process_single_image_task,
|
||||||
image_data,
|
image_data,
|
||||||
aspect_ratio,
|
aspect_ratio,
|
||||||
grid_pinpoints,
|
grid_pinpoints,
|
||||||
|
self._processor,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._process_single_image_task(
|
return self._process_single_image_task(
|
||||||
image_data, aspect_ratio, grid_pinpoints
|
image_data,
|
||||||
|
aspect_ratio,
|
||||||
|
grid_pinpoints,
|
||||||
|
self._processor.image_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|||||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||||
image_data[0], aspect_ratio, grid_pinpoints
|
image_data[0], aspect_ratio, grid_pinpoints
|
||||||
)
|
)
|
||||||
data_hashes = [image_hash]
|
|
||||||
image_sizes = [image_size]
|
image_sizes = [image_size]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid image data: {image_data}")
|
raise ValueError(f"Invalid image data: {image_data}")
|
||||||
|
modality = Modality.IMAGE
|
||||||
|
if isinstance(request_obj.modalities, list):
|
||||||
|
if request_obj.modalities[0] == "multi-images":
|
||||||
|
modality = Modality.MULTI_IMAGES
|
||||||
|
elif request_obj.modalities[0] == "video":
|
||||||
|
modality = Modality.VIDEO
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"pixel_values": pixel_values,
|
"mm_items": [
|
||||||
"data_hashes": data_hashes,
|
MultimodalDataItem(
|
||||||
"image_sizes": image_sizes,
|
pixel_values=pixel_values,
|
||||||
"modalities": request_obj.modalities or ["image"],
|
image_sizes=image_sizes,
|
||||||
|
modality=modality,
|
||||||
|
)
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
import asyncio
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import BaseImageProcessorFast
|
||||||
|
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor,
|
BaseMultimodalProcessor,
|
||||||
MultimodalSpecialTokens,
|
MultimodalSpecialTokens,
|
||||||
get_global_processor,
|
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.models.minicpmo import MiniCPMO
|
from sglang.srt.models.minicpmo import MiniCPMO
|
||||||
from sglang.srt.models.minicpmv import MiniCPMV
|
from sglang.srt.models.minicpmv import MiniCPMV
|
||||||
|
|
||||||
@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
self.image_token = "(<image>./</image>)"
|
self.image_token = "(<image>./</image>)"
|
||||||
self.audio_token = "(<audio>./</audio>)"
|
self.audio_token = "(<audio>./</audio>)"
|
||||||
|
|
||||||
@staticmethod
|
def process_data_task(self, input_text, images=None, audios=None):
|
||||||
def _process_data_task(input_text, images=None, audios=None):
|
|
||||||
|
|
||||||
if isinstance(images, list) and len(images) == 0:
|
if isinstance(images, list) and len(images) == 0:
|
||||||
images = None
|
images = None
|
||||||
if isinstance(audios, list) and len(audios) == 0:
|
if isinstance(audios, list) and len(audios) == 0:
|
||||||
audios = None
|
audios = None
|
||||||
result = get_global_processor().__call__(
|
processor = self._processor
|
||||||
|
args = {}
|
||||||
|
if isinstance(processor, BaseImageProcessorFast):
|
||||||
|
args["device"] = "cuda"
|
||||||
|
result = self._processor.__call__(
|
||||||
text=input_text,
|
text=input_text,
|
||||||
images=images,
|
images=images,
|
||||||
audios=audios,
|
audios=audios,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
chunk_input=True,
|
chunk_input=True,
|
||||||
|
**args,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"input_ids": result.input_ids,
|
"input_ids": result.input_ids,
|
||||||
@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
"audio_bounds": getattr(result, "audio_bounds", None),
|
"audio_bounds": getattr(result, "audio_bounds", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _process_data(self, images, input_text, audios=None):
|
|
||||||
if self.executor is not None:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
multimodal_data_inputs = await loop.run_in_executor(
|
|
||||||
self.executor,
|
|
||||||
MiniCPMMultimodalProcessor._process_data_task,
|
|
||||||
input_text,
|
|
||||||
images,
|
|
||||||
audios,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
multimodal_data_inputs = self._processor(
|
|
||||||
images=images, text=input_text, audios=audios, return_tensors="pt"
|
|
||||||
)
|
|
||||||
|
|
||||||
return multimodal_data_inputs
|
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self,
|
self,
|
||||||
image_data: List[Union[str, bytes]],
|
image_data: List[Union[str, bytes]],
|
||||||
@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
audio_data = [audio_data]
|
audio_data = [audio_data]
|
||||||
|
|
||||||
base_output = self.load_mm_data(
|
base_output = self.load_mm_data(
|
||||||
input_ids=input_ids,
|
prompt=input_ids,
|
||||||
max_req_input_len=max_req_input_len,
|
max_req_input_len=max_req_input_len,
|
||||||
audio_data=audio_data,
|
audio_data=audio_data,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
if base_output is None:
|
if base_output is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
res = await self._process_data(
|
res = self.process_mm_data(
|
||||||
images=base_output.images,
|
|
||||||
input_text=base_output.input_text,
|
input_text=base_output.input_text,
|
||||||
|
images=base_output.images,
|
||||||
audios=base_output.audios,
|
audios=base_output.audios,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
tgt_sizes_flat += [tgt_n]
|
tgt_sizes_flat += [tgt_n]
|
||||||
|
|
||||||
pixel_values = pixel_values_flat
|
pixel_values = pixel_values_flat
|
||||||
if len(tgt_sizes_flat) == 0:
|
|
||||||
tgt_sizes = None
|
items = []
|
||||||
else:
|
if len(pixel_values) != 0:
|
||||||
tgt_sizes = torch.stack(tgt_sizes_flat)
|
item = MultimodalDataItem(
|
||||||
if not isinstance(res["audio_features"], list):
|
pixel_values=pixel_values,
|
||||||
res["audio_features"] = [res["audio_features"]]
|
tgt_size=tgt_sizes_flat,
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
)
|
||||||
|
items += [item]
|
||||||
|
|
||||||
|
if (
|
||||||
|
"audio_features" in res
|
||||||
|
and res["audio_features"] is not None
|
||||||
|
and len(res["audio_features"]) != 0
|
||||||
|
):
|
||||||
|
item = MultimodalDataItem(
|
||||||
|
audio_features=[res["audio_features"]],
|
||||||
|
audio_feature_lens=res["audio_feature_lens"],
|
||||||
|
modality=Modality.AUDIO,
|
||||||
|
)
|
||||||
|
items += [item]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"mm_items": items,
|
||||||
"input_ids": res["input_ids"].flatten().tolist(),
|
"input_ids": res["input_ids"].flatten().tolist(),
|
||||||
"pixel_values": pixel_values,
|
|
||||||
"tgt_sizes": tgt_sizes,
|
|
||||||
"data_hashes": base_output.mm_data_hashes,
|
|
||||||
"modalities": request_obj.modalities or ["image"],
|
|
||||||
"audio_start_id": audio_start_id,
|
"audio_start_id": audio_start_id,
|
||||||
"audio_end_id": audio_end_id,
|
"audio_end_id": audio_end_id,
|
||||||
"audio_features": res["audio_features"],
|
|
||||||
"audio_bounds": res["audio_bounds"],
|
|
||||||
"audio_feature_lens": res["audio_feature_lens"],
|
|
||||||
"im_token_id": im_token_id,
|
"im_token_id": im_token_id,
|
||||||
"im_start_id": tokenizer.im_start_id,
|
"im_start_id": tokenizer.im_start_id,
|
||||||
"im_end_id": tokenizer.im_end_id,
|
"im_end_id": tokenizer.im_end_id,
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import asyncio
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor,
|
BaseMultimodalProcessor,
|
||||||
get_global_processor,
|
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.models.mllama import MllamaForConditionalGeneration
|
from sglang.srt.models.mllama import MllamaForConditionalGeneration
|
||||||
from sglang.srt.utils import load_image
|
from sglang.srt.utils import load_image
|
||||||
|
|
||||||
@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
|
|||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _process_single_image_task(images, input_text):
|
|
||||||
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
|
|
||||||
return get_global_processor()(images, input_text, return_tensors="pt")
|
|
||||||
|
|
||||||
async def _process_single_image(self, images, input_text):
|
|
||||||
if self.executor is not None:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
image_inputs = await loop.run_in_executor(
|
|
||||||
self.executor,
|
|
||||||
MllamaImageProcessor._process_single_image_task,
|
|
||||||
images,
|
|
||||||
input_text,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_inputs = self._processor(images, input_text, return_tensors="pt")
|
|
||||||
|
|
||||||
return image_inputs
|
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||||
):
|
):
|
||||||
@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
|
|||||||
else:
|
else:
|
||||||
images = load_image(image_data[0])[0]
|
images = load_image(image_data[0])[0]
|
||||||
|
|
||||||
image_inputs = await self._process_single_image(images, input_text)
|
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||||
image_inputs["data_hashes"] = [hash(str(image_data))]
|
|
||||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||||
|
image_inputs["mm_items"] = [
|
||||||
|
MultimodalDataItem(
|
||||||
|
pixel_values=image_inputs["pixel_values"],
|
||||||
|
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
|
||||||
|
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
return image_inputs
|
return image_inputs
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
import time
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
MultimodalSpecialTokens,
|
MultimodalSpecialTokens,
|
||||||
get_global_processor,
|
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||||
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
self.MAX_PIXELS = 16384 * 28 * 28
|
self.MAX_PIXELS = 16384 * 28 * 28
|
||||||
self.MAX_RATIO = 200
|
self.MAX_RATIO = 200
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _process_images_task(images, input_text, _hf_config):
|
|
||||||
if isinstance(images, list) and len(images) == 0:
|
|
||||||
images = None
|
|
||||||
result = get_global_processor().__call__(
|
|
||||||
text=[input_text], images=images, padding=True, return_tensors="pt"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": result.input_ids,
|
|
||||||
"pixel_values": getattr(result, "pixel_values", None),
|
|
||||||
"image_grid_thw": getattr(result, "image_grid_thw", None),
|
|
||||||
"second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
|
|
||||||
"video_grid_thws": getattr(result, "video_grid_thws", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _process_single_image(self, images, input_text) -> dict:
|
|
||||||
if self.executor is not None:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return await loop.run_in_executor(
|
|
||||||
self.executor,
|
|
||||||
Qwen2_5VLImageProcessor._process_images_task,
|
|
||||||
images,
|
|
||||||
input_text,
|
|
||||||
self.hf_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self._process_images_task(images, input_text, self.hf_config)
|
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self,
|
self,
|
||||||
image_data: List[Union[str, bytes]],
|
image_data: List[Union[str, bytes]],
|
||||||
input_ids,
|
prompt,
|
||||||
request_obj,
|
request_obj,
|
||||||
max_req_input_len,
|
max_req_input_len,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
start = time.time()
|
|
||||||
if not image_data:
|
if not image_data:
|
||||||
return None
|
return None
|
||||||
if isinstance(image_data, str):
|
if isinstance(image_data, str):
|
||||||
@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
|
|
||||||
image_token = self.IMAGE_TOKEN
|
image_token = self.IMAGE_TOKEN
|
||||||
base_output = self.load_mm_data(
|
base_output = self.load_mm_data(
|
||||||
input_ids=input_ids,
|
prompt=prompt,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||||
max_req_input_len=max_req_input_len,
|
max_req_input_len=max_req_input_len,
|
||||||
@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||||
return math.floor(number / factor) * factor
|
return math.floor(number / factor) * factor
|
||||||
|
|
||||||
images = [resize_image(image) for image in base_output.images]
|
async def resize_image_async(image):
|
||||||
|
return resize_image(image)
|
||||||
|
|
||||||
ret = await self._process_single_image(
|
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||||
images=images, input_text=base_output.input_text
|
resized_images = await asyncio.gather(*resize_tasks)
|
||||||
|
|
||||||
|
ret = self.process_mm_data(
|
||||||
|
input_text=base_output.input_text,
|
||||||
|
images=resized_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_grid_thws = torch.concat([ret["image_grid_thw"]])
|
image_grid_thws = torch.concat([ret["image_grid_thw"]])
|
||||||
video_grid_thws = None
|
|
||||||
return {
|
return {
|
||||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||||
"pixel_values": ret["pixel_values"],
|
"mm_items": [
|
||||||
"data_hashes": base_output.mm_data_hashes,
|
MultimodalDataItem(
|
||||||
"modalities": request_obj.modalities or ["image"],
|
pixel_values=ret["pixel_values"],
|
||||||
"image_grid_thws": image_grid_thws,
|
image_grid_thws=image_grid_thws,
|
||||||
"video_grid_thws": video_grid_thws,
|
# TODO
|
||||||
|
video_grid_thws=None,
|
||||||
|
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
)
|
||||||
|
],
|
||||||
"im_start_id": self.IM_START_TOKEN_ID,
|
"im_start_id": self.IM_START_TOKEN_ID,
|
||||||
"im_end_id": self.IM_END_TOKEN_ID,
|
"im_end_id": self.IM_END_TOKEN_ID,
|
||||||
"im_token_id": self.image_token_id,
|
"im_token_id": self.image_token_id,
|
||||||
"video_token_id": self.video_token_id,
|
"video_token_id": self.video_token_id,
|
||||||
"second_per_grid_ts": ret["second_per_grid_ts"],
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum, auto
|
||||||
|
|
||||||
# Copyright 2023-2024 SGLang Team
|
# Copyright 2023-2024 SGLang Team
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
|||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_compiler_backend
|
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
@@ -143,165 +145,185 @@ class FINISH_ABORT(BaseFinishReason):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Modality(Enum):
|
||||||
|
IMAGE = auto()
|
||||||
|
MULTI_IMAGES = auto()
|
||||||
|
VIDEO = auto()
|
||||||
|
AUDIO = auto()
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MultimodalInputs:
|
class MultimodalDataItem:
|
||||||
"""The image related inputs."""
|
"""
|
||||||
|
A single multimodal data, from a single image/video/audio or other
|
||||||
|
"""
|
||||||
|
|
||||||
pixel_values: Union[torch.Tensor, np.array]
|
modality: Modality
|
||||||
data_hashes: Optional[list] = None
|
|
||||||
image_sizes: Optional[list] = None
|
|
||||||
image_offsets: Optional[list] = None
|
|
||||||
image_pad_len: Optional[list] = None
|
|
||||||
pad_values: Optional[list] = None
|
|
||||||
modalities: Optional[list] = None
|
|
||||||
num_image_tokens: Optional[int] = None
|
|
||||||
|
|
||||||
# Llava related
|
hash: int = None
|
||||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
pad_value: int = None
|
||||||
|
|
||||||
|
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
||||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
# QWen2-VL related
|
image_sizes: Tuple[int, int] = None
|
||||||
# [num_of_images, t, h, w]
|
image_offsets: Optional[list] = None
|
||||||
image_grid_thws: torch.Tensor = None
|
|
||||||
mrope_position_delta: Optional[torch.Tensor] = None
|
# the real data, pixel_values or audio_features
|
||||||
# Qwen2-VL video related
|
# data: Union[List[torch.Tensor], List[np.array]]
|
||||||
video_token_id: Optional[int] = None
|
pixel_values: Union[torch.Tensor, np.array] = None
|
||||||
video_grid_thws: List[Tuple[int, int, int]] = None
|
image_grid_thws: Union[torch.Tensor, np.array] = None
|
||||||
|
video_grid_thws: Union[torch.Tensor, np.array] = None
|
||||||
|
|
||||||
|
image_emb_mask: Optional[torch.Tensor] = None
|
||||||
|
image_spatial_crop: Optional[torch.Tensor] = None
|
||||||
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
# deepseek vl2 related
|
# [num_images, (n, w, h)]
|
||||||
images_emb_mask: Optional[List[torch.Tensor]] = None
|
tgt_size: Tuple[int, int] = None
|
||||||
image_spatial_crop: Optional[List[torch.Tensor]] = None
|
|
||||||
|
|
||||||
# The id of the single-image placeholder token
|
audio_features: Union[torch.Tensor, np.array] = None
|
||||||
|
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_empty_list(l):
|
||||||
|
if l is None:
|
||||||
|
return True
|
||||||
|
return len([item for item in flatten_nested_list(l) if item is not None]) == 0
|
||||||
|
|
||||||
|
def set_pad_value(self):
|
||||||
|
"""
|
||||||
|
Set the pad value after first hashign the data
|
||||||
|
"""
|
||||||
|
|
||||||
|
def hash_feature(f):
|
||||||
|
if isinstance(f, list):
|
||||||
|
return hash(tuple(flatten_nested_list(f)))
|
||||||
|
elif isinstance(f, np.ndarray):
|
||||||
|
arr = np.ascontiguousarray(f)
|
||||||
|
arr_bytes = arr.tobytes()
|
||||||
|
return hash(arr_bytes)
|
||||||
|
return hash(f)
|
||||||
|
|
||||||
|
if self.is_audio():
|
||||||
|
self.hash = hash_feature(self.audio_features)
|
||||||
|
else:
|
||||||
|
self.hash = hash_feature(self.pixel_values)
|
||||||
|
|
||||||
|
assert self.hash is not None
|
||||||
|
self.pad_value = self.hash % (1 << 30)
|
||||||
|
|
||||||
|
def is_audio(self):
|
||||||
|
return (
|
||||||
|
self.modality == Modality.AUDIO
|
||||||
|
) and not MultimodalDataItem.is_empty_list(self.audio_features)
|
||||||
|
|
||||||
|
def is_image(self):
|
||||||
|
return (
|
||||||
|
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
|
||||||
|
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||||
|
|
||||||
|
def is_video(self):
|
||||||
|
return (
|
||||||
|
self.modality == Modality.VIDEO
|
||||||
|
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
...
|
||||||
|
# TODO
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MultimodalInputs:
|
||||||
|
"""The multimodal data related inputs."""
|
||||||
|
|
||||||
|
# items of data
|
||||||
|
mm_items: List[MultimodalDataItem]
|
||||||
|
image_pad_len: Optional[list] = None
|
||||||
|
num_image_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
# QWen2-VL related
|
||||||
|
mrope_position_delta: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# image
|
||||||
im_token_id: Optional[torch.Tensor] = None
|
im_token_id: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# All the images in the batch should share the same special image
|
|
||||||
# bound token ids.
|
|
||||||
im_start_id: Optional[int] = None
|
im_start_id: Optional[int] = None
|
||||||
im_end_id: Optional[int] = None
|
im_end_id: Optional[int] = None
|
||||||
slice_start_id: Optional[int] = None
|
slice_start_id: Optional[int] = None
|
||||||
slice_end_id: Optional[int] = None
|
slice_end_id: Optional[int] = None
|
||||||
# [num_images, 2 (w, h)]
|
|
||||||
tgt_sizes: Optional[list] = None
|
# video
|
||||||
|
video_token_id: Optional[int] = None
|
||||||
|
|
||||||
# audio
|
# audio
|
||||||
audio_start_id: Optional[torch.Tensor] = None
|
audio_start_id: Optional[torch.Tensor] = None
|
||||||
audio_end_id: Optional[torch.Tensor] = None
|
audio_end_id: Optional[torch.Tensor] = None
|
||||||
audio_features: Optional[List[torch.Tensor]] = None
|
|
||||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_dict(obj: dict):
|
def from_dict(obj: dict):
|
||||||
ret = MultimodalInputs(
|
ret = MultimodalInputs(
|
||||||
pixel_values=obj["pixel_values"],
|
mm_items=obj["mm_items"],
|
||||||
data_hashes=obj["data_hashes"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert isinstance(ret.mm_items, list)
|
||||||
|
ret.mm_items = [
|
||||||
|
item
|
||||||
|
for item in ret.mm_items
|
||||||
|
if item.is_audio() or item.is_image() or item.is_video()
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(ret.mm_items) != 0
|
||||||
|
|
||||||
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||||
# Please note that if the `input_ids` is later used in the model forward,
|
# Please note that if the `input_ids` is later used in the model forward,
|
||||||
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
||||||
# errors in cuda kernels. See also llava.py for example.
|
# errors in cuda kernels. See also llava.py for example.
|
||||||
ret.pad_values = [x % (1 << 30) for x in ret.data_hashes]
|
for item in ret.mm_items:
|
||||||
|
item.set_pad_value()
|
||||||
|
|
||||||
optional_args = [
|
optional_args = [
|
||||||
"image_sizes",
|
|
||||||
"modalities",
|
"modalities",
|
||||||
"aspect_ratio_ids",
|
|
||||||
"aspect_ratio_mask",
|
|
||||||
"image_grid_thws",
|
|
||||||
"images_emb_mask",
|
|
||||||
"image_spatial_crop",
|
|
||||||
"im_token_id",
|
"im_token_id",
|
||||||
"im_start_id",
|
"im_start_id",
|
||||||
"im_end_id",
|
"im_end_id",
|
||||||
"slice_start_id",
|
"slice_start_id",
|
||||||
"slice_end_id",
|
"slice_end_id",
|
||||||
"tgt_sizes",
|
|
||||||
"audio_start_id",
|
"audio_start_id",
|
||||||
"audio_end_id",
|
"audio_end_id",
|
||||||
"audio_features",
|
|
||||||
"audio_feature_lens",
|
|
||||||
]
|
]
|
||||||
for arg in optional_args:
|
for arg in optional_args:
|
||||||
if arg in obj:
|
if arg in obj:
|
||||||
setattr(ret, arg, obj[arg])
|
setattr(ret, arg, obj[arg])
|
||||||
|
|
||||||
# validate
|
|
||||||
assert (
|
|
||||||
isinstance(ret.pixel_values, torch.Tensor)
|
|
||||||
or isinstance(ret.pixel_values, np.ndarray)
|
|
||||||
or isinstance(ret.pixel_values, list)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert ret.audio_features is None or isinstance(ret.audio_features, list)
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def contains_image_inputs(self) -> bool:
|
def contains_image_inputs(self) -> bool:
|
||||||
""" """
|
""" """
|
||||||
return self.pixel_values is not None and self.pixel_values != []
|
return any(item.is_image() for item in self.mm_items)
|
||||||
|
|
||||||
def contains_audio_inputs(self) -> bool:
|
def contains_audio_inputs(self) -> bool:
|
||||||
""" """
|
""" """
|
||||||
return self.audio_features is not None and self.audio_features != []
|
return any(item.is_audio() for item in self.mm_items)
|
||||||
|
|
||||||
|
def collect_image_inputs(self) -> List[torch.Tensor]:
|
||||||
|
return [item.pixel_values for item in self.mm_items if item.is_image()]
|
||||||
|
|
||||||
def merge(self, other: MultimodalInputs):
|
def merge(self, other: MultimodalInputs):
|
||||||
"""
|
"""
|
||||||
merge image inputs when requests are being merged
|
merge image inputs when requests are being merged
|
||||||
"""
|
"""
|
||||||
if isinstance(self.pixel_values, list):
|
|
||||||
# in some rare cases, pixel values are list of patches with different shapes
|
|
||||||
# e.g. minicpm
|
|
||||||
self.pixel_values += other.pixel_values
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
|
||||||
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
|
|
||||||
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
|
||||||
|
|
||||||
# args would be stacked along first dim
|
|
||||||
# usually these are already tensors
|
|
||||||
stack_args = [
|
|
||||||
# TODO: merge with image_grid_thws, basically the same thing
|
|
||||||
"tgt_sizes",
|
|
||||||
"image_spatial_crop",
|
|
||||||
]
|
|
||||||
for arg in stack_args:
|
|
||||||
if getattr(self, arg, None) is None:
|
|
||||||
setattr(self, arg, getattr(other, arg, None))
|
|
||||||
elif getattr(other, arg, None) is not None:
|
|
||||||
# self and other both not None
|
|
||||||
setattr(
|
|
||||||
self,
|
|
||||||
arg,
|
|
||||||
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.image_grid_thws is None:
|
|
||||||
self.image_grid_thws = other.image_grid_thws
|
|
||||||
elif other.image_grid_thws is not None:
|
|
||||||
self.image_grid_thws = torch.concat(
|
|
||||||
[self.image_grid_thws, other.image_grid_thws]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||||
# Please note that if the `input_ids` is later used in the model forward,
|
# Please note that if the `input_ids` is later used in the model forward,
|
||||||
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
||||||
# errors in cuda kernels. See also llava.py for example.
|
# errors in cuda kernels. See also llava.py for example.
|
||||||
self.data_hashes += other.data_hashes
|
|
||||||
self.pad_values = [x % (1 << 30) for x in self.data_hashes]
|
|
||||||
|
|
||||||
# args needed to be merged
|
# args needed to be merged
|
||||||
optional_args = [
|
optional_args = [
|
||||||
"audio_features",
|
"items",
|
||||||
"image_sizes",
|
|
||||||
"image_offsets",
|
"image_offsets",
|
||||||
"image_pad_len",
|
"image_pad_len",
|
||||||
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
||||||
"aspect_ratio_ids",
|
|
||||||
"aspect_ratio_mask",
|
|
||||||
"images_emb_mask",
|
|
||||||
]
|
]
|
||||||
for arg in optional_args:
|
for arg in optional_args:
|
||||||
self_arg = getattr(self, arg, None)
|
self_arg = getattr(self, arg, None)
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|||||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
|||||||
@@ -1,11 +1,6 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from collections import defaultdict
|
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
||||||
|
|
||||||
|
|||||||
@@ -355,11 +355,6 @@ class ForwardBatch:
|
|||||||
for mm_input in valid_inputs[1:]:
|
for mm_input in valid_inputs[1:]:
|
||||||
merged.merge(mm_input)
|
merged.merge(mm_input)
|
||||||
|
|
||||||
if isinstance(merged.pixel_values, np.ndarray):
|
|
||||||
merged.pixel_values = torch.from_numpy(merged.pixel_values)
|
|
||||||
if isinstance(merged.audio_features, np.ndarray):
|
|
||||||
merged.audio_features = torch.from_numpy(merged.audio_features)
|
|
||||||
|
|
||||||
return merged
|
return merged
|
||||||
|
|
||||||
def contains_image_inputs(self) -> bool:
|
def contains_image_inputs(self) -> bool:
|
||||||
|
|||||||
@@ -251,16 +251,15 @@ class ModelRunner:
|
|||||||
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
||||||
|
|
||||||
if self.is_multimodal:
|
if self.is_multimodal:
|
||||||
self.mem_fraction_static *= 0.95
|
self.mem_fraction_static *= 0.90
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
||||||
f"because this is a multimodal model."
|
f"because this is a multimodal model."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.model_config.hf_config.architectures == [
|
logger.info(
|
||||||
"MllamaForConditionalGeneration"
|
"Automatically turn off --chunked-prefill-size for multimodal model."
|
||||||
]:
|
)
|
||||||
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
|
||||||
server_args.chunked_prefill_size = -1
|
server_args.chunked_prefill_size = -1
|
||||||
|
|
||||||
if self.model_config.hf_config.architectures == [
|
if self.model_config.hf_config.architectures == [
|
||||||
@@ -269,18 +268,7 @@ class ModelRunner:
|
|||||||
"Qwen2_5_VLForConditionalGeneration"
|
"Qwen2_5_VLForConditionalGeneration"
|
||||||
]:
|
]:
|
||||||
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
|
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
|
||||||
logger.info(
|
logger.info("Automatically disable radix cache for qwen-vl series.")
|
||||||
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
|
|
||||||
)
|
|
||||||
server_args.chunked_prefill_size = -1
|
|
||||||
server_args.disable_radix_cache = True
|
|
||||||
|
|
||||||
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
|
|
||||||
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
|
|
||||||
logger.info(
|
|
||||||
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
|
|
||||||
)
|
|
||||||
server_args.chunked_prefill_size = -1
|
|
||||||
server_args.disable_radix_cache = True
|
server_args.disable_radix_cache = True
|
||||||
|
|
||||||
if server_args.enable_deepep_moe:
|
if server_args.enable_deepep_moe:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|||||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||||
from sglang.srt.model_executor.model_runner import ForwardBatch
|
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix
|
from sglang.srt.utils import add_prefix, flatten_nested_list
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionEmbeddings(nn.Module):
|
class CLIPVisionEmbeddings(nn.Module):
|
||||||
@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
pixel_values: torch.Tensor,
|
pixel_values: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
hidden_states = self.embeddings(pixel_values.to(self.device))
|
hidden_states = self.embeddings(pixel_values.to(self.device))
|
||||||
hidden_states = self.pre_layrnorm(hidden_states)
|
hidden_states = self.pre_layrnorm(hidden_states)
|
||||||
|
|
||||||
@@ -456,12 +455,18 @@ class CLIPModel(nn.Module):
|
|||||||
get_embedding: bool = True,
|
get_embedding: bool = True,
|
||||||
):
|
):
|
||||||
assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
|
assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
|
||||||
image_inputs = None
|
mm_inputs = []
|
||||||
if forward_batch.mm_inputs is not None:
|
if forward_batch.mm_inputs is not None:
|
||||||
image_inputs = forward_batch.mm_inputs
|
mm_inputs = forward_batch.mm_inputs
|
||||||
|
pixel_values_list = [
|
||||||
if image_inputs is not None and image_inputs[0] is not None:
|
item.pixel_values
|
||||||
vision_outputs = self.vision_model(image_inputs[0].pixel_values)
|
for item in flatten_nested_list(
|
||||||
|
[mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if len(pixel_values_list) != 0:
|
||||||
|
pixel_values = torch.concat(pixel_values_list)
|
||||||
|
vision_outputs = self.vision_model(pixel_values)
|
||||||
pooled_output = vision_outputs[:, 0, :]
|
pooled_output = vision_outputs[:, 0, :]
|
||||||
image_embeds = self.visual_projection(pooled_output)
|
image_embeds = self.visual_projection(pooled_output)
|
||||||
image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)
|
image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
|
|||||||
MultiModalityDataPaddingPatternTokenPairs,
|
MultiModalityDataPaddingPatternTokenPairs,
|
||||||
general_mm_embed_routine,
|
general_mm_embed_routine,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.llama import LlamaForCausalLM
|
from sglang.srt.models.llama import LlamaForCausalLM
|
||||||
@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
pixel_values = image_input.pixel_values
|
pixel_values = torch.concat([item.pixel_values for item in items], dim=0)
|
||||||
bs, n = pixel_values.shape[0:2]
|
bs, n = pixel_values.shape[0:2]
|
||||||
pixel_values = pixel_values.to(
|
pixel_values = pixel_values.to(
|
||||||
device=self.vision_model.device, dtype=self.vision_model.dtype
|
device=self.vision_model.device, dtype=self.vision_model.dtype
|
||||||
@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|||||||
return images_embeds
|
return images_embeds
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Embedding:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
return self.language_model.model.embed_tokens
|
return self.language_model.get_input_embeddings()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1984,22 +1984,17 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
|
get_embedding: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
hidden_states = general_mm_embed_routine(
|
||||||
inputs_embeds = general_mm_embed_routine(
|
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
embed_tokens=self.get_input_embeddings(),
|
image_data_embedding_func=self.get_image_feature,
|
||||||
mm_data_embedding_func=self.get_image_feature,
|
language_model=self.language_model,
|
||||||
|
positions=positions,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.language_model(
|
return hidden_states
|
||||||
input_ids=None,
|
|
||||||
positions=positions,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
input_embeds=inputs_embeds,
|
|
||||||
get_embedding=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
||||||
return self.gen_aligner(self.gen_embed(image_ids))
|
return self.gen_aligner(self.gen_embed(image_ids))
|
||||||
|
|||||||
@@ -1308,6 +1308,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.dp_size = get_attention_dp_size()
|
self.dp_size = get_attention_dp_size()
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.linear import ReplicatedLinear
|
from sglang.srt.layers.linear import ReplicatedLinear
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
from sglang.srt.managers.mm_utils import (
|
||||||
|
MultiModalityDataPaddingPatternImageTokens,
|
||||||
|
general_mm_embed_routine,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
||||||
@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
# todo
|
|
||||||
class DeepseekVL2ForCausalLM(nn.Module):
|
class DeepseekVL2ForCausalLM(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
):
|
):
|
||||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
hs = general_mm_embed_routine(
|
||||||
if (
|
|
||||||
forward_batch.forward_mode.is_extend()
|
|
||||||
and forward_batch.contains_image_inputs()
|
|
||||||
):
|
|
||||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
|
||||||
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
|
|
||||||
for idx, image in enumerate(forward_batch.mm_inputs):
|
|
||||||
if image is None:
|
|
||||||
continue
|
|
||||||
start_idx = extend_start_loc_cpu[idx]
|
|
||||||
end_idx = start_idx + extend_seq_lens_cpu[idx]
|
|
||||||
images_emb_mask = image.images_emb_mask.to(device="cuda")
|
|
||||||
image_features = self.get_image_feature(image)
|
|
||||||
input_embeds[start_idx:end_idx] = input_embeds[
|
|
||||||
start_idx:end_idx
|
|
||||||
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
|
|
||||||
|
|
||||||
outputs = self.language_model.forward(
|
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
input_embeds=input_embeds,
|
image_data_embedding_func=self.get_image_feature,
|
||||||
|
language_model=self.language_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return outputs
|
return hs
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@@ -263,21 +249,34 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|||||||
weights_loader(param, loaded_weight)
|
weights_loader(param, loaded_weight)
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||||
return input_ids
|
helper = MultiModalityDataPaddingPatternImageTokens(
|
||||||
|
image_token_id=image_inputs.im_token_id
|
||||||
|
)
|
||||||
|
return helper.pad_input_tokens(input_ids, image_inputs)
|
||||||
|
|
||||||
def get_image_feature(self, image_input: MultimodalInputs):
|
def get_image_feature(self, items: List[MultimodalDataItem]):
|
||||||
pixel_values = image_input.pixel_values.type(
|
|
||||||
next(self.vision.parameters()).dtype
|
images_spatial_crop = torch.cat(
|
||||||
).to(device=next(self.vision.parameters()).device)
|
[item.image_spatial_crop for item in items], dim=0
|
||||||
image_feature = self.vision.forward_features(pixel_values)
|
)
|
||||||
|
|
||||||
|
assert images_spatial_crop.dim() == 3
|
||||||
|
|
||||||
|
# TODO: can it be batched ?
|
||||||
|
images_in_this_batch = []
|
||||||
|
for item in items:
|
||||||
|
assert item.pixel_values.dim() == 4
|
||||||
|
image_feature = self.vision.forward_features(
|
||||||
|
item.pixel_values.type(next(self.vision.parameters()).dtype).to(
|
||||||
|
device=next(self.vision.parameters()).device
|
||||||
|
)
|
||||||
|
)
|
||||||
images_embeds = self.projector(image_feature)
|
images_embeds = self.projector(image_feature)
|
||||||
_, hw, n_dim = images_embeds.shape
|
_, hw, n_dim = images_embeds.shape
|
||||||
h = w = int(hw**0.5)
|
h = w = int(hw**0.5)
|
||||||
tile_index = 0
|
tile_index = 0
|
||||||
images_in_this_batch = []
|
for jdx in range(item.image_spatial_crop.shape[1]):
|
||||||
images_spatial_crop = image_input.image_spatial_crop
|
num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
|
||||||
for jdx in range(images_spatial_crop.shape[1]):
|
|
||||||
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
|
|
||||||
if num_width_tiles == 0 or num_height_tiles == 0:
|
if num_width_tiles == 0 or num_height_tiles == 0:
|
||||||
break
|
break
|
||||||
num_tiles_in_image = num_width_tiles * num_height_tiles
|
num_tiles_in_image = num_width_tiles * num_height_tiles
|
||||||
@@ -300,7 +299,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|||||||
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
||||||
|
|
||||||
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
||||||
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
|
global_features = torch.cat(
|
||||||
|
[global_features, new_lines_in_global], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
# [h, w + 1, D] -> [h * (w + 1), D]
|
# [h, w + 1, D] -> [h * (w + 1), D]
|
||||||
global_features = global_features.view(-1, n_dim)
|
global_features = global_features.view(-1, n_dim)
|
||||||
|
|||||||
@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import (
|
from transformers import AutoModel, Gemma3Config, PreTrainedModel
|
||||||
AutoModel,
|
|
||||||
BatchFeature,
|
|
||||||
Gemma3Config,
|
|
||||||
Gemma3Processor,
|
|
||||||
PreTrainedModel,
|
|
||||||
)
|
|
||||||
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_processor
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
||||||
@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
|
|||||||
MultiModalityDataPaddingPatternTokenPairs,
|
MultiModalityDataPaddingPatternTokenPairs,
|
||||||
general_mm_embed_routine,
|
general_mm_embed_routine,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
MultimodalDataItem,
|
||||||
|
MultimodalInputs,
|
||||||
|
flatten_nested_list,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import (
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
@@ -274,17 +271,16 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
return self.language_model.get_attention_sliding_window_size()
|
return self.language_model.get_attention_sliding_window_size()
|
||||||
|
|
||||||
def get_image_feature(self, image_input: MultimodalInputs):
|
def get_image_feature(self, items: List[MultimodalDataItem]):
|
||||||
"""
|
"""
|
||||||
Projects the last hidden state from the vision model into language model space.
|
Projects the last hidden state from the vision model into language model space.
|
||||||
|
|
||||||
Args:
|
|
||||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
|
||||||
The tensors corresponding to the input images.
|
|
||||||
Returns:
|
Returns:
|
||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
pixel_values = image_input.pixel_values
|
pixel_values = torch.stack(
|
||||||
|
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
||||||
|
)
|
||||||
pixel_values = pixel_values.to("cuda")
|
pixel_values = pixel_values.to("cuda")
|
||||||
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
||||||
|
|
||||||
@@ -292,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
image_features = self.multi_modal_projector(vision_outputs)
|
image_features = self.multi_modal_projector(vision_outputs)
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
def embed_mm_inputs(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
image_input: MultimodalInputs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if input_ids is None:
|
|
||||||
raise ValueError("Unimplemented")
|
|
||||||
# boolean-masking image tokens
|
|
||||||
special_image_mask = torch.isin(
|
|
||||||
input_ids,
|
|
||||||
torch.tensor(image_input.pad_values, device=input_ids.device),
|
|
||||||
).unsqueeze(-1)
|
|
||||||
num_image_tokens_in_input_ids = special_image_mask.sum()
|
|
||||||
|
|
||||||
inputs_embeds = None
|
|
||||||
if num_image_tokens_in_input_ids == 0:
|
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
||||||
return inputs_embeds
|
|
||||||
else:
|
|
||||||
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
|
|
||||||
image_features = self.get_image_feature(image_input.pixel_values)
|
|
||||||
|
|
||||||
# print(f"image tokens from image embeddings: {image_features.numel()}")
|
|
||||||
num_image_tokens_in_embedding = (
|
|
||||||
image_features.shape[0] * image_features.shape[1]
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
|
|
||||||
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
|
|
||||||
image_features = image_features[:num_image, :]
|
|
||||||
logger.warning(
|
|
||||||
f"Number of images does not match number of special image tokens in the input text. "
|
|
||||||
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
|
|
||||||
"tokens from image embeddings."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Important: clamp after extracting original image boundaries
|
|
||||||
input_ids.clamp_(min=0, max=self.vocab_size - 1)
|
|
||||||
|
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
||||||
|
|
||||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
|
||||||
inputs_embeds.device
|
|
||||||
)
|
|
||||||
|
|
||||||
image_features = image_features.to(
|
|
||||||
inputs_embeds.device, inputs_embeds.dtype
|
|
||||||
)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(
|
|
||||||
special_image_mask, image_features
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -405,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
llm_input_ids = input_ids
|
llm_input_ids = input_ids
|
||||||
|
|
||||||
inputs_embeds = general_mm_embed_routine(
|
hs = general_mm_embed_routine(
|
||||||
input_ids=llm_input_ids,
|
input_ids=llm_input_ids,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
embed_tokens=self.get_input_embeddings(),
|
language_model=self.language_model,
|
||||||
mm_data_embedding_func=self.get_image_feature,
|
image_data_embedding_func=self.get_image_feature,
|
||||||
)
|
|
||||||
|
|
||||||
outputs = self.language_model(
|
|
||||||
input_ids=None,
|
|
||||||
positions=positions,
|
positions=positions,
|
||||||
forward_batch=forward_batch,
|
|
||||||
input_embeds=inputs_embeds,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return outputs
|
return hs
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
return self.language_model.tie_weights()
|
return self.language_model.tie_weights()
|
||||||
|
|||||||
@@ -428,6 +428,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, forward_batch)
|
return self.pooler(hidden_states, forward_batch)
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
def get_hidden_dim(self, module_name):
|
||||||
# return input_dim, output_dim
|
# return input_dim, output_dim
|
||||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from transformers import (
|
|||||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
|
||||||
from sglang.srt.mm_utils import (
|
from sglang.srt.mm_utils import (
|
||||||
get_anyres_image_grid_shape,
|
get_anyres_image_grid_shape,
|
||||||
unpad_image,
|
unpad_image,
|
||||||
@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.models.llama import LlamaForCausalLM
|
from sglang.srt.models.llama import LlamaForCausalLM
|
||||||
from sglang.srt.models.mistral import MistralForCausalLM
|
from sglang.srt.models.mistral import MistralForCausalLM
|
||||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
from sglang.srt.utils import add_prefix
|
from sglang.srt.utils import add_prefix, flatten_nested_list
|
||||||
|
|
||||||
|
|
||||||
class LlavaBaseForCausalLM(nn.Module):
|
class LlavaBaseForCausalLM(nn.Module):
|
||||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||||
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
|
image_sizes = flatten_nested_list(
|
||||||
|
[item.image_sizes for item in image_inputs.mm_items]
|
||||||
|
)
|
||||||
|
|
||||||
|
pad_values = [item.pad_value for item in image_inputs.mm_items]
|
||||||
|
|
||||||
# hardcode for spatial_unpad + anyres
|
# hardcode for spatial_unpad + anyres
|
||||||
if image_inputs.modalities is not None and (
|
if any(
|
||||||
"multi-images" in image_inputs.modalities
|
item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
|
||||||
or "video" in image_inputs.modalities
|
for item in image_inputs.mm_items
|
||||||
):
|
):
|
||||||
image_aspect_ratio = "pad"
|
image_aspect_ratio = "pad"
|
||||||
else:
|
else:
|
||||||
@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
new_image_feature_len = self.image_feature_len # multiimage
|
new_image_feature_len = self.image_feature_len # multi-image
|
||||||
|
|
||||||
height = width = self.num_patches_per_side
|
height = width = self.num_patches_per_side
|
||||||
if "anyres" in image_aspect_ratio:
|
if "anyres" in image_aspect_ratio:
|
||||||
@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||||
input_ids = (
|
input_ids = (
|
||||||
input_ids[:offset]
|
input_ids[:offset]
|
||||||
+ [pad_values[image_idx]] * new_image_feature_len
|
+ [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
|
||||||
+ input_ids[offset + 1 :]
|
+ input_ids[offset + 1 :]
|
||||||
)
|
)
|
||||||
offset_list.append(offset)
|
offset_list.append(offset)
|
||||||
@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
modalities_list = []
|
modalities_list = []
|
||||||
max_image_offset = []
|
max_image_offset = []
|
||||||
for im in image_inputs:
|
for im in image_inputs:
|
||||||
if im and im.modalities is not None:
|
if im:
|
||||||
modalities_list.extend(im.modalities)
|
modalities_list.extend([item.modality for item in im.mm_items])
|
||||||
if im and im.image_offsets:
|
if im and im.image_offsets:
|
||||||
max_image_offset.append(
|
max_image_offset.append(
|
||||||
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
|
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
|
||||||
@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
|
|
||||||
if need_vision.any():
|
if need_vision.any():
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
pixel_values = [
|
pixel_values = flatten_nested_list(
|
||||||
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
|
[
|
||||||
|
[item.pixel_values for item in image_inputs[i].mm_items]
|
||||||
|
for i in range(bs)
|
||||||
|
if need_vision[i]
|
||||||
]
|
]
|
||||||
|
)
|
||||||
image_sizes = [
|
image_sizes = [
|
||||||
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
|
flatten_nested_list(
|
||||||
|
[item.image_sizes for item in image_inputs[i].mm_items]
|
||||||
|
)
|
||||||
|
for i in range(bs)
|
||||||
|
if need_vision[i]
|
||||||
]
|
]
|
||||||
|
|
||||||
########## Encode Image ########
|
########## Encode Image ########
|
||||||
@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
new_image_features = []
|
new_image_features = []
|
||||||
height = width = self.num_patches_per_side
|
height = width = self.num_patches_per_side
|
||||||
for image_idx, image_feature in enumerate(image_features):
|
for image_idx, image_feature in enumerate(image_features):
|
||||||
if modalities_list[image_idx] == "image":
|
if modalities_list[image_idx] == Modality.IMAGE:
|
||||||
image_aspect_ratio = (
|
image_aspect_ratio = (
|
||||||
self.config.image_aspect_ratio
|
self.config.image_aspect_ratio
|
||||||
) # single image
|
) # single image
|
||||||
elif (
|
elif (
|
||||||
modalities_list[image_idx] == "multi-images"
|
modalities_list[image_idx] == Modality.MULTI_IMAGES
|
||||||
or modalities_list[image_idx] == "video"
|
or modalities_list[image_idx] == Modality.VIDEO
|
||||||
):
|
):
|
||||||
image_aspect_ratio = "pad" # multi image
|
image_aspect_ratio = "pad" # multi image
|
||||||
# image_aspect_ratio = (
|
# image_aspect_ratio = (
|
||||||
@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
if (
|
if (
|
||||||
image_feature.shape[0] > 1
|
image_feature.shape[0] > 1
|
||||||
and "anyres" in image_aspect_ratio
|
and "anyres" in image_aspect_ratio
|
||||||
and modalities_list[image_idx] == "image"
|
and modalities_list[image_idx] == Modality.IMAGE
|
||||||
):
|
):
|
||||||
base_image_feature = image_feature[0]
|
base_image_feature = image_feature[0]
|
||||||
image_feature = image_feature[1:]
|
image_feature = image_feature[1:]
|
||||||
@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
image_feature = image_feature.unsqueeze(0)
|
image_feature = image_feature.unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
if modalities_list[image_idx] == "video": # video
|
if modalities_list[image_idx] == Modality.VIDEO: # video
|
||||||
# 2x2 pooling
|
# 2x2 pooling
|
||||||
num_of_frames = image_feature.shape[0]
|
num_of_frames = image_feature.shape[0]
|
||||||
image_feature = image_feature.view(
|
image_feature = image_feature.view(
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
|
|||||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
from sglang.srt.managers.schedule_batch import MultimodalInputs, flatten_nested_list
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.llama import LlamaForCausalLM
|
from sglang.srt.models.llama import LlamaForCausalLM
|
||||||
@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||||
pad_values = image_inputs.pad_values
|
pad_values = [item.pad_value for item in image_inputs.mm_items]
|
||||||
new_image_feature_len = self.image_feature_len
|
new_image_feature_len = self.image_feature_len
|
||||||
|
|
||||||
pad_ids = pad_values * (
|
pad_ids = pad_values * (
|
||||||
@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
need_vision = start_positions <= np.array(max_image_offset)
|
need_vision = start_positions <= np.array(max_image_offset)
|
||||||
|
|
||||||
if need_vision.any():
|
if need_vision.any():
|
||||||
pixel_values = [
|
pixel_values = flatten_nested_list(
|
||||||
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
|
[
|
||||||
|
[item.pixel_values for item in image_inputs[i].mm_items]
|
||||||
|
for i in range(bs)
|
||||||
|
if need_vision[i]
|
||||||
]
|
]
|
||||||
|
)
|
||||||
image_offsets = [
|
image_offsets = [
|
||||||
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
|
flatten_nested_list(
|
||||||
|
[item.image_offsets for item in image_inputs[i].mm_items]
|
||||||
|
)
|
||||||
|
for i in range(bs)
|
||||||
|
if need_vision[i]
|
||||||
]
|
]
|
||||||
|
|
||||||
########## Encode Image ########
|
########## Encode Image ########
|
||||||
@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||||
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
||||||
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
||||||
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
"model.vision_tower.vision_tower": "vision_tower",
|
||||||
|
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||||
"model.image_newline": "language_model.model.image_newline",
|
"model.image_newline": "language_model.model.image_newline",
|
||||||
}
|
}
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|||||||
@@ -40,16 +40,19 @@ from transformers.models.whisper.modeling_whisper import (
|
|||||||
from sglang.srt.layers.quantization import QuantizationConfig
|
from sglang.srt.layers.quantization import QuantizationConfig
|
||||||
from sglang.srt.managers.mm_utils import (
|
from sglang.srt.managers.mm_utils import (
|
||||||
MultiModalityDataPaddingPatternTokenPairs,
|
MultiModalityDataPaddingPatternTokenPairs,
|
||||||
embed_mm_inputs,
|
general_mm_embed_routine,
|
||||||
get_multimodal_data_bounds,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
from sglang.srt.managers.schedule_batch import (
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
MultimodalDataItem,
|
||||||
|
MultimodalInputs,
|
||||||
|
flatten_nested_list,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.minicpmv import (
|
from sglang.srt.models.minicpmv import (
|
||||||
Idefics2VisionTransformer,
|
Idefics2VisionTransformer,
|
||||||
MiniCPMVBaseModel,
|
MiniCPMBaseModel,
|
||||||
Resampler2_5,
|
Resampler2_5,
|
||||||
)
|
)
|
||||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
@@ -1409,7 +1412,7 @@ class MultiModalProjector(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMO(MiniCPMVBaseModel):
|
class MiniCPMO(MiniCPMBaseModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@@ -1537,7 +1540,7 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|||||||
|
|
||||||
return input_lengths_after_cnn, input_lengths_after_pooling
|
return input_lengths_after_cnn, input_lengths_after_pooling
|
||||||
|
|
||||||
def get_audio_embedding_streaming(self, multimodal_input: MultimodalInputs):
|
def get_audio_embedding_streaming(self, items: List[MultimodalDataItem]):
|
||||||
r"""
|
r"""
|
||||||
Extract audio embeddings in a streaming manner using cached key-value pairs.
|
Extract audio embeddings in a streaming manner using cached key-value pairs.
|
||||||
|
|
||||||
@@ -1545,26 +1548,15 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|||||||
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
|
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
|
||||||
for streaming scenarios.
|
for streaming scenarios.
|
||||||
|
|
||||||
Args:
|
|
||||||
multimodal_input (dict):
|
|
||||||
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
|
|
||||||
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[torch.Tensor]]: audio embeddings
|
List[List[torch.Tensor]]: audio embeddings
|
||||||
"""
|
"""
|
||||||
# print("audio embedding")
|
wavforms = flatten_nested_list(
|
||||||
|
[item.audio_features for item in items if item.audio_features]
|
||||||
wavforms = (
|
|
||||||
[]
|
|
||||||
if multimodal_input.audio_features is None
|
|
||||||
else multimodal_input.audio_features
|
|
||||||
)
|
)
|
||||||
# list, [[x1, x2], [y1], [z1]]
|
# list, [[x1, x2], [y1], [z1]]
|
||||||
audio_feature_lens_raw = (
|
audio_feature_lens_raw = flatten_nested_list(
|
||||||
[]
|
[item.audio_feature_lens for item in items if item.audio_feature_lens]
|
||||||
if multimodal_input.audio_feature_lens is None
|
|
||||||
else multimodal_input.audio_feature_lens
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# exist audio
|
# exist audio
|
||||||
@@ -1650,7 +1642,7 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|||||||
ret[i, start:ending] = True
|
ret[i, start:ending] = True
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_audio_embedding(self, multimodal_input: MultimodalInputs, chunk_length=-1):
|
def get_audio_embedding(self, items: List[MultimodalDataItem], chunk_length=-1):
|
||||||
r"""
|
r"""
|
||||||
Extract full audio embeddings with optional chunk-based attention.
|
Extract full audio embeddings with optional chunk-based attention.
|
||||||
|
|
||||||
@@ -1659,31 +1651,25 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|||||||
not use key-value caching and is suitable for non-streaming inference.
|
not use key-value caching and is suitable for non-streaming inference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
multimodal_input (dict):
|
|
||||||
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
|
|
||||||
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
|
|
||||||
chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
|
chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
|
||||||
attention (>0) during embedding computation.
|
attention (>0) during embedding computation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[torch.Tensor]]: audio embeddings
|
List[List[torch.Tensor]]: audio embeddings
|
||||||
"""
|
"""
|
||||||
# print("audio embedding")
|
|
||||||
# (bs, 80, frames) or [], multi audios need filled in advance
|
# (bs, 80, frames) or [], multi audios need filled in advance
|
||||||
wavforms = (
|
wavforms = flatten_nested_list(
|
||||||
[]
|
[item.audio_features for item in items if item.audio_features]
|
||||||
if multimodal_input.audio_features is None
|
|
||||||
else multimodal_input.audio_features
|
|
||||||
)
|
)
|
||||||
# list, [[x1, x2], [y1], [z1]]
|
# list, [[x1, x2], [y1], [z1]]
|
||||||
audio_feature_lens_raw = (
|
audio_feature_lens_raw = flatten_nested_list(
|
||||||
[]
|
[item.audio_feature_lens for item in items if item.audio_feature_lens]
|
||||||
if multimodal_input.audio_feature_lens is None
|
|
||||||
else multimodal_input.audio_feature_lens
|
|
||||||
)
|
)
|
||||||
|
|
||||||
final_audio_embeds = []
|
final_audio_embeds = []
|
||||||
|
|
||||||
|
assert isinstance(wavforms, list)
|
||||||
|
assert isinstance(wavforms[0], torch.Tensor)
|
||||||
# exist audio
|
# exist audio
|
||||||
for wavform in wavforms:
|
for wavform in wavforms:
|
||||||
if len(wavform) > 0:
|
if len(wavform) > 0:
|
||||||
@@ -1757,86 +1743,46 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|||||||
final_audio_embeds.append(target_audio_embeds)
|
final_audio_embeds.append(target_audio_embeds)
|
||||||
return final_audio_embeds
|
return final_audio_embeds
|
||||||
|
|
||||||
|
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
|
embedding = self.get_omni_embedding(
|
||||||
|
items=items,
|
||||||
|
chunk_length=self.config.audio_chunk_length,
|
||||||
|
stream_input=False,
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
|
||||||
def get_omni_embedding(
|
def get_omni_embedding(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
items: List[MultimodalDataItem],
|
||||||
multimodal_input: MultimodalInputs,
|
|
||||||
input_embeds: torch.Tensor,
|
|
||||||
forward_mode: ForwardMode,
|
|
||||||
chunk_length=-1,
|
chunk_length=-1,
|
||||||
stream_input=False,
|
stream_input=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
multimodal_input:
|
|
||||||
input_embeds:
|
|
||||||
chunk_length: whisper use full attention or chunk attention
|
chunk_length: whisper use full attention or chunk attention
|
||||||
stream_input: use streaming audio embedding
|
stream_input: use streaming audio embedding
|
||||||
Returns:
|
Returns:
|
||||||
final embeddings with audio feature
|
final embeddings with audio feature
|
||||||
"""
|
"""
|
||||||
input_embeds = input_embeds.unsqueeze(0)
|
|
||||||
if not forward_mode.is_decode() and multimodal_input.contains_audio_inputs():
|
|
||||||
audio_bounds = get_multimodal_data_bounds(
|
|
||||||
input_ids=input_ids,
|
|
||||||
pad_values=multimodal_input.pad_values,
|
|
||||||
token_pairs=[
|
|
||||||
(multimodal_input.audio_start_id, multimodal_input.audio_end_id)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
if audio_bounds.numel() == 0:
|
|
||||||
input_embeds = input_embeds.squeeze(0)
|
|
||||||
# TODO
|
|
||||||
logger.warn("Unimplemented logic. Please try disabling chunked prefill")
|
|
||||||
return input_embeds
|
|
||||||
audio_bounds = audio_bounds.unsqueeze(0)
|
|
||||||
bs = len(input_embeds)
|
|
||||||
|
|
||||||
if stream_input:
|
if stream_input:
|
||||||
audio_embeddings = self.get_audio_embedding_streaming(multimodal_input)
|
audio_embeddings = self.get_audio_embedding_streaming(items)
|
||||||
else:
|
else:
|
||||||
audio_embeddings = self.get_audio_embedding(
|
audio_embeddings = self.get_audio_embedding(items, chunk_length)
|
||||||
multimodal_input, chunk_length
|
bs = len(audio_embeddings)
|
||||||
)
|
|
||||||
# batch size
|
# batch size
|
||||||
assert len(audio_embeddings) == len(input_embeds)
|
audio_embs = torch.cat(flatten_nested_list(audio_embeddings), dim=0)
|
||||||
if len(audio_embeddings) > 0:
|
|
||||||
if self.config.chunk_input:
|
|
||||||
for i in range(bs):
|
|
||||||
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
|
|
||||||
device=input_embeds.device, dtype=input_embeds.dtype
|
|
||||||
)
|
|
||||||
audio_start_pos = 0
|
|
||||||
for bound in audio_bounds[i]:
|
|
||||||
audio_len = bound[1] - bound[0] + 1
|
|
||||||
input_embeds[0, bound[0] : bound[1] + 1] = audio_embs[
|
|
||||||
audio_start_pos : audio_start_pos + audio_len, :
|
|
||||||
]
|
|
||||||
audio_start_pos += audio_len
|
|
||||||
else:
|
|
||||||
for i in range(bs):
|
|
||||||
audio_embs = audio_embeddings[i]
|
|
||||||
bounds = audio_bounds[i]
|
|
||||||
for embs, bound in zip(audio_embs, bounds):
|
|
||||||
audio_indices = torch.arange(
|
|
||||||
bound[0], bound[1], dtype=torch.long
|
|
||||||
).to(input_embeds.device)
|
|
||||||
|
|
||||||
if embs.shape[0] != len(audio_indices):
|
return audio_embs
|
||||||
raise ValueError(
|
|
||||||
f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
f"to input indices of length {len(audio_indices)}"
|
# list of tensors
|
||||||
)
|
pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
||||||
input_embeds[i, audio_indices] = embs.to(input_embeds.dtype)
|
tgt_sizes = torch.stack(
|
||||||
input_embeds = input_embeds.squeeze(0)
|
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
||||||
return input_embeds
|
)
|
||||||
|
assert len(pixel_values) == tgt_sizes.shape[0]
|
||||||
|
|
||||||
def get_image_features(
|
|
||||||
self,
|
|
||||||
image_inputs: MultimodalInputs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
pixel_values = image_inputs.pixel_values
|
|
||||||
tgt_sizes = image_inputs.tgt_sizes
|
|
||||||
device = self.vpm.embeddings.position_embedding.weight.device
|
device = self.vpm.embeddings.position_embedding.weight.device
|
||||||
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
||||||
all_pixel_values_lst = [
|
all_pixel_values_lst = [
|
||||||
@@ -1845,10 +1791,10 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|||||||
|
|
||||||
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
||||||
assert isinstance(max_patches, int)
|
assert isinstance(max_patches, int)
|
||||||
|
|
||||||
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
||||||
all_pixel_values_lst, batch_first=True, padding_value=0.0
|
all_pixel_values_lst, batch_first=True, padding_value=0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
B, L, _ = all_pixel_values.shape
|
B, L, _ = all_pixel_values.shape
|
||||||
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
||||||
patch_attn_mask = torch.zeros(
|
patch_attn_mask = torch.zeros(
|
||||||
@@ -1875,53 +1821,23 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = None
|
|
||||||
# TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once
|
|
||||||
if (
|
|
||||||
not forward_batch.forward_mode.is_decode()
|
|
||||||
and forward_batch.contains_image_inputs()
|
|
||||||
):
|
|
||||||
mm_inputs = forward_batch.merge_mm_inputs()
|
|
||||||
inputs_embeds = embed_mm_inputs(
|
|
||||||
mm_input=mm_inputs,
|
|
||||||
input_ids=input_ids,
|
|
||||||
input_embedding=self.get_input_embeddings(),
|
|
||||||
mm_data_embedding_func=self.get_image_features,
|
|
||||||
placeholder_token_ids=[mm_inputs.im_token_id] + mm_inputs.pad_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = input_ids.clamp(
|
|
||||||
min=0, max=self.get_input_embeddings().num_embeddings - 1
|
|
||||||
)
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.llm.get_input_embeddings(input_ids)
|
|
||||||
if (
|
|
||||||
not forward_batch.forward_mode.is_decode()
|
|
||||||
and self.config.init_audio
|
|
||||||
and forward_batch.contains_audio_inputs()
|
|
||||||
):
|
|
||||||
mm_input = forward_batch.merge_mm_inputs()
|
mm_input = forward_batch.merge_mm_inputs()
|
||||||
inputs_embeds = self.get_omni_embedding(
|
placeholder_token_ids = (
|
||||||
|
([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
|
||||||
|
if forward_batch.contains_mm_inputs()
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
hidden_states = general_mm_embed_routine(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
multimodal_input=mm_input,
|
|
||||||
input_embeds=inputs_embeds,
|
|
||||||
forward_mode=forward_batch.forward_mode,
|
|
||||||
chunk_length=self.config.audio_chunk_length,
|
|
||||||
stream_input=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
forward_batch.mm_inputs = None
|
|
||||||
|
|
||||||
hidden_states = self.llm.model(
|
|
||||||
input_ids=None,
|
|
||||||
positions=positions,
|
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
input_embeds=inputs_embeds,
|
language_model=self.llm,
|
||||||
)
|
image_data_embedding_func=self.get_image_feature,
|
||||||
|
audio_data_embedding_func=self.get_audio_feature,
|
||||||
return self.logits_processor(
|
placeholder_token_ids=placeholder_token_ids,
|
||||||
input_ids, hidden_states, self.llm.lm_head, forward_batch
|
positions=positions,
|
||||||
)
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -54,12 +54,12 @@ from sglang.srt.managers.mm_utils import (
|
|||||||
MultiModalityDataPaddingPatternTokenPairs,
|
MultiModalityDataPaddingPatternTokenPairs,
|
||||||
general_mm_embed_routine,
|
general_mm_embed_routine,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
||||||
from sglang.srt.utils import add_prefix
|
from sglang.srt.utils import add_prefix, flatten_nested_list
|
||||||
|
|
||||||
RawImageType = Union[Image.Image, torch.Tensor]
|
RawImageType = Union[Image.Image, torch.Tensor]
|
||||||
|
|
||||||
@@ -661,7 +661,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
|||||||
return tuple(int(x) for x in version_str.split("."))
|
return tuple(int(x) for x in version_str.split("."))
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMVBaseModel(nn.Module):
|
class MiniCPMBaseModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
The abstract class of MiniCPMV can only be inherited, but cannot be
|
The abstract class of MiniCPMV can only be inherited, but cannot be
|
||||||
instantiated.
|
instantiated.
|
||||||
@@ -853,7 +853,7 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
return vlm_embedding, vision_hidden_states
|
return vlm_embedding, vision_hidden_states
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Embedding:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
return self.llm.get_input_embedding()
|
return self.llm.get_input_embeddings()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -862,23 +862,14 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = general_mm_embed_routine(
|
hidden_states = general_mm_embed_routine(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
embed_tokens=self.get_input_embeddings(),
|
image_data_embedding_func=self.get_image_feature,
|
||||||
mm_data_embedding_func=self.get_image_features,
|
language_model=self.llm,
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.llm.model(
|
|
||||||
input_ids=None,
|
|
||||||
positions=positions,
|
positions=positions,
|
||||||
forward_batch=forward_batch,
|
|
||||||
input_embeds=inputs_embeds,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.logits_processor(
|
|
||||||
input_ids, hidden_states, self.llm.lm_head, forward_batch
|
|
||||||
)
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def init_llm(
|
def init_llm(
|
||||||
self,
|
self,
|
||||||
@@ -913,11 +904,11 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_image_features(self, image_inputs: MultimodalInputs) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMV2_6(MiniCPMVBaseModel):
|
class MiniCPMV2_6(MiniCPMBaseModel):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@@ -1023,14 +1014,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
|||||||
)
|
)
|
||||||
return vision_embedding
|
return vision_embedding
|
||||||
|
|
||||||
def get_image_features(
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
self,
|
|
||||||
image_inputs: MultimodalInputs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# list of tensors
|
# list of tensors
|
||||||
pixel_values = image_inputs.pixel_values
|
pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
||||||
|
tgt_sizes = torch.stack(
|
||||||
tgt_sizes = image_inputs.tgt_sizes
|
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
||||||
|
)
|
||||||
|
assert len(pixel_values) == tgt_sizes.shape[0]
|
||||||
|
|
||||||
device = self.vpm.embeddings.position_embedding.weight.device
|
device = self.vpm.embeddings.position_embedding.weight.device
|
||||||
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
||||||
@@ -1040,10 +1030,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
|||||||
|
|
||||||
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
||||||
assert isinstance(max_patches, int)
|
assert isinstance(max_patches, int)
|
||||||
|
|
||||||
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
||||||
all_pixel_values_lst, batch_first=True, padding_value=0.0
|
all_pixel_values_lst, batch_first=True, padding_value=0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
B, L, _ = all_pixel_values.shape
|
B, L, _ = all_pixel_values.shape
|
||||||
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
||||||
patch_attn_mask = torch.zeros(
|
patch_attn_mask = torch.zeros(
|
||||||
|
|||||||
@@ -796,14 +796,16 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
self.logits_processor = LogitsProcessor(config.text_config)
|
self.logits_processor = LogitsProcessor(config.text_config)
|
||||||
self.capture_mode = False
|
self.capture_mode = False
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||||
pixel_values = image_inputs.pixel_values
|
pixel_values = torch.cat(
|
||||||
pad_values = image_inputs.pad_values
|
[item.pixel_values for item in mm_inputs.mm_items], dim=0
|
||||||
|
)
|
||||||
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
||||||
|
|
||||||
num_concurrent_media, num_tiles = pixel_values.shape[1:3]
|
num_concurrent_media, num_tiles = pixel_values.shape[1:3]
|
||||||
num_patches = self.vision_model.num_patches
|
num_patches = self.vision_model.num_patches
|
||||||
image_len = num_concurrent_media * num_tiles * num_patches
|
image_len = num_concurrent_media * num_tiles * num_patches
|
||||||
image_inputs.num_image_tokens = image_len
|
mm_inputs.num_image_tokens = image_len
|
||||||
|
|
||||||
pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))
|
pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))
|
||||||
|
|
||||||
@@ -815,10 +817,16 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
|
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
|
||||||
max_num_images = max_num_tiles = bs = 0
|
max_num_images = max_num_tiles = bs = 0
|
||||||
for i, im in enumerate(forward_batch.mm_inputs):
|
for i, mm_input in enumerate(forward_batch.mm_inputs):
|
||||||
if not forward_batch.encoder_cached[i] and im is not None:
|
|
||||||
max_num_images = max(max_num_images, im.pixel_values.shape[1])
|
if not forward_batch.encoder_cached[i] and mm_input is not None:
|
||||||
max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2])
|
pixel_values = torch.cat(
|
||||||
|
[item.pixel_values for item in mm_input.mm_items], dim=0
|
||||||
|
)
|
||||||
|
# max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
|
||||||
|
max_num_images = max(max_num_images, pixel_values.shape[1])
|
||||||
|
|
||||||
|
max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
|
||||||
bs += 1
|
bs += 1
|
||||||
|
|
||||||
if max_num_images * max_num_tiles * bs == 0:
|
if max_num_images * max_num_tiles * bs == 0:
|
||||||
@@ -842,17 +850,24 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
i = 0
|
i = 0
|
||||||
encoder_lens_need = []
|
encoder_lens_need = []
|
||||||
for k, im in enumerate(forward_batch.mm_inputs):
|
|
||||||
if forward_batch.encoder_cached[k] or im is None:
|
for k, mm_input in enumerate(forward_batch.mm_inputs):
|
||||||
|
if forward_batch.encoder_cached[k] or mm_input is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
encoder_lens_need.append(forward_batch.encoder_lens[k])
|
encoder_lens_need.append(forward_batch.encoder_lens[k])
|
||||||
for j in range(im.pixel_values.shape[1]):
|
pixel_values = torch.cat(
|
||||||
img = im.pixel_values[0, j]
|
[item.pixel_values for item in mm_input.mm_items], dim=0
|
||||||
|
)
|
||||||
|
for j in range(pixel_values.shape[1]):
|
||||||
|
img = pixel_values[0, j]
|
||||||
num_tiles = img.shape[0]
|
num_tiles = img.shape[0]
|
||||||
batched_images[i, j, :num_tiles] = img
|
batched_images[i, j, :num_tiles] = img
|
||||||
batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j]
|
batched_ar_ids[i, j] = mm_input.mm_items[0].aspect_ratio_id[0, j]
|
||||||
batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j]
|
|
||||||
|
batched_ar_mask[i, j, :num_tiles] = mm_input.mm_items[
|
||||||
|
0
|
||||||
|
].aspect_ratio_mask[0, j]
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need
|
return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need
|
||||||
|
|||||||
@@ -261,11 +261,14 @@ class Qwen2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
if hasattr(self.config, "scale_emb"):
|
if hasattr(self.config, "scale_emb"):
|
||||||
return self.embed_tokens(input_ids) * self.config.scale_emb
|
return self.get_input_embeddings()(input_ids) * self.config.scale_emb
|
||||||
else:
|
else:
|
||||||
return self.embed_tokens(input_ids)
|
return self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -358,10 +361,10 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.get_input_embeddings(input_ids)
|
return self.model.get_input_embedding(input_ids)
|
||||||
|
|
||||||
def get_input_embedding(self) -> nn.Embedding:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
return self.model.embed_tokens
|
return self.model.embed_tokens
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@@ -30,22 +30,13 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from transformers import AutoModel, Qwen2VLConfig
|
from transformers import Qwen2VLConfig
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
||||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
|
||||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||||
Qwen2_5_VLConfig,
|
|
||||||
Qwen2_5_VLVisionConfig,
|
Qwen2_5_VLVisionConfig,
|
||||||
)
|
)
|
||||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.hf_transformers_utils import get_processor
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
from sglang.srt.layers.attention.vision import VisionAttention
|
from sglang.srt.layers.attention.vision import VisionAttention
|
||||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||||
@@ -57,7 +48,7 @@ from sglang.srt.managers.mm_utils import (
|
|||||||
MultiModalityDataPaddingPatternTokenPairs,
|
MultiModalityDataPaddingPatternTokenPairs,
|
||||||
general_mm_embed_routine,
|
general_mm_embed_routine,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2 import Qwen2Model
|
from sglang.srt.models.qwen2 import Qwen2Model
|
||||||
@@ -513,19 +504,24 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||||
# Get all special token IDs
|
# Get all special token IDs
|
||||||
im_start_id: int = image_inputs.im_start_id
|
im_start_id: int = mm_inputs.im_start_id
|
||||||
im_end_id: int = image_inputs.im_end_id
|
im_end_id: int = mm_inputs.im_end_id
|
||||||
|
|
||||||
media_token_pairs = [(im_start_id, im_end_id)]
|
media_token_pairs = [(im_start_id, im_end_id)]
|
||||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||||
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||||
|
|
||||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
|
# in qwen-vl, last dim is the same
|
||||||
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||||
pixel_values = image_input.pixel_values.type(self.visual.dtype)
|
self.visual.dtype
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
|
)
|
||||||
|
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
|
||||||
|
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||||
|
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||||
@@ -570,18 +566,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
f"(3, seq_len) positions, but got {positions.size()}"
|
f"(3, seq_len) positions, but got {positions.size()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs_embeds = general_mm_embed_routine(
|
hidden_states = general_mm_embed_routine(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
embed_tokens=self.get_input_embeddings(),
|
language_model=self.model,
|
||||||
mm_data_embedding_func=self.get_image_feature,
|
image_data_embedding_func=self.get_image_feature,
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids=None,
|
|
||||||
positions=positions,
|
positions=positions,
|
||||||
forward_batch=forward_batch,
|
|
||||||
input_embeds=inputs_embeds,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not get_embedding:
|
if not get_embedding:
|
||||||
@@ -594,9 +584,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
("qkv_proj", "q_proj", "q"),
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
("qkv_proj", "k_proj", "k"),
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
("qkv_proj", "v_proj", "v"),
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import (
|
|||||||
MultiModalityDataPaddingPatternTokenPairs,
|
MultiModalityDataPaddingPatternTokenPairs,
|
||||||
general_mm_embed_routine,
|
general_mm_embed_routine,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2 import Qwen2Model
|
from sglang.srt.models.qwen2 import Qwen2Model
|
||||||
@@ -472,18 +472,24 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
||||||
# add replaced padding by unique image hash
|
# add replaced padding by unique image hash
|
||||||
def pad_input_ids(self, input_ids: List[int], multi_modal_inputs: MultimodalInputs):
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||||
# Get all special token IDs
|
# Get all special token IDs
|
||||||
im_start_id: int = multi_modal_inputs.im_start_id
|
im_start_id: int = mm_inputs.im_start_id
|
||||||
im_end_id: int = multi_modal_inputs.im_end_id
|
im_end_id: int = mm_inputs.im_end_id
|
||||||
|
|
||||||
media_token_pairs = [(im_start_id, im_end_id)]
|
media_token_pairs = [(im_start_id, im_end_id)]
|
||||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||||
return pattern.pad_input_tokens(input_ids, multi_modal_inputs)
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||||
|
|
||||||
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
pixel_values = image_input.pixel_values.type(self.visual.dtype)
|
# in qwen-vl, last dim is the same
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
|
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||||
|
self.visual.dtype
|
||||||
|
)
|
||||||
|
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
|
||||||
|
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||||
|
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||||
@@ -527,27 +533,20 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
"multimodal section rotary embedding requires "
|
"multimodal section rotary embedding requires "
|
||||||
f"(3, seq_len) positions, but got {positions.size()}"
|
f"(3, seq_len) positions, but got {positions.size()}"
|
||||||
)
|
)
|
||||||
|
hidden_states = general_mm_embed_routine(
|
||||||
inputs_embeds = general_mm_embed_routine(
|
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
embed_tokens=self.get_input_embeddings(),
|
language_model=self.model,
|
||||||
mm_data_embedding_func=self.get_image_feature,
|
image_data_embedding_func=self.get_image_feature,
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids=None,
|
|
||||||
positions=positions,
|
positions=positions,
|
||||||
forward_batch=forward_batch,
|
|
||||||
input_embeds=inputs_embeds,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not get_embedding:
|
if get_embedding:
|
||||||
|
return self.pooler(hidden_states, forward_batch)
|
||||||
|
else:
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return self.pooler(hidden_states, forward_batch)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -897,6 +897,7 @@ def v1_chat_generate_request(
|
|||||||
request_ids: List[str] = None,
|
request_ids: List[str] = None,
|
||||||
):
|
):
|
||||||
input_ids = []
|
input_ids = []
|
||||||
|
prompts = []
|
||||||
sampling_params_list = []
|
sampling_params_list = []
|
||||||
image_data_list = []
|
image_data_list = []
|
||||||
audio_data_list = []
|
audio_data_list = []
|
||||||
@@ -916,6 +917,7 @@ def v1_chat_generate_request(
|
|||||||
# - audio_data: None or a list of audio strings (URLs).
|
# - audio_data: None or a list of audio strings (URLs).
|
||||||
# None skips any image processing in GenerateReqInput.
|
# None skips any image processing in GenerateReqInput.
|
||||||
strict_tag = None
|
strict_tag = None
|
||||||
|
prompt = ""
|
||||||
if not isinstance(request.messages, str):
|
if not isinstance(request.messages, str):
|
||||||
# Apply chat template and its stop strings.
|
# Apply chat template and its stop strings.
|
||||||
tools = None
|
tools = None
|
||||||
@@ -1005,11 +1007,13 @@ def v1_chat_generate_request(
|
|||||||
image_data = None
|
image_data = None
|
||||||
audio_data = None
|
audio_data = None
|
||||||
modalities = []
|
modalities = []
|
||||||
|
prompt = request.messages
|
||||||
input_ids.append(prompt_ids)
|
input_ids.append(prompt_ids)
|
||||||
return_logprobs.append(request.logprobs)
|
return_logprobs.append(request.logprobs)
|
||||||
logprob_start_lens.append(-1)
|
logprob_start_lens.append(-1)
|
||||||
top_logprobs_nums.append(request.top_logprobs or 0)
|
top_logprobs_nums.append(request.top_logprobs or 0)
|
||||||
lora_paths.append(request.lora_path)
|
lora_paths.append(request.lora_path)
|
||||||
|
prompts.append(prompt)
|
||||||
|
|
||||||
sampling_params = {
|
sampling_params = {
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
@@ -1063,6 +1067,10 @@ def v1_chat_generate_request(
|
|||||||
audio_data_list.append(audio_data)
|
audio_data_list.append(audio_data)
|
||||||
modalities_list.append(modalities)
|
modalities_list.append(modalities)
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
|
if tokenizer_manager.model_config.is_multimodal:
|
||||||
|
# processor will need text input
|
||||||
|
prompt_kwargs = {"text": prompts[0]}
|
||||||
|
else:
|
||||||
if isinstance(input_ids[0], str):
|
if isinstance(input_ids[0], str):
|
||||||
prompt_kwargs = {"text": input_ids[0]}
|
prompt_kwargs = {"text": input_ids[0]}
|
||||||
else:
|
else:
|
||||||
@@ -1075,6 +1083,10 @@ def v1_chat_generate_request(
|
|||||||
top_logprobs_nums = top_logprobs_nums[0]
|
top_logprobs_nums = top_logprobs_nums[0]
|
||||||
modalities_list = modalities_list[0]
|
modalities_list = modalities_list[0]
|
||||||
lora_paths = lora_paths[0]
|
lora_paths = lora_paths[0]
|
||||||
|
else:
|
||||||
|
if tokenizer_manager.model_config.is_multimodal:
|
||||||
|
# processor will need text input
|
||||||
|
prompt_kwargs = {"text": prompts}
|
||||||
else:
|
else:
|
||||||
if isinstance(input_ids[0], str):
|
if isinstance(input_ids[0], str):
|
||||||
prompt_kwargs = {"text": input_ids}
|
prompt_kwargs = {"text": input_ids}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Common utilities."""
|
"""Common utilities."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import builtins
|
import builtins
|
||||||
import ctypes
|
import ctypes
|
||||||
@@ -54,6 +53,7 @@ import torch.distributed
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import triton
|
import triton
|
||||||
import zmq
|
import zmq
|
||||||
|
from decord import VideoReader, cpu
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -513,13 +513,18 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from scipy.signal import resample
|
from scipy.signal import resample
|
||||||
|
|
||||||
# print(f"loading {audio_file}")
|
|
||||||
# Load audio data
|
# Load audio data
|
||||||
if isinstance(audio_file, bytes):
|
if isinstance(audio_file, bytes):
|
||||||
audio, original_sr = sf.read(BytesIO(audio_file))
|
audio, original_sr = sf.read(BytesIO(audio_file))
|
||||||
elif audio_file.startswith("data:"):
|
elif audio_file.startswith("data:"):
|
||||||
audio_file = audio_file.split(",")[1]
|
audio_file = audio_file.split(",")[1]
|
||||||
audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
|
audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
|
||||||
|
elif audio_file.startswith("http://") or audio_file.startswith("https://"):
|
||||||
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
|
||||||
|
response = requests.get(audio_file, stream=True, timeout=timeout)
|
||||||
|
audio_file = BytesIO(response.content)
|
||||||
|
response.close()
|
||||||
|
audio, original_sr = sf.read(audio_file)
|
||||||
elif isinstance(audio_file, str):
|
elif isinstance(audio_file, str):
|
||||||
audio, original_sr = sf.read(audio_file)
|
audio, original_sr = sf.read(audio_file)
|
||||||
else:
|
else:
|
||||||
@@ -537,6 +542,30 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
|
|||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def encode_video(video_path, frame_count_limit=None):
|
||||||
|
if not os.path.exists(video_path):
|
||||||
|
logger.error(f"Video {video_path} does not exist")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if frame_count_limit == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def uniform_sample(l, n):
|
||||||
|
gap = len(l) / n
|
||||||
|
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||||
|
return [l[i] for i in idxs]
|
||||||
|
|
||||||
|
vr = VideoReader(video_path, ctx=cpu(0))
|
||||||
|
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||||
|
frame_indices = [i for i in range(0, len(vr), sample_fps)]
|
||||||
|
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
|
||||||
|
frame_indices = uniform_sample(frame_indices, frame_count_limit)
|
||||||
|
|
||||||
|
frames = vr.get_batch(frame_indices).asnumpy()
|
||||||
|
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
|
def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
|
||||||
image = image_size = None
|
image = image_size = None
|
||||||
|
|
||||||
@@ -1796,3 +1825,12 @@ def retry(
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_nested_list(nested_list):
|
||||||
|
if isinstance(nested_list, list):
|
||||||
|
return [
|
||||||
|
item for sublist in nested_list for item in flatten_nested_list(sublist)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [nested_list]
|
||||||
|
|||||||
@@ -155,9 +155,7 @@ class TestOpenAIVisionServer(CustomTestCase):
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
|
|
||||||
},
|
|
||||||
"modalities": "multi-images",
|
"modalities": "multi-images",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -399,14 +397,14 @@ class TestOpenAIVisionServer(CustomTestCase):
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": prompt,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"type": "audio_url",
|
"type": "audio_url",
|
||||||
"audio_url": {"url": f"{audio_file_name}"},
|
"audio_url": {"url": f"{audio_file_name}"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@@ -14,7 +15,11 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
|||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.conversation import generate_chat_conv
|
from sglang.srt.conversation import generate_chat_conv
|
||||||
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
Modality,
|
||||||
|
MultimodalDataItem,
|
||||||
|
MultimodalInputs,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -195,14 +200,35 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
|||||||
# sglang
|
# sglang
|
||||||
model = self.get_sglang_model()
|
model = self.get_sglang_model()
|
||||||
input_ids = inputs["input_ids"].to(self.device).flatten()
|
input_ids = inputs["input_ids"].to(self.device).flatten()
|
||||||
|
|
||||||
|
pixel_values = inputs["pixel_values"]
|
||||||
|
tgt_sizes = inputs["tgt_sizes"]
|
||||||
|
pixel_values_flat: List[torch.Tensor] = []
|
||||||
|
tgt_sizes_flat: List[torch.Tensor] = []
|
||||||
|
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
|
||||||
|
# per image
|
||||||
|
if len(pixel_b) != len(tgt_b):
|
||||||
|
raise ValueError(
|
||||||
|
"Inconsistent N lengths, found: "
|
||||||
|
f"{len(pixel_b)} vs {len(tgt_b)}"
|
||||||
|
)
|
||||||
|
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
||||||
|
pixel_values_flat += [pixel_n]
|
||||||
|
tgt_sizes_flat += [tgt_n]
|
||||||
sglang_output = embed_mm_inputs(
|
sglang_output = embed_mm_inputs(
|
||||||
mm_input=MultimodalInputs(
|
mm_inputs=MultimodalInputs(
|
||||||
pixel_values=inputs["pixel_values"][0],
|
mm_items=[
|
||||||
tgt_sizes=inputs["tgt_sizes"][0],
|
MultimodalDataItem(
|
||||||
|
pixel_values=pixel_values_flat,
|
||||||
|
tgt_size=tgt_sizes_flat,
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
pad_value=self.processor.tokenizer.unk_token_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
),
|
),
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
input_embedding=model.get_input_embeddings(),
|
input_embedding=model.get_input_embeddings(),
|
||||||
mm_data_embedding_func=model.get_image_features,
|
image_data_embedding_func=model.get_image_feature,
|
||||||
placeholder_token_ids=[
|
placeholder_token_ids=[
|
||||||
self.processor.tokenizer.unk_token_id,
|
self.processor.tokenizer.unk_token_id,
|
||||||
],
|
],
|
||||||
|
|||||||
Reference in New Issue
Block a user