Sync from v0.13
This commit is contained in:
139
vllm/model_executor/models/internlm2_ve.py
Normal file
139
vllm/model_executor/models/internlm2_ve.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from itertools import islice
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.internlm2 import (
|
||||
InternLM2Attention,
|
||||
InternLM2ForCausalLM,
|
||||
InternLM2MLP,
|
||||
InternLM2Model,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
class InternLM2VEDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.attention = InternLM2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_parameters=config.rope_parameters,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
self.feed_forward = InternLM2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
self.feed_forward_ve = InternLM2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward_ve",
|
||||
)
|
||||
self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
visual_token_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.attention_norm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.attention_norm(hidden_states, residual)
|
||||
hidden_states = self.attention(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.ffn_norm(hidden_states, residual)
|
||||
if visual_token_mask is not None and visual_token_mask.any():
|
||||
visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool()
|
||||
text_token_mask = ~visual_token_mask
|
||||
hidden_states[visual_token_mask] = self.feed_forward_ve(
|
||||
hidden_states[visual_token_mask].reshape(-1, self.hidden_size)
|
||||
).flatten()
|
||||
if text_token_mask.any():
|
||||
hidden_states[text_token_mask] = self.feed_forward(
|
||||
hidden_states[text_token_mask].reshape(-1, self.hidden_size)
|
||||
).flatten()
|
||||
else:
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class InternLM2VEModel(InternLM2Model):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, prefix=prefix, layer_type=InternLM2VEDecoderLayer
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
visual_token_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.tok_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
visual_token_mask=visual_token_mask,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, prefix=prefix, model_type=InternLM2VEModel
|
||||
)
|
||||
Reference in New Issue
Block a user