import math from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, BartParallelLMHead, BartScaledWordEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .utils import AutoWeightsLoader class Florence2LanguageModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model) self.encoder = BartEncoder(config, cache_config=cache_config, quant_config=quant_config) self.decoder = BartDecoder(config, cache_config=cache_config, quant_config=quant_config) if self.config.tie_word_embeddings: self.encoder.embed_tokens.weight = self.shared.weight self.decoder.embed_tokens.weight = self.shared.weight def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata) -> torch.Tensor: r""" Args: input_ids Indices of *decoder* input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. positions Positions of *decoder* input sequence tokens. encoder_input_ids Indices of *encoder* input sequence tokens in the vocabulary. encoder_positions: Positions of *encoder* input sequence tokens. kv_caches: Layer-wise list of KV cache tensors attn_metadata: vLLM Attention metadata structure Returns: Model output torch.Tensor """ encoder_hidden_states = None if encoder_input_ids.numel() > 0: # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, positions=encoder_positions, kv_caches=kv_caches, attn_metadata=attn_metadata) # decoder outputs consists of # (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( decoder_input_ids=input_ids, decoder_positions=positions, encoder_hidden_states=encoder_hidden_states, kv_caches=kv_caches, attn_metadata=attn_metadata) return decoder_outputs class Florence2LanguageForConditionalGeneration(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config self.model = Florence2LanguageModel(vllm_config=vllm_config, prefix=prefix) embed_scale = math.sqrt( config.d_model) if config.scale_embedding else 1.0 self.vocab_size = config.vocab_size self.lm_head = BartParallelLMHead(self.vocab_size, config.d_model, embed_scale=embed_scale) self.logits_processor = LogitsProcessor(self.vocab_size, config.vocab_size) self.sampler = get_sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: r""" Args: input_ids torch.Tensor of *decoder* input token ids. positions torch.Tensor of *decoder* position indices. encoder_input_ids torch.Tensor of *encoder* input token ids. encoder_positions torch.Tensor of *encoder* position indices kv_caches: Layer-wise list of KV cache tensors attn_metadata: vLLM Attention metadata structure Returns: Output torch.Tensor """ return self.model(input_ids, positions, encoder_input_ids, encoder_positions, kv_caches, attn_metadata) def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue param = params_dict[name.replace(weight_name, param_name)] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: if "final_logits_bias" in name: continue if self.config.tie_word_embeddings and "embed_tokens" in name: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) class Florence2ForConditionalGeneration(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config # TODO(Isotr0py): Add vision backbone self.language_model = Florence2LanguageForConditionalGeneration( vllm_config=vllm_config.with_hf_config(config.text_config), prefix=prefix, ) @property def sampler(self): return self.language_model.sampler def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, *, encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, **kwargs, ) -> torch.Tensor: r""" Args: input_ids torch.Tensor of *decoder* input token ids. positions torch.Tensor of *decoder* position indices. encoder_input_ids torch.Tensor of *encoder* input token ids. encoder_positions torch.Tensor of *encoder* position indices kv_caches: Layer-wise list of KV cache tensors attn_metadata: vLLM Attention metadata structure Returns: Output torch.Tensor """ return self.language_model(input_ids, positions, encoder_input_ids, encoder_positions, kv_caches, attn_metadata) def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): skip_prefixes = [ 'image_projection', "vision_tower", "image_proj_norm", "image_pos_embed", "visual_temporal_embed" ] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) loader.load_weights(weights)