refactor: minor refactors regarding multimodal processing (#6187)
This commit is contained in:
@@ -22,7 +22,11 @@ from typing import List, Optional, Set, Union
|
|||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
from sglang.srt.hf_transformers_utils import (
|
||||||
|
get_config,
|
||||||
|
get_context_length,
|
||||||
|
get_hf_text_config,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||||
@@ -209,7 +213,13 @@ class ModelConfig:
|
|||||||
|
|
||||||
# Cache attributes
|
# Cache attributes
|
||||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||||
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
|
||||||
|
config = self.hf_config
|
||||||
|
|
||||||
|
# multimodal
|
||||||
|
self.image_token_id = getattr(config, "image_token_id", None) or getattr(
|
||||||
|
config, "image_token_index", None
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
|
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
|
||||||
@@ -423,31 +433,6 @@ class ModelConfig:
|
|||||||
self.model_path = client.get_local_dir()
|
self.model_path = client.get_local_dir()
|
||||||
|
|
||||||
|
|
||||||
def get_hf_text_config(config: PretrainedConfig):
|
|
||||||
"""Get the "sub" config relevant to llm for multi modal models.
|
|
||||||
No op for pure text models.
|
|
||||||
"""
|
|
||||||
class_name = config.architectures[0]
|
|
||||||
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
|
||||||
# We support non-hf version of llava models, so we do not want to
|
|
||||||
# read the wrong values from the unused default text_config.
|
|
||||||
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
|
|
||||||
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
|
|
||||||
setattr(config, "torch_dtype", torch.float16)
|
|
||||||
return config
|
|
||||||
|
|
||||||
if hasattr(config, "text_config"):
|
|
||||||
# The code operates under the assumption that text_config should have
|
|
||||||
# `num_attention_heads` (among others). Assert here to fail early
|
|
||||||
# if transformers config doesn't align with this assumption.
|
|
||||||
assert hasattr(config.text_config, "num_attention_heads")
|
|
||||||
return config.text_config
|
|
||||||
if hasattr(config, "language_config"):
|
|
||||||
return config.language_config
|
|
||||||
else:
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
"half": torch.float16,
|
"half": torch.float16,
|
||||||
@@ -537,6 +522,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
|||||||
|
|
||||||
|
|
||||||
multimodal_model_archs = [
|
multimodal_model_archs = [
|
||||||
|
"CLIPModel",
|
||||||
"DeepseekVL2ForCausalLM",
|
"DeepseekVL2ForCausalLM",
|
||||||
"Gemma3ForConditionalGeneration",
|
"Gemma3ForConditionalGeneration",
|
||||||
"Grok1VForCausalLM",
|
"Grok1VForCausalLM",
|
||||||
@@ -554,7 +540,6 @@ multimodal_model_archs = [
|
|||||||
"MllamaForConditionalGeneration",
|
"MllamaForConditionalGeneration",
|
||||||
"Qwen2VLForConditionalGeneration",
|
"Qwen2VLForConditionalGeneration",
|
||||||
"Qwen2_5_VLForConditionalGeneration",
|
"Qwen2_5_VLForConditionalGeneration",
|
||||||
"CLIPModel",
|
|
||||||
"KimiVLForConditionalGeneration",
|
"KimiVLForConditionalGeneration",
|
||||||
"InternVLChatModel",
|
"InternVLChatModel",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import warnings
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Type, Union
|
from typing import Dict, Optional, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@@ -65,6 +66,43 @@ def download_from_hf(model_path: str):
|
|||||||
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
|
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_text_config(config: PretrainedConfig):
|
||||||
|
"""Get the "sub" config relevant to llm for multi modal models.
|
||||||
|
No op for pure text models.
|
||||||
|
"""
|
||||||
|
if config.architectures is not None:
|
||||||
|
class_name = config.architectures[0]
|
||||||
|
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
||||||
|
# We support non-hf version of llava models, so we do not want to
|
||||||
|
# read the wrong values from the unused default text_config.
|
||||||
|
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
|
||||||
|
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
|
||||||
|
setattr(config, "torch_dtype", torch.float16)
|
||||||
|
return config
|
||||||
|
|
||||||
|
if hasattr(config, "text_config"):
|
||||||
|
# The code operates under the assumption that text_config should have
|
||||||
|
# `num_attention_heads` (among others). Assert here to fail early
|
||||||
|
# if transformers config doesn't align with this assumption.
|
||||||
|
assert hasattr(config.text_config, "num_attention_heads")
|
||||||
|
return config.text_config
|
||||||
|
if hasattr(config, "language_config"):
|
||||||
|
return config.language_config
|
||||||
|
if hasattr(config, "thinker_config"):
|
||||||
|
# qwen2.5 omni
|
||||||
|
thinker_config = config.thinker_config
|
||||||
|
if hasattr(thinker_config, "text_config"):
|
||||||
|
setattr(
|
||||||
|
thinker_config.text_config,
|
||||||
|
"torch_dtype",
|
||||||
|
getattr(thinker_config, "torch_dtype", None),
|
||||||
|
)
|
||||||
|
return thinker_config.text_config
|
||||||
|
return thinker_config
|
||||||
|
else:
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
model: str,
|
model: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
@@ -80,13 +118,12 @@ def get_config(
|
|||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||||
)
|
)
|
||||||
|
text_config = get_hf_text_config(config=config)
|
||||||
|
|
||||||
# FIXME: Pour contents of janus-pro's langauge_config to first-level
|
if isinstance(model, str) and text_config is not None:
|
||||||
if isinstance(model, str) and model.lower().startswith("deepseek-ai/janus-pro"):
|
for key, val in text_config.__dict__.items():
|
||||||
assert hasattr(config, "language_config")
|
if not hasattr(config, key) and getattr(text_config, key, None) is not None:
|
||||||
for key, val in config.language_config.__dict__.items():
|
setattr(config, key, val)
|
||||||
setattr(config, key, val)
|
|
||||||
setattr(config, "architectures", ["MultiModalityCausalLM"])
|
|
||||||
|
|
||||||
if config.model_type in _CONFIG_REGISTRY:
|
if config.model_type in _CONFIG_REGISTRY:
|
||||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||||
@@ -99,6 +136,9 @@ def get_config(
|
|||||||
if not hasattr(config, key):
|
if not hasattr(config, key):
|
||||||
setattr(config, key, val)
|
setattr(config, key, val)
|
||||||
|
|
||||||
|
if config.model_type == "multi_modality":
|
||||||
|
config.update({"architectures": ["MultiModalityCausalLM"]})
|
||||||
|
|
||||||
if model_override_args:
|
if model_override_args:
|
||||||
config.update(model_override_args)
|
config.update(model_override_args)
|
||||||
|
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
flatten_batch: bool = False,
|
flatten_batch: bool = False,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`.
|
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||||
Args:
|
Args:
|
||||||
s: sequence length
|
s: sequence length
|
||||||
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
|
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
|
||||||
|
|||||||
@@ -22,13 +22,15 @@ from dataclasses import dataclass, field
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from sglang.srt.mm_utils import has_valid_data
|
||||||
|
|
||||||
# handle serialization of Image for pydantic
|
# handle serialization of Image for pydantic
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
else:
|
else:
|
||||||
Image = Any
|
Image = Any
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason, flatten_nested_list
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
@@ -104,6 +106,9 @@ class GenerateReqInput:
|
|||||||
bootstrap_port: Optional[Union[List[int], int]] = None
|
bootstrap_port: Optional[Union[List[int], int]] = None
|
||||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||||
|
|
||||||
|
def contains_mm_input(self) -> bool:
|
||||||
|
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
"""
|
"""
|
||||||
Normalize the batch size and arguments for the request.
|
Normalize the batch size and arguments for the request.
|
||||||
@@ -487,6 +492,9 @@ class EmbeddingReqInput:
|
|||||||
# The modalities of the image data [image, multi-images, video]
|
# The modalities of the image data [image, multi-images, video]
|
||||||
modalities: Optional[List[str]] = None
|
modalities: Optional[List[str]] = None
|
||||||
|
|
||||||
|
def contains_mm_input(self) -> bool:
|
||||||
|
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
# at least one of text, input_ids, or image should be provided
|
# at least one of text, input_ids, or image should be provided
|
||||||
if self.text is None and self.input_ids is None and self.image_data is None:
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
Multi-modality utils
|
Multi-modality utils
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
@@ -41,11 +42,26 @@ class MultiModalityDataPaddingPattern:
|
|||||||
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
||||||
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
||||||
|
|
||||||
|
The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value
|
||||||
|
|
||||||
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_token_pairs: Optional[List[Tuple[int, int]]],
|
||||||
|
data_start_token_ids: Optional[List[int]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_start_token_ids marks the start of a single multimodal data
|
||||||
|
See Minicpmo's slice_start_id for example
|
||||||
|
"""
|
||||||
self.data_token_id_pairs = data_token_pairs
|
self.data_token_id_pairs = data_token_pairs
|
||||||
|
self.data_start_token_ids = data_start_token_ids or [
|
||||||
|
s for s, _e in data_token_pairs
|
||||||
|
]
|
||||||
|
|
||||||
def pad_input_tokens(
|
def pad_input_tokens(
|
||||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||||
@@ -79,7 +95,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|||||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||||
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
||||||
|
|
||||||
if input_ids[start_idx] in start_token_ids:
|
if input_ids[start_idx] in self.data_start_token_ids:
|
||||||
data_idx += 1
|
data_idx += 1
|
||||||
mm_inputs.data_offsets += [start_idx]
|
mm_inputs.data_offsets += [start_idx]
|
||||||
|
|
||||||
@@ -170,7 +186,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|||||||
output_ids_tensor[start_idx:end_idx] = pad_value
|
output_ids_tensor[start_idx:end_idx] = pad_value
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Skipping region {i} due to None pad_value.")
|
logger.warning(f"Skipping region {i} due to None pad_value.")
|
||||||
|
|
||||||
return output_ids_tensor.tolist()
|
return output_ids_tensor.tolist()
|
||||||
|
|
||||||
|
|
||||||
@@ -202,7 +217,7 @@ def get_embedding_and_mask(
|
|||||||
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
||||||
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Number of tokens in multimodal embedding does not match those in the input text."
|
f"Number of tokens in multimodal embedding does not match those in the input text. "
|
||||||
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
||||||
"tokens from multimodal embeddings."
|
"tokens from multimodal embeddings."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,9 +36,21 @@ class BaseMultiModalProcessorOutput:
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MultimodalSpecialTokens:
|
class MultimodalSpecialTokens:
|
||||||
image_token: Optional[str] = None
|
image_token: Optional[Union[int, str, List[str]]] = None
|
||||||
video_token: Optional[str] = None
|
video_token: Optional[Union[int, str, List[str]]] = None
|
||||||
audio_token: Optional[str] = None
|
audio_token: Optional[Union[int, str, List[str]]] = None
|
||||||
|
|
||||||
|
def convert_to_str(self, token: Union[str, int], processor) -> str:
|
||||||
|
if token is None:
|
||||||
|
return token
|
||||||
|
if isinstance(token, str):
|
||||||
|
return token
|
||||||
|
return processor.tokenizer.convert_ids_to_tokens([token])[0]
|
||||||
|
|
||||||
|
def convert_to_strs(self, processor):
|
||||||
|
self.image_token = self.convert_to_str(self.image_token, processor)
|
||||||
|
self.video_token = self.convert_to_str(self.video_token, processor)
|
||||||
|
self.audio_token = self.convert_to_str(self.audio_token, processor)
|
||||||
|
|
||||||
image_token_regex: Optional[re.Pattern] = None
|
image_token_regex: Optional[re.Pattern] = None
|
||||||
video_token_regex: Optional[re.Pattern] = None
|
video_token_regex: Optional[re.Pattern] = None
|
||||||
@@ -74,6 +86,7 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
self.hf_config = hf_config
|
self.hf_config = hf_config
|
||||||
self._processor = _processor
|
self._processor = _processor
|
||||||
|
self.arch = hf_config.architectures[0]
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
# FIXME: not accurate, model and image specific
|
# FIXME: not accurate, model and image specific
|
||||||
self.NUM_TOKEN_PER_FRAME = 330
|
self.NUM_TOKEN_PER_FRAME = 330
|
||||||
@@ -260,19 +273,10 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
"""
|
"""
|
||||||
if not return_text:
|
if not return_text:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
if image_data is None:
|
if image_data is None:
|
||||||
image_data = []
|
image_data = []
|
||||||
if isinstance(multimodal_tokens.image_token, int):
|
|
||||||
multimodal_tokens.image_token = re.compile(
|
multimodal_tokens.convert_to_strs(self._processor)
|
||||||
re.escape(
|
|
||||||
self._processor.tokenizer.convert_ids_to_tokens(
|
|
||||||
multimodal_tokens.image_token
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
multimodal_tokens.image_token = multimodal_tokens.image_token
|
|
||||||
multimodal_tokens_pattern = multimodal_tokens.collect()
|
multimodal_tokens_pattern = multimodal_tokens.collect()
|
||||||
|
|
||||||
if isinstance(prompt, list) and return_text:
|
if isinstance(prompt, list) and return_text:
|
||||||
@@ -332,9 +336,9 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
new_text += text_part
|
new_text += text_part
|
||||||
|
|
||||||
out = BaseMultiModalProcessorOutput(
|
out = BaseMultiModalProcessorOutput(
|
||||||
|
input_text=new_text,
|
||||||
images=images,
|
images=images,
|
||||||
audios=audios,
|
audios=audios,
|
||||||
input_text=new_text,
|
|
||||||
)
|
)
|
||||||
out.normalize()
|
out.normalize()
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import BaseImageProcessorFast
|
|
||||||
|
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor,
|
BaseMultimodalProcessor,
|
||||||
@@ -21,33 +20,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
self.image_token = "(<image>./</image>)"
|
self.image_token = "(<image>./</image>)"
|
||||||
self.audio_token = "(<audio>./</audio>)"
|
self.audio_token = "(<audio>./</audio>)"
|
||||||
|
|
||||||
def process_data_task(self, input_text, images=None, audios=None):
|
|
||||||
|
|
||||||
if isinstance(images, list) and len(images) == 0:
|
|
||||||
images = None
|
|
||||||
if isinstance(audios, list) and len(audios) == 0:
|
|
||||||
audios = None
|
|
||||||
processor = self._processor
|
|
||||||
args = {}
|
|
||||||
if isinstance(processor, BaseImageProcessorFast):
|
|
||||||
args["device"] = "cuda"
|
|
||||||
result = self._processor.__call__(
|
|
||||||
text=input_text,
|
|
||||||
images=images,
|
|
||||||
audios=audios,
|
|
||||||
return_tensors="pt",
|
|
||||||
chunk_input=True,
|
|
||||||
**args,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"input_ids": result.input_ids,
|
|
||||||
"pixel_values": getattr(result, "pixel_values", None),
|
|
||||||
"tgt_sizes": getattr(result, "tgt_sizes", None),
|
|
||||||
"audio_features": getattr(result, "audio_features", None),
|
|
||||||
"audio_feature_lens": getattr(result, "audio_feature_lens", None),
|
|
||||||
"audio_bounds": getattr(result, "audio_bounds", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self,
|
self,
|
||||||
image_data: List[Union[str, bytes]],
|
image_data: List[Union[str, bytes]],
|
||||||
|
|||||||
@@ -324,8 +324,9 @@ class MultimodalInputs:
|
|||||||
video_token_id: Optional[int] = None
|
video_token_id: Optional[int] = None
|
||||||
|
|
||||||
# audio
|
# audio
|
||||||
audio_start_id: Optional[torch.Tensor] = None
|
audio_token_id: Optional[int] = None
|
||||||
audio_end_id: Optional[torch.Tensor] = None
|
audio_start_id: Optional[int] = None
|
||||||
|
audio_end_id: Optional[int] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_dict(obj: dict):
|
def from_dict(obj: dict):
|
||||||
@@ -349,6 +350,7 @@ class MultimodalInputs:
|
|||||||
"slice_end_id",
|
"slice_end_id",
|
||||||
"audio_start_id",
|
"audio_start_id",
|
||||||
"audio_end_id",
|
"audio_end_id",
|
||||||
|
"audio_token_id",
|
||||||
]
|
]
|
||||||
for arg in optional_args:
|
for arg in optional_args:
|
||||||
if arg in obj:
|
if arg in obj:
|
||||||
|
|||||||
@@ -459,14 +459,16 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
|
|
||||||
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
image_inputs: Optional[Dict] = None
|
||||||
image_data=obj.image_data,
|
if obj.contains_mm_input():
|
||||||
input_text=input_text or input_ids,
|
image_inputs = await self.mm_processor.process_mm_data_async(
|
||||||
request_obj=obj,
|
image_data=obj.image_data,
|
||||||
max_req_input_len=self.max_req_input_len,
|
input_text=input_text or input_ids,
|
||||||
)
|
request_obj=obj,
|
||||||
if image_inputs and "input_ids" in image_inputs:
|
max_req_input_len=self.max_req_input_len,
|
||||||
input_ids = image_inputs["input_ids"]
|
)
|
||||||
|
if image_inputs and "input_ids" in image_inputs:
|
||||||
|
input_ids = image_inputs["input_ids"]
|
||||||
|
|
||||||
self._validate_token_len(obj, input_ids)
|
self._validate_token_len(obj, input_ids)
|
||||||
return self._create_tokenized_object(
|
return self._create_tokenized_object(
|
||||||
|
|||||||
@@ -36,6 +36,16 @@ from io import BytesIO
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from sglang.srt.utils import flatten_nested_list
|
||||||
|
|
||||||
|
|
||||||
|
def has_valid_data(data) -> bool:
|
||||||
|
if data is None:
|
||||||
|
return False
|
||||||
|
if isinstance(data, list):
|
||||||
|
return any(has_valid_data(item) for item in flatten_nested_list(data))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def select_best_resolution(original_size, possible_resolutions):
|
def select_best_resolution(original_size, possible_resolutions):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1165,7 +1165,7 @@ class ModelRunner:
|
|||||||
def model_is_mrope(self) -> bool:
|
def model_is_mrope(self) -> bool:
|
||||||
"""Detect if the model has "mrope" rope_scaling type.
|
"""Detect if the model has "mrope" rope_scaling type.
|
||||||
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
||||||
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
|
||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
return False
|
return False
|
||||||
is_mrope_enabled = "mrope_section" in rope_scaling
|
is_mrope_enabled = "mrope_section" in rope_scaling
|
||||||
|
|||||||
@@ -1520,12 +1520,15 @@ class MiniCPMO(MiniCPMBaseModel):
|
|||||||
slice_start_id: int = mm_input.slice_start_id
|
slice_start_id: int = mm_input.slice_start_id
|
||||||
slice_end_id: int = mm_input.slice_end_id
|
slice_end_id: int = mm_input.slice_end_id
|
||||||
|
|
||||||
media_token_pairs = [
|
data_token_pairs = [
|
||||||
(im_start_id, im_end_id),
|
(im_start_id, im_end_id),
|
||||||
(slice_start_id, slice_end_id),
|
(slice_start_id, slice_end_id),
|
||||||
(mm_input.audio_start_id, mm_input.audio_end_id),
|
(mm_input.audio_start_id, mm_input.audio_end_id),
|
||||||
]
|
]
|
||||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
data_start_token_ids = [im_start_id, mm_input.audio_start_id]
|
||||||
|
pattern = MultiModalityDataPaddingPatternTokenPairs(
|
||||||
|
data_token_pairs=data_token_pairs, data_start_token_ids=data_start_token_ids
|
||||||
|
)
|
||||||
|
|
||||||
return pattern.pad_input_tokens(input_ids, mm_input)
|
return pattern.pad_input_tokens(input_ids, mm_input)
|
||||||
|
|
||||||
|
|||||||
@@ -865,7 +865,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
pixel_values = torch.cat(
|
pixel_values = torch.cat(
|
||||||
[item.pixel_values for item in mm_input.mm_items], dim=0
|
[item.pixel_values for item in mm_input.mm_items], dim=0
|
||||||
)
|
)
|
||||||
# max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
|
|
||||||
max_num_images = max(max_num_images, pixel_values.shape[1])
|
max_num_images = max(max_num_images, pixel_values.shape[1])
|
||||||
|
|
||||||
max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
|
max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
|
||||||
|
|||||||
@@ -146,6 +146,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
use_qkv_parallel=True,
|
use_qkv_parallel=True,
|
||||||
|
rotary_embed="normal",
|
||||||
|
proj_bias=True,
|
||||||
qkv_backend=qkv_backend,
|
qkv_backend=qkv_backend,
|
||||||
softmax_in_single_precision=softmax_in_single_precision,
|
softmax_in_single_precision=softmax_in_single_precision,
|
||||||
flatten_batch=flatten_batch,
|
flatten_batch=flatten_batch,
|
||||||
|
|||||||
@@ -147,8 +147,8 @@ class TestVisionChunkedPrefill(CustomTestCase):
|
|||||||
|
|
||||||
def _test_chunked_prefill(self, batches, num_frames):
|
def _test_chunked_prefill(self, batches, num_frames):
|
||||||
# Chunked
|
# Chunked
|
||||||
|
chunked_server_pid = self.launch_server(chunked_prefill_size=1024)
|
||||||
try:
|
try:
|
||||||
chunked_server_pid = self.launch_server(chunked_prefill_size=1024)
|
|
||||||
outputs_chunked = []
|
outputs_chunked = []
|
||||||
for batch, num_frame in zip(batches, num_frames):
|
for batch, num_frame in zip(batches, num_frames):
|
||||||
output_chunked = self.generate_for_video(
|
output_chunked = self.generate_for_video(
|
||||||
|
|||||||
Reference in New Issue
Block a user