from typing import List, Optional, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.internlm2 import (InternLM2Attention, InternLM2ForCausalLM, InternLM2MLP, InternLM2Model) from vllm.sequence import IntermediateTensors from .utils import make_layers, maybe_prefix class InternLM2VEDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.attention = InternLM2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attention", ) self.feed_forward = InternLM2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) self.feed_forward_ve = InternLM2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.feed_forward_ve", ) self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], visual_token_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.attention_norm(hidden_states) else: hidden_states, residual = self.attention_norm( hidden_states, residual) hidden_states = self.attention( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) # Fully Connected hidden_states, residual = self.ffn_norm(hidden_states, residual) if visual_token_mask is not None and visual_token_mask.any(): visual_token_mask = visual_token_mask.repeat( 1, self.hidden_size).bool() text_token_mask = ~visual_token_mask hidden_states[visual_token_mask] = self.feed_forward_ve( hidden_states[visual_token_mask].reshape( -1, self.hidden_size)).flatten() if text_token_mask.any(): hidden_states[text_token_mask] = self.feed_forward( hidden_states[text_token_mask].reshape( -1, self.hidden_size)).flatten() else: hidden_states = self.feed_forward(hidden_states) return hidden_states, residual class InternLM2VEModel(InternLM2Model): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: InternLM2VEDecoderLayer( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, visual_token_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.tok_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i - self.start_layer], attn_metadata, residual, visual_token_mask=visual_token_mask, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class InternLM2VEForCausalLM(InternLM2ForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) self.model = InternLM2VEModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))