# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from itertools import islice import torch from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.internlm2 import ( InternLM2Attention, InternLM2ForCausalLM, InternLM2MLP, InternLM2Model, ) from vllm.sequence import IntermediateTensors class InternLM2VEDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.attention = InternLM2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_parameters=config.rope_parameters, max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attention", ) self.feed_forward = InternLM2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) self.feed_forward_ve = InternLM2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.feed_forward_ve", ) self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, visual_token_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.attention_norm(hidden_states) else: hidden_states, residual = self.attention_norm(hidden_states, residual) hidden_states = self.attention( positions=positions, hidden_states=hidden_states, ) # Fully Connected hidden_states, residual = self.ffn_norm(hidden_states, residual) if visual_token_mask is not None and visual_token_mask.any(): visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool() text_token_mask = ~visual_token_mask hidden_states[visual_token_mask] = self.feed_forward_ve( hidden_states[visual_token_mask].reshape(-1, self.hidden_size) ).flatten() if text_token_mask.any(): hidden_states[text_token_mask] = self.feed_forward( hidden_states[text_token_mask].reshape(-1, self.hidden_size) ).flatten() else: hidden_states = self.feed_forward(hidden_states) return hidden_states, residual class InternLM2VEModel(InternLM2Model): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__( vllm_config=vllm_config, prefix=prefix, layer_type=InternLM2VEDecoderLayer ) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, visual_token_mask: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.tok_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, visual_token_mask=visual_token_mask, ) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class InternLM2VEForCausalLM(InternLM2ForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__( vllm_config=vllm_config, prefix=prefix, model_type=InternLM2VEModel )