[Model] Supporet InternVL2_5 on v0.11.0 (#72)
Co-authored-by: v_qiaoyijin <v_qiaoyijin@baidu.com>
This commit is contained in:
@@ -29,6 +29,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from vllm.model_executor.models.vision import run_dp_sharded_vision_model
|
||||
|
||||
NORM2FN = {
|
||||
'rms_norm': RMSNorm,
|
||||
'layer_norm': nn.LayerNorm,
|
||||
@@ -137,6 +139,7 @@ class InternParallelAttention(nn.Module):
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -166,6 +169,7 @@ class InternParallelAttention(nn.Module):
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
@@ -183,6 +187,7 @@ class InternParallelAttention(nn.Module):
|
||||
self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
@@ -286,6 +291,7 @@ class InternMLP(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -295,12 +301,14 @@ class InternMLP(nn.Module):
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1")
|
||||
prefix=f"{prefix}.fc1",
|
||||
disable_tp=use_data_parallel)
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2")
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
@@ -319,6 +327,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -329,11 +338,13 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
self.attn = self._init_attn(config,
|
||||
quant_config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.attn")
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel)
|
||||
|
||||
self.mlp = InternMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
|
||||
@@ -351,6 +362,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
*,
|
||||
num_dummy_heads: int,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
# fallback to sdpa attention if tp unavailable
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -387,6 +399,7 @@ class InternVisionEncoder(nn.Module):
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -401,7 +414,8 @@ class InternVisionEncoder(nn.Module):
|
||||
InternVisionEncoderLayer(config,
|
||||
quant_config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
@@ -428,10 +442,12 @@ class InternVisionModel(nn.Module):
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
|
||||
self.embeddings = InternVisionEmbeddings(config)
|
||||
self.encoder = InternVisionEncoder(
|
||||
@@ -440,6 +456,7 @@ class InternVisionModel(nn.Module):
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
@@ -463,7 +480,11 @@ class InternVisionModel(nn.Module):
|
||||
raise ValueError(
|
||||
f'wrong pixel_values size: {pixel_values.shape}')
|
||||
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
||||
if self.use_data_parallel:
|
||||
encoder_outputs = run_dp_sharded_vision_model(
|
||||
hidden_states, self.encoder)
|
||||
else:
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
@@ -477,4 +498,4 @@ class InternVisionModel(nn.Module):
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
return loaded_params
|
||||
Reference in New Issue
Block a user