# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from itertools import islice from typing import Optional, Union import torch from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.internlm2 import (InternLM2Attention, InternLM2ForCausalLM, InternLM2MLP, InternLM2Model) from vllm.sequence import IntermediateTensors class InternLM2VEDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.attention = InternLM2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attention", ) self.feed_forward = InternLM2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) self.feed_forward_ve = InternLM2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.feed_forward_ve", ) self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], visual_token_mask: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.attention_norm(hidden_states) else: hidden_states, residual = self.attention_norm( hidden_states, residual) hidden_states = self.attention( positions=positions, hidden_states=hidden_states, ) # Fully Connected hidden_states, residual = self.ffn_norm(hidden_states, residual) if visual_token_mask is not None and visual_token_mask.any(): visual_token_mask = visual_token_mask.repeat( 1, self.hidden_size).bool() text_token_mask = ~visual_token_mask hidden_states[visual_token_mask] = self.feed_forward_ve( hidden_states[visual_token_mask].reshape( -1, self.hidden_size)).flatten() if text_token_mask.any(): hidden_states[text_token_mask] = self.feed_forward( hidden_states[text_token_mask].reshape( -1, self.hidden_size)).flatten() else: hidden_states = self.feed_forward(hidden_states) return hidden_states, residual class InternLM2VEModel(InternLM2Model): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=InternLM2VEDecoderLayer) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, visual_token_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.tok_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, visual_token_mask=visual_token_mask, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class InternLM2VEForCausalLM(InternLM2ForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix, model_type=InternLM2VEModel)