[Model] Supporet InternVL2_5 on v0.11.0 (#72)

Co-authored-by: v_qiaoyijin <v_qiaoyijin@baidu.com>
This commit is contained in:
Joeegin
2026-01-04 16:38:05 +08:00
committed by GitHub
parent 684ce2761e
commit ded24f5026
3 changed files with 97 additions and 46 deletions

View File

@@ -29,6 +29,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.vision import run_dp_sharded_vision_model
NORM2FN = { NORM2FN = {
'rms_norm': RMSNorm, 'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm, 'layer_norm': nn.LayerNorm,
@@ -137,6 +139,7 @@ class InternParallelAttention(nn.Module):
*, *,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -166,6 +169,7 @@ class InternParallelAttention(nn.Module):
bias=config.qkv_bias, bias=config.qkv_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv", prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel,
) )
self.qk_normalization = config.qk_normalization self.qk_normalization = config.qk_normalization
@@ -183,6 +187,7 @@ class InternParallelAttention(nn.Module):
self.embed_dim, self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
disable_tp=use_data_parallel,
) )
self.attn = MultiHeadAttention(self.num_heads_per_partition, self.attn = MultiHeadAttention(self.num_heads_per_partition,
@@ -286,6 +291,7 @@ class InternMLP(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -295,12 +301,14 @@ class InternMLP(nn.Module):
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1") prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel)
self.fc2 = RowParallelLinear(config.intermediate_size, self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2") prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
@@ -319,6 +327,7 @@ class InternVisionEncoderLayer(nn.Module):
*, *,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -329,11 +338,13 @@ class InternVisionEncoderLayer(nn.Module):
self.attn = self._init_attn(config, self.attn = self._init_attn(config,
quant_config, quant_config,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.attn") prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel)
self.mlp = InternMLP(config, self.mlp = InternMLP(config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
@@ -351,6 +362,7 @@ class InternVisionEncoderLayer(nn.Module):
*, *,
num_dummy_heads: int, num_dummy_heads: int,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
# fallback to sdpa attention if tp unavailable # fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
@@ -387,6 +399,7 @@ class InternVisionEncoder(nn.Module):
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
@@ -401,7 +414,8 @@ class InternVisionEncoder(nn.Module):
InternVisionEncoderLayer(config, InternVisionEncoderLayer(config,
quant_config, quant_config,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.layers.{layer_idx}") prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(num_hidden_layers) for layer_idx in range(num_hidden_layers)
]) ])
@@ -428,10 +442,12 @@ class InternVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.use_data_parallel = use_data_parallel
self.embeddings = InternVisionEmbeddings(config) self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder( self.encoder = InternVisionEncoder(
@@ -440,6 +456,7 @@ class InternVisionModel(nn.Module):
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel,
) )
def get_input_embeddings(self): def get_input_embeddings(self):
@@ -463,7 +480,11 @@ class InternVisionModel(nn.Module):
raise ValueError( raise ValueError(
f'wrong pixel_values size: {pixel_values.shape}') f'wrong pixel_values size: {pixel_values.shape}')
encoder_outputs = self.encoder(inputs_embeds=hidden_states) if self.use_data_parallel:
encoder_outputs = run_dp_sharded_vision_model(
hidden_states, self.encoder)
else:
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
return encoder_outputs return encoder_outputs
@@ -477,4 +498,4 @@ class InternVisionModel(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params

View File

@@ -3,6 +3,7 @@
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial from functools import partial
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
@@ -30,10 +31,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP, default_pooling_type from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.interfaces_base import default_pooling_type
from vllm.model_executor.models.utils import (is_pp_missing_parameter, from vllm.model_executor.models.utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
@@ -298,7 +299,7 @@ class InternLM2Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
@@ -358,10 +359,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.output, hidden_states, logits = self.logits_processor(self.output, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
@@ -423,13 +422,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
delattr(self, attr) delattr(self, attr)
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.v_head = RowParallelLinear( self.head_dtype = vllm_config.model_config.head_dtype
config.hidden_size,
1, self.v_head = RowParallelLinear(config.hidden_size,
bias=False, 1,
input_is_parallel=False, bias=False,
prefix=maybe_prefix(prefix, "v_head"), input_is_parallel=False,
) params_dtype=self.head_dtype,
prefix=maybe_prefix(prefix, "v_head"),
return_bias=False)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
@@ -446,5 +447,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors, hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds) inputs_embeds)
logits, _ = self.v_head(hidden_states) hidden_states = hidden_states.to(self.head_dtype)
return logits logits = self.v_head(hidden_states)
return logits

View File

@@ -7,6 +7,7 @@
# Copyright (c) 2023 OpenGVLab # Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, Optional, TypeVar, Union from typing import Annotated, Any, Literal, Optional, TypeVar, Union
@@ -21,13 +22,13 @@ from transformers import BatchEncoding, PretrainedConfig, TensorType
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from .intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors) MultiModalKwargsItems, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -36,6 +37,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import set_default_torch_num_threads
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, SupportsLoRA, from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, SupportsLoRA,
@@ -114,13 +116,26 @@ InternVLVideoInputs = Union[InternVLVideoPixelInputs,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int): def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
return T.Compose([ transform = T.Compose([
T.Lambda(lambda img: convert_image_mode(img, 'RGB')), T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
T.Resize((input_size, input_size), T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC), interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(), T.ToTensor(),
T.Normalize(mean=MEAN, std=STD) T.Normalize(mean=MEAN, std=STD)
]) ])
# Image transformation operations (which include tensor computations
# on the CPU) can occupy a substantial number of CPU cores, introducing
# overhead due to CPU contention. This issue becomes particularly
# noticeable when deploying multiple vLLM instances on a single machine.
# Therefore, it is necessary to limit the number of threads allocated to
# image transformation tasks.
num_threads = int(os.environ.get("OMP_NUM_THREADS", "1"))
def apply(img):
with set_default_torch_num_threads(num_threads):
return transform(img)
return apply
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
@@ -796,18 +811,19 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs: out_mm_data = out_mm_kwargs.get_data()
image_num_patches = out_mm_kwargs["image_num_patches"] if "image_num_patches" in out_mm_data:
image_num_patches = out_mm_data["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor) assert isinstance(image_num_patches, torch.Tensor)
image_num_patches = image_num_patches.tolist() image_num_patches = image_num_patches.tolist()
elif "image_embeds" in out_mm_kwargs: elif "image_embeds" in out_mm_data:
# TODO: Use image size information in dictionary embedding inputs # TODO: Use image size information in dictionary embedding inputs
# to compute num_patches (similar to Qwen2-VL) # to compute num_patches (similar to Qwen2-VL)
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) image_num_patches = [None] * len(out_mm_data["image_embeds"])
else: else:
image_num_patches = [] image_num_patches = []
@@ -853,9 +869,13 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
def get_video_token(self) -> Optional[str]: def get_video_token(self) -> Optional[str]:
text_model_type = self.get_hf_config().get_text_config().model_type text_model_type = self.get_hf_config().get_text_config().model_type
if text_model_type == "qwen2": video_token_map = {
return "<|video_pad|>" "qwen2": "<|video_pad|>",
return None "qwen3": "<|video_pad|>",
"qwen3_moe": "<|video_pad|>",
"gpt_oss": "<|reserved_200000|>",
}
return video_token_map.get(text_model_type)
def get_num_frames_with_most_features( def get_num_frames_with_most_features(
self, self,
@@ -965,15 +985,19 @@ class InternVLMultiModalProcessor(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
prompt_repl: list[PromptUpdate] = super()._get_prompt_updates( prompt_repl = super()._get_prompt_updates(
mm_items, hf_processor_mm_kwargs, out_mm_kwargs) mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
out_mm_kwargs=out_mm_kwargs,
)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "video_num_patches" in out_mm_kwargs: out_mm_data = out_mm_kwargs.get_data()
video_num_patches = out_mm_kwargs["video_num_patches"] if "video_num_patches" in out_mm_data:
video_num_patches = out_mm_data["video_num_patches"]
assert isinstance(video_num_patches, torch.Tensor) assert isinstance(video_num_patches, torch.Tensor)
video_num_patches = video_num_patches.tolist() video_num_patches = video_num_patches.tolist()
else: else:
@@ -991,12 +1015,15 @@ class InternVLMultiModalProcessor(
video_context_token=hf_processor.video_token) video_context_token=hf_processor.video_token)
if self.info.supports_video: if self.info.supports_video:
prompt_repl.append( prompt_repl = [
*prompt_repl,
PromptReplacement( PromptReplacement(
modality="video", modality="video",
target="<video>", target="<video>",
replacement=get_video_replacement_internvl, replacement=get_video_replacement_internvl,
)) )
]
return prompt_repl return prompt_repl
@@ -1007,6 +1034,8 @@ class InternVLMultiModalProcessor(
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA): SupportsLoRA):
supports_encoder_tp_data = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"): if modality.startswith("image"):
@@ -1025,6 +1054,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self._patch_quant_config(config, quant_config) self._patch_quant_config(config, quant_config)
image_size = config.force_image_size or config.vision_config.image_size image_size = config.force_image_size or config.vision_config.image_size
@@ -1092,7 +1122,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
prefix=prefix, prefix=prefix,
) use_data_parallel=self.use_data_parallel)
else: else:
return InternVisionPatchModel(config.vision_config) return InternVisionPatchModel(config.vision_config)
@@ -1368,10 +1398,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states)
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
@@ -1392,4 +1420,4 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
return MultiModelKeys.from_string_field( return MultiModelKeys.from_string_field(
language_model="language_model", language_model="language_model",
connector="mlp1", connector="mlp1",
tower_model="vision_model") tower_model="vision_model")