add qwen3
This commit is contained in:
257
vllm-v0.6.2/vllm/model_executor/models/florence2.py
Normal file
257
vllm-v0.6.2/vllm/model_executor/models/florence2.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
|
||||
BartParallelLMHead,
|
||||
BartScaledWordEmbedding)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import AutoWeightsLoader
|
||||
|
||||
|
||||
class Florence2LanguageModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
|
||||
self.encoder = BartEncoder(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.decoder = BartDecoder(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
self.encoder.embed_tokens.weight = self.shared.weight
|
||||
self.decoder.embed_tokens.weight = self.shared.weight
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
Indices of *decoder* input sequence tokens in the vocabulary.
|
||||
Padding will be ignored by default should you
|
||||
provide it.
|
||||
positions
|
||||
Positions of *decoder* input sequence tokens.
|
||||
encoder_input_ids
|
||||
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||
encoder_positions:
|
||||
Positions of *encoder* input sequence tokens.
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Model output torch.Tensor
|
||||
"""
|
||||
|
||||
encoder_hidden_states = None
|
||||
|
||||
if encoder_input_ids.numel() > 0:
|
||||
# Run encoder attention if a non-zero number of encoder tokens
|
||||
# are provided as input
|
||||
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||
positions=encoder_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
# decoder outputs consists of
|
||||
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||
decoder_outputs = self.decoder(
|
||||
decoder_input_ids=input_ids,
|
||||
decoder_positions=positions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
return decoder_outputs
|
||||
|
||||
|
||||
class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
self.model = Florence2LanguageModel(vllm_config=vllm_config,
|
||||
prefix=prefix)
|
||||
embed_scale = math.sqrt(
|
||||
config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = BartParallelLMHead(self.vocab_size,
|
||||
config.d_model,
|
||||
embed_scale=embed_scale)
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
torch.Tensor of *decoder* input token ids.
|
||||
positions
|
||||
torch.Tensor of *decoder* position indices.
|
||||
encoder_input_ids
|
||||
torch.Tensor of *encoder* input token ids.
|
||||
encoder_positions
|
||||
torch.Tensor of *encoder* position indices
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Output torch.Tensor
|
||||
"""
|
||||
return self.model(input_ids, positions, encoder_input_ids,
|
||||
encoder_positions, kv_caches, attn_metadata)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if "final_logits_bias" in name:
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "embed_tokens" in name:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class Florence2ForConditionalGeneration(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
# TODO(Isotr0py): Add vision backbone
|
||||
self.language_model = Florence2LanguageForConditionalGeneration(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.language_model.sampler
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
*,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
torch.Tensor of *decoder* input token ids.
|
||||
positions
|
||||
torch.Tensor of *decoder* position indices.
|
||||
encoder_input_ids
|
||||
torch.Tensor of *encoder* input token ids.
|
||||
encoder_positions
|
||||
torch.Tensor of *encoder* position indices
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Output torch.Tensor
|
||||
"""
|
||||
return self.language_model(input_ids, positions, encoder_input_ids,
|
||||
encoder_positions, kv_caches, attn_metadata)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
skip_prefixes = [
|
||||
'image_projection', "vision_tower", "image_proj_norm",
|
||||
"image_pos_embed", "visual_temporal_embed"
|
||||
]
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
loader.load_weights(weights)
|
||||
Reference in New Issue
Block a user