From 31d6dee5c4b8741455650d9981ec87d78bd1193e Mon Sep 17 00:00:00 2001 From: Zijian Date: Thu, 12 Jun 2025 02:47:25 +0800 Subject: [PATCH] Support VILA models (#6106) --- python/sglang/bench_serving.py | 2 +- python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/conversation.py | 6 + .../multimodal_processors/base_processor.py | 4 +- .../managers/multimodal_processors/vila.py | 85 +++++ python/sglang/srt/models/vila.py | 305 ++++++++++++++++++ test/srt/test_vision_openai_server_b.py | 19 ++ 7 files changed, 419 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/managers/multimodal_processors/vila.py create mode 100644 python/sglang/srt/models/vila.py diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index db213a9e3..5057f3e15 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -399,7 +399,7 @@ async def async_request_sglang_generate( # NOTE: Some completion API might have a last # usage summary response without a token so we # want to check a token was generated - if data["text"]: + if "text" in data and data["text"]: timestamp = time.perf_counter() generated_text = data["text"] output_len = data["meta_info"]["completion_tokens"] diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 0b641d344..b52ae3957 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -578,6 +578,7 @@ multimodal_model_archs = [ "KimiVLForConditionalGeneration", "InternVLChatModel", "Phi4MMForCausalLM", + "VILAForConditionalGeneration", ] diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 5320fa2d7..ec5765a1f 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -983,3 +983,9 @@ def match_devstral(model_path: str): def match_phi_4_mm(model_path: str): if "phi-4-multimodal" in model_path.lower(): return "phi-4-mm" + + +@register_conv_template_matching_function +def match_vila(model_path: str): + if re.search(r"vila", model_path, re.IGNORECASE): + return "chatml" diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index dae9c2b75..618f66a2f 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -146,7 +146,7 @@ class BaseMultimodalProcessor(ABC): request_obj, max_req_input_len, **kwargs, - ): + ) -> Optional[Dict[str, Any]]: pass def get_estimated_frames_list(self, image_data): @@ -261,7 +261,7 @@ class BaseMultimodalProcessor(ABC): def load_mm_data( self, - prompt: str, + prompt: str | List[int], multimodal_tokens: MultimodalSpecialTokens, max_req_input_len: int, image_data: Optional[list] = None, diff --git a/python/sglang/srt/managers/multimodal_processors/vila.py b/python/sglang/srt/managers/multimodal_processors/vila.py new file mode 100644 index 000000000..53f224dc7 --- /dev/null +++ b/python/sglang/srt/managers/multimodal_processors/vila.py @@ -0,0 +1,85 @@ +from typing import Any, Dict, List, Optional, Type, cast + +import torch.nn as nn +from transformers.configuration_utils import PretrainedConfig +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from sglang.srt.managers.io_struct import ( + EmbeddingReqInput, + GenerateReqInput, + ImageDataItem, +) +from sglang.srt.managers.multimodal_processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem +from sglang.srt.models.vila import VILAForConditionalGeneration +from sglang.srt.server_args import ServerArgs + + +class VILAProcessor(ProcessorMixin): + """A stub class for the VILA processor.""" + + tokenizer: PreTrainedTokenizerBase + + +class VILAMultimodalProcessor(BaseMultimodalProcessor): + models: List[Type[nn.Module]] = [VILAForConditionalGeneration] + + _processor: VILAProcessor + + def __init__( + self, + hf_config: PretrainedConfig, + server_args: ServerArgs, + _processor: VILAProcessor, + ) -> None: + super().__init__(hf_config, server_args, _processor) + + async def process_mm_data_async( + self, + image_data: Optional[ImageDataItem | List[ImageDataItem]], + input_text: str | List[int], + request_obj: GenerateReqInput | EmbeddingReqInput, + max_req_input_len: int, + **kwargs, + ) -> Optional[Dict[str, Any]]: + if not image_data: + return None + + if not isinstance(image_data, list): + image_data = [image_data] + + mm_data = self.load_mm_data( + prompt=input_text, + multimodal_tokens=MultimodalSpecialTokens( + image_token=self._processor.tokenizer.image_token + ), + max_req_input_len=max_req_input_len, + image_data=image_data, + ) + + inputs = self.process_mm_data( + input_text=mm_data.input_text, + images=mm_data.images, + ) + + image_offsets = self.get_mm_items_offset( + input_ids=inputs.input_ids[0], + mm_token_id=cast(int, self._processor.tokenizer.image_token_id), + ) + + mm_items: List[MultimodalDataItem] = [ + MultimodalDataItem( + modality=Modality.IMAGE, + image_offsets=image_offsets, + pixel_values=inputs.pixel_values, + ) + ] + + return dict( + input_ids=inputs.input_ids[0].tolist(), + mm_items=mm_items, + ) diff --git a/python/sglang/srt/models/vila.py b/python/sglang/srt/models/vila.py new file mode 100644 index 000000000..8672f6982 --- /dev/null +++ b/python/sglang/srt/models/vila.py @@ -0,0 +1,305 @@ +import logging +from typing import Any, Dict, Iterable, List, Optional, Tuple, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel + +import sglang.srt.managers.mm_utils as mm_utils +import sglang.srt.model_loader.weight_utils as weight_utils +import sglang.srt.utils as utils +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.qwen2 import Qwen2ForCausalLM + +logger = logging.getLogger(__name__) + + +##### BEGIN COPY configuration.py ##### + + +class VILAConfig(PretrainedConfig): + # Class attributes. + model_type: str = "vila" + sub_configs: Dict[str, PretrainedConfig] = { + "text_config": Qwen2Config(), + "vision_config": SiglipVisionConfig(), + } + _auto_class: Optional[str] = "AutoConfig" + + # Configuration for sub-modules. + text_config: Qwen2Config = Qwen2Config() + vision_config: SiglipVisionConfig = SiglipVisionConfig() + + # Model configuration. + hidden_size: int + image_token_id: int + mm_hidden_size: int + mm_projector_type: str + mm_vision_select_feature: str + mm_vision_select_layer: int + video_token_id: int + + def __init__( + self, + text_config: Optional[Dict[str, Any]] = None, + vision_config: Optional[Dict[str, Any]] = None, + *, + hidden_size: int = 1536, + image_token_id: int = 151649, + mm_hidden_size: int = 1152, + mm_projector_type: str = "mlp_downsample_3x3_fix", + mm_vision_select_feature: str = "cls_patch", + mm_vision_select_layer: int = -2, + video_token_id: int = 151650, + **kwargs, + ): + super().__init__(**kwargs) + + self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config() + self.vision_config = ( + SiglipVisionConfig(**vision_config) + if vision_config + else SiglipVisionConfig() + ) + + self.hidden_size = hidden_size + self.image_token_id = image_token_id + self.mm_hidden_size = mm_hidden_size + self.mm_projector_type = mm_projector_type + self.mm_vision_select_feature = mm_vision_select_feature + self.mm_vision_select_layer = mm_vision_select_layer + self.video_token_id = video_token_id + + +##### END COPY configuration.py ##### + +##### BEGIN COPY modeling_vila.py ##### + + +class DownSample3x3BlockFix(nn.Module): + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size). + + Returns: + The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9). + """ + + batch_size, sequence_length, hidden_size = x.shape + + feat_size = int(sequence_length**0.5) + if feat_size**2 != sequence_length: + raise ValueError( + f"Cannot take square root: sequence_length {sequence_length} is not a perfect square" + ) + + features = x.reshape(batch_size, feat_size, feat_size, hidden_size) + + pad_after = (3 - feat_size % 3) % 3 + if pad_after > 0: + features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after)) + feat_size = feat_size + pad_after + + features = features.reshape( + batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size + ) + features = features.permute(0, 1, 3, 2, 4, 5).contiguous() + features = features.reshape(batch_size, -1, 9 * hidden_size) + + return features + + +class MultimodalProjector(nn.Module): + layers: nn.Sequential + + def __init__( + self, + config: VILAConfig, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + if config.mm_projector_type == "mlp_downsample_3x3_fix": + self.layers = nn.Sequential( + DownSample3x3BlockFix(), + nn.LayerNorm(config.mm_hidden_size * 9), + nn.Linear( + config.mm_hidden_size * 9, + config.mm_hidden_size * 3, + ), + nn.GELU(), + nn.LayerNorm(config.vision_config.hidden_size * 3), + nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size), + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size), + ) + else: + raise NotImplementedError( + f"Unsupported mm_projector_type: {config.mm_projector_type}" + ) + + self.layers.type(config.torch_dtype) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size). + + Returns: + The output tensor of shape (batch_size, image_pad_len, hidden_size). + """ + + return self.layers(x.to(device=self.device, dtype=self.dtype)) + + +##### END COPY modeling_vila.py ##### + + +class VILAForConditionalGeneration(nn.Module): + config: VILAConfig + quant_config: Optional[QuantizationConfig] + + logits_processor: LogitsProcessor + pooler: Pooler + + llm: Qwen2ForCausalLM + mm_projector: MultimodalProjector + vision_tower: SiglipVisionModel + + def __init__( + self, + config: VILAConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.quant_config = quant_config + + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + self.llm = Qwen2ForCausalLM( + config=config.text_config, + quant_config=quant_config, + prefix=utils.add_prefix("llm", prefix), + ) + self.mm_projector = MultimodalProjector(config) + self.vision_tower = SiglipVisionModel(config.vision_config) + + @property + def dtype(self) -> torch.dtype: + return self.config.torch_dtype + + def forward( + self, + input_ids: Tensor, + positions: Tensor, + forward_batch: ForwardBatch, + get_embedding: bool = False, + ) -> LogitsProcessorOutput: + output = mm_utils.general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.llm, + image_data_embedding_func=self.get_image_feature, + get_embedding=get_embedding, + positions=positions, + ) + + return cast(LogitsProcessorOutput, output) + + def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor: + pixel_values = cast(Tensor, mm_input[0].pixel_values) + + ##### BEGIN COPY modeling_vila.py ##### + + vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__( + pixel_values.to( + device=self.vision_tower.device, dtype=self.vision_tower.dtype + ), + output_hidden_states=True, + ) + + mm_projector_input = self._vision_tower_output_to_mm_projector_input( + vision_tower_output + ) + + image_embedding: Tensor = self.mm_projector.__call__( + mm_projector_input.to( + device=self.mm_projector.device, dtype=self.mm_projector.dtype + ) + ) + + ##### END COPY modeling_vila.py ##### + + return image_embedding + + def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> None: + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if name.startswith("llm."): + self.llm.load_weights([(name[len("llm.") :], loaded_weight)]) + else: + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", weight_utils.default_weight_loader + ) + weight_loader(param, loaded_weight) + + def pad_input_ids( + self, + input_ids: List[int], + image_inputs: MultimodalInputs, + ) -> List[int]: + pattern = MultiModalityDataPaddingPatternMultimodalTokens( + token_ids=[self.config.image_token_id], + ) + + return pattern.pad_input_tokens(input_ids, image_inputs) + + ##### BEGIN COPY modeling_vila.py ##### + + def _vision_tower_output_to_mm_projector_input( + self, + vision_tower_output: BaseModelOutputWithPooling, + ) -> Tensor: + assert vision_tower_output.hidden_states is not None + + selected_layer_hidden_states = vision_tower_output.hidden_states[ + self.config.mm_vision_select_layer + ] + + if self.config.mm_vision_select_feature == "cls_patch": + return selected_layer_hidden_states + else: + raise NotImplementedError( + f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}" + ) + + ##### END COPY modeling_vila.py ##### + + +EntryClass = [VILAForConditionalGeneration] diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index 2d05b0688..42427814c 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -222,5 +222,24 @@ class TestPhi4MMServer(TestOpenAIVisionServer): pass +class TestVILAServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "AndyZijianZhang/NVILA-Lite-2B" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + "--trust-remote-code", + "--context-length=65536", + ], + ) + cls.base_url += "/v1" + + if __name__ == "__main__": unittest.main()