add qwen3

This commit is contained in:
Chranos
2026-02-04 17:22:39 +08:00
parent d1c0f68ab4
commit 8511fe8530
1932 changed files with 300426 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, has_inner_state, supports_lora,
supports_multimodal, supports_pp)
from .interfaces_base import (VllmModelForEmbedding,
VllmModelForTextGeneration, is_embedding_model,
is_text_generation_model)
from .registry import ModelRegistry
__all__ = [
"ModelRegistry",
"VllmModelForEmbedding",
"is_embedding_model",
"VllmModelForTextGeneration",
"is_text_generation_model",
"HasInnerState",
"has_inner_state",
"SupportsLoRA",
"supports_lora",
"SupportsMultiModal",
"supports_multimodal",
"SupportsPP",
"supports_pp",
]

View File

@@ -0,0 +1,563 @@
"""Inference-only Snowflake Arctic model."""
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig, DeepSpeedFPParameter)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.arctic import ArcticConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class ArcticMLP(nn.Module):
def __init__(self,
config: ArcticConfig,
layer_id: int,
expert_id: int = -1,
is_residual_mlp: bool = False,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True):
super().__init__()
self.hidden_size = config.hidden_size
self.expert_id = expert_id
self.layer_id = layer_id
self.ffn_dim = config.intermediate_size if not is_residual_mlp \
else self.hidden_size
self.w13 = MergedColumnParallelLinear(self.hidden_size,
[self.ffn_dim] * 2,
bias=False,
quant_config=quant_config)
self.w2 = RowParallelLinear(self.ffn_dim,
self.hidden_size,
bias=False,
reduce_results=reduce_results,
quant_config=quant_config)
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, hidden_states):
gate_up, _ = self.w13(hidden_states)
hidden_states = self.act_fn(gate_up)
hidden_states, _ = self.w2(hidden_states)
return hidden_states
class ArcticMoE(nn.Module):
"""
Model-parallel implementation of Arctic MoE Layer.
"""
def __init__(self,
config: ArcticConfig,
layer_id: int,
tp_size: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.hidden_size = config.hidden_size
self.num_experts = config.num_local_experts
self.layer_id = layer_id
self.top_k = config.num_experts_per_tok
self.intermediate_size = config.intermediate_size // self.tp_size
self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
self.is_quant = isinstance(quant_config, DeepSpeedFPConfig)
self.reduce_results = reduce_results
# Some other parameters
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if not self.is_moe_layer:
self.mlp = ArcticMLP(config,
layer_id=layer_id,
quant_config=quant_config,
reduce_results=reduce_results)
else:
self.gate = ReplicatedLinear(self.hidden_size,
self.num_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=quant_config)
if self.is_quant:
self.ws = DeepSpeedFPParameter(
torch.Size((self.num_experts, 2 * self.intermediate_size,
self.hidden_size)),
params_dtype=params_dtype,
quant_config=quant_config,
)
self.w2s = DeepSpeedFPParameter(
torch.Size((self.num_experts, self.hidden_size,
self.intermediate_size)),
params_dtype=params_dtype,
quant_config=quant_config,
)
else:
self.ws = nn.Parameter(
torch.empty(self.num_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
dtype=self.params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype))
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.ds_dequantize() if self.is_quant else param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if self.is_quant:
param.ds_quantize_(param_data)
def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
do_normalize = self.top_k > 1
topk_weights, topk_ids = fused_topk(hidden_states,
router_logits,
self.top_k,
renormalize=do_normalize)
# topk_ids: (num_tokens, k)
if self.is_quant:
if 2 * num_tokens <= self.num_experts:
# If much fewer tokens than experts, use selective dequantize.
ws_dequantized = self.ws.ds_selective_dequantize(
topk_ids.flatten())
w2s_dequantized = self.w2s.ds_selective_dequantize(
topk_ids.flatten())
# We gathered the experts to the tokens so update the mapping.
topk_ids = torch.arange(
0,
topk_ids.numel(),
device=topk_ids.device,
).reshape(topk_ids.shape)
else:
ws_dequantized = self.ws.ds_dequantize()
w2s_dequantized = self.w2s.ds_dequantize()
final_hidden_states = fused_experts(
hidden_states,
ws_dequantized if self.is_quant else self.ws,
w2s_dequantized if self.is_quant else self.w2s,
topk_weights,
topk_ids,
inplace=True)
if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
def forward(self, hidden_states: torch.Tensor):
if self.is_moe_layer:
final_hidden_states = self.local_moe_fused(hidden_states)
else:
final_hidden_states = self.mlp(hidden_states)
return final_hidden_states
class ArcticAttention(nn.Module):
def __init__(
self,
config: ArcticConfig,
layer_idx: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
reduce_results=True,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class ArcticDecoderLayer(nn.Module):
def __init__(
self,
config: ArcticConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
self.use_residual = config.use_residual and is_moe_layer
self.self_attn = ArcticAttention(config,
layer_idx,
cache_config,
quant_config=quant_config)
self.block_sparse_moe = ArcticMoE(
config,
layer_id=layer_idx,
quant_config=quant_config,
reduce_results=(not self.use_residual))
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
if self.use_residual:
self.residual_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.residual_mlp = ArcticMLP(config,
layer_id=layer_idx,
is_residual_mlp=True,
reduce_results=False)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual_input = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual_input + hidden_states
residual_attn = hidden_states
if self.use_residual:
hidden_states = self.residual_layernorm(hidden_states)
hidden_states = self.residual_mlp(hidden_states)
residual_mlp = hidden_states
hidden_states = self.post_attention_layernorm(residual_input)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual_mlp + hidden_states
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states = residual_attn + hidden_states
else:
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual_attn + hidden_states
return hidden_states
@support_torch_compile
class ArcticModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=self.vocab_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: ArcticDecoderLayer(config, int(
prefix.split(".")[-1]), cache_config, quant_config),
prefix=f"{prefix}.layers")
self._attn_implementation = config._attn_implementation
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states)
return hidden_states
class ArcticForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.model = ArcticModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.unpadded_vocab_size = config.vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
mlp_params_mapping: List[Tuple[str, str, int]] = []
expert_params_mapping: List[Tuple[str, str, int]] = []
num_layers = self.config.num_hidden_layers
for layer in range(num_layers):
mlp_params_mapping.append(
(f"layers.{layer}.residual_mlp.w13.weight",
f"layers.{layer}.residual_mlp.w1.weight", 0))
mlp_params_mapping.append(
(f"layers.{layer}.residual_mlp.w13.weight",
f"layers.{layer}.residual_mlp.w3.weight", 1))
if layer % 2 == 0:
# MLP layers
mlp_params_mapping.append(
(f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0))
mlp_params_mapping.append(
(f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1))
else:
# MoE layers
for expert_id in range(self.config.num_local_experts):
expert_params_mapping.append(
("ws", f"experts.{expert_id}.w1.weight", expert_id))
expert_params_mapping.append(
("w2s", f"experts.{expert_id}.w2.weight", expert_id))
expert_params_mapping.append(
("ws", f"experts.{expert_id}.w3.weight", expert_id))
params_dict = dict(self.named_parameters())
logger.info(
"It will take ~10 minutes loading from the 16-bit weights. "
"Alternatively, use the prequantized 8-bit weights of arctic "
"and set load-format to `sharded_state` will accelerate loading.")
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, shard_id in mlp_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, shard_id \
in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,467 @@
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BaiChuanMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class BaiChuanAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size: int,
num_heads: int,
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name
self.W_pack = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
quant_config=quant_config)
else:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class BaiChuanDecoderLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class BaiChuanModel(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
position_embedding: str = "ROPE",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: BaiChuanDecoderLayer(config, position_embedding,
cache_config, quant_config),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual,
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
position_embedding: str = "ROPE",
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = BaiChuanModel(vllm_config=vllm_config,
prefix=prefix,
position_embedding=position_embedding)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name == "lm_head.weight":
# Unlike Baichuan, Baichuan2 normalizes the head weights.
# Refer to:
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
# Distinguish between Baichuan and Baichuan2 by checking the
# vocab size. This is suggested by
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
is_baichuan2 = self.config.vocab_size == 125696
if is_baichuan2:
loaded_weight = torch.nn.functional.normalize(
loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B.
NOTE: the class name has a lower case 'c'.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ROPE")
else: # baichuan 13b, baichuan2 13b
super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ALIBI")
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 7B.
NOTE: the class name has an upper case 'C'.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ROPE")

View File

@@ -0,0 +1,998 @@
# Derived from BART implementation posted on HuggingFace; license below:
#
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model."""
import math
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import BartConfig
from transformers.utils import logging
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
logger = logging.get_logger(__name__)
def get_bsz_seq_len(input_ids):
shp = input_ids.shape
ndim = len(shp)
if ndim == 1:
return 1, input_ids.numel()
else:
return shp[:2]
class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int):
# Bart is set up so that if padding_idx is
# specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately.
# Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(
self,
positions: torch.Tensor,
attn_type: AttentionType,
) -> torch.Tensor:
"""`input_ids' shape is expected to be [bsz x seqlen]."""
assert attn_type != AttentionType.ENCODER_DECODER
return super().forward(positions + self.offset)
class BartScaledWordEmbedding(VocabParallelEmbedding):
"""
This module overrides VocabParallelEmbedding's
forward by multiplying with embeddings scale.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
return super().forward(input_ids) * self.embed_scale
class BartParallelLMHead(ParallelLMHead):
"""
This module overrides ParallelLMHead's
forward by dividing by embeddings scale,
yielding effectively the inverse of
BartScaledWordEmbedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
return super().forward(input_ids) / self.embed_scale
class BartEncoderAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER)
output, _ = self.out_proj(attn_output)
return output
class BartDecoderSelfAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.DECODER)
output, _ = self.out_proj(attn_output)
return output
class BartCrossAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
decoder_hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
# (afeldman-nm 2024/07/22) TODO:
# Need a more efficient solution for q/k/v
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if encoder_hidden_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
output, _ = self.out_proj(attn_output)
return output
class BartEncoderLayer(nn.Module):
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BartEncoderAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
config=config,
cache_config=cache_config,
quant_config=quant_config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = get_act_fn(config.activation_function)
ffn_hidden_size = self.embed_dim
ffn_intermediate_size = config.encoder_ffn_dim
ffn_has_bias = True
self.fc1 = ColumnParallelLinear(
ffn_hidden_size,
ffn_intermediate_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.act = get_act_fn("gelu")
self.fc2 = RowParallelLinear(
ffn_intermediate_size,
ffn_hidden_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
hidden_states
torch.Tensor of *encoder* input embeddings.
kv_cache:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Encoder layer output torch.Tensor
"""
residual = hidden_states
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states,
min=-clamp_value,
max=clamp_value)
return hidden_states
class BartDecoderLayer(nn.Module):
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BartDecoderSelfAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
config=config,
cache_config=cache_config,
quant_config=quant_config)
self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
'''
afeldman-nm: personally I would call this "cross-attention",
however I left the name as "encoder_attn" to maintain consistency
with the name of the pretrained weights.
'''
self.encoder_attn = BartCrossAttention(
self.embed_dim,
config.decoder_attention_heads,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
ffn_hidden_size = self.embed_dim
ffn_intermediate_size = config.encoder_ffn_dim
ffn_has_bias = True
self.fc1 = ColumnParallelLinear(
ffn_hidden_size,
ffn_intermediate_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
ffn_intermediate_size,
ffn_hidden_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
decoder_hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
decoder_hidden_states
torch.Tensor of *decoder* input embeddings.
kv_cache:
KV cache tensor
attn_metadata:
vLLM Attention metadata structure
encoder_hidden_states
torch.Tensor of *encoder* input embeddings.
Returns:
Decoder layer output torch.Tensor
"""
residual = decoder_hidden_states
# Self Attention
hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
residual = hidden_states
hidden_states = self.encoder_attn(
decoder_hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class BartEncoder(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers*
self attention layers. Each layer is a [`BartEncoderLayer`].
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
embed_dim = config.d_model
self.max_source_positions = config.max_position_embeddings
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
embed_dim,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
)
self.layers = nn.ModuleList(
[BartEncoderLayer(config,cache_config,quant_config) \
for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(
positions,
AttentionType.ENCODER,
)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(
hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
return hidden_states
class BartDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers.
Each layer is a [`BartDecoderLayer`]
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
self.max_target_positions = config.max_position_embeddings
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
)
self.layers = nn.ModuleList(
[BartDecoderLayer(config,cache_config,quant_config) \
for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
def forward(self, decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
decoder_input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
decoder_positions
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Decoder output torch.Tensor
"""
inputs_embeds = self.embed_tokens(decoder_input_ids)
# embed positions
embed_pos = self.embed_positions(
decoder_positions,
AttentionType.DECODER,
)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
# decoder layers
for idx, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
decoder_hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
encoder_hidden_states=encoder_hidden_states,
)
return hidden_states
class BartModel(nn.Module):
_tied_weights_keys = [
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.encoder = BartEncoder(config,
cache_config,
quant_config=quant_config)
self.decoder = BartDecoder(config,
cache_config,
quant_config=quant_config)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Model output torch.Tensor
"""
encoder_hidden_states = None
if encoder_input_ids.numel() > 0:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states,
kv_caches=kv_caches,
attn_metadata=attn_metadata)
return decoder_outputs
class BartForConditionalGeneration(nn.Module):
base_model_prefix = "model"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
# currently all existing BART models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.config = config
self.model = BartModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.lm_head = BartParallelLMHead(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Output torch.Tensor
"""
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
stacked_params_mapping = {
"q_proj": {
"param_name": "qkv_proj",
"shard_id": "q",
},
"k_proj": {
"param_name": "qkv_proj",
"shard_id": "k",
},
"v_proj": {
"param_name": "qkv_proj",
"shard_id": "v",
},
}
params_mapping = {
"beta": "bias",
"gamma": "weight",
"LayerNorm": "layernorm",
}
def _rename_key(self, key: str):
prefix = f"{self.base_model_prefix}."
key = key[len(prefix):] if key.startswith(prefix) else key
for src, dst in self.params_mapping.items():
key = key.replace(src, dst)
return key
def _rename_stacked_param(
self,
name: str,
) -> Tuple[str, Optional[str]]:
for key, mapping in self.stacked_params_mapping.items():
if key in name:
name = name.replace(key, mapping["param_name"])
return name, mapping["shard_id"]
return name, None
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
model_params_dict = dict(self.model.named_parameters())
top_params_dict = dict(self.named_parameters())
weights_tuple_list = list(weights)
shared_embedding_weight = None
shared_embedding_shard_id = None
for name, loaded_weight in weights_tuple_list:
name = self._rename_key(name)
name, shard_id = self._rename_stacked_param(name)
if ('shared.weight' in name
or 'encoder.embed_tokens.weight' in name
or 'decoder.embed_tokens.weight' in name
or 'lm_head.weight' in name):
assert shared_embedding_weight is None, (
"Conflicting embedding weights.")
shared_embedding_weight = loaded_weight
shared_embedding_shard_id = shard_id
else:
# Skip the specific downstream task weight.
if name.startswith('cls.'):
continue
# use Pooler instead.
if name.startswith('pooler.'):
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in model_params_dict:
continue
param = model_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if shard_id:
weight_loader(param, loaded_weight, shard_id)
else:
weight_loader(param, loaded_weight)
# Assign shared weight values
encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
default_weight_loader)
decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
default_weight_loader)
lm_head_in_param = top_params_dict['lm_head.weight']
lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
default_weight_loader)
assert shared_embedding_weight is not None
if shared_embedding_shard_id:
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
shared_embedding_shard_id)
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
shared_embedding_shard_id)
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
shared_embedding_shard_id)
else:
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)

View File

@@ -0,0 +1,428 @@
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .utils import maybe_prefix
class BertEmbedding(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.position_embeddings = VocabParallelEmbedding(
config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
# Token type embeddings. (TODO: move off hotpath?)
token_type_embeddings = self.token_type_embeddings(
torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device))
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
class BertEncoder(nn.Module):
def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.layer = nn.ModuleList([
BertLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
hidden_states: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
for i in range(len(self.layer)):
layer = self.layer[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
return hidden_states
class BertLayer(nn.Module):
def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.attention = BertAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
layer_norm_eps=config.layer_norm_eps,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.intermediate = BertIntermediate(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
self.output = BertOutput(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_norm_eps=config.layer_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.output")
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
):
attn_output = self.attention(hidden_states, kv_cache, attn_metadata)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output
class BertAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
layer_norm_eps: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.self = BertSelfAttention(hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.output")
self.output = BertSelfOutput(hidden_size=hidden_size,
layer_norm_eps=layer_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.output")
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
self_output = self.self(hidden_states, kv_cache, attn_metadata)
return self.output(self_output, hidden_states)
class BertSelfAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = self.total_num_heads
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size=self.hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_ONLY)
return output
class BertSelfOutput(nn.Module):
def __init__(self,
hidden_size: int,
layer_norm_eps: float,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.dense = RowParallelLinear(input_size=hidden_size,
output_size=hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense")
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
def forward(self, hidden_states: torch.Tensor,
input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertIntermediate(nn.Module):
def __init__(self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.dense = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense")
self.intermediate_act_fn = get_act_fn(hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self,
hidden_size: int,
intermediate_size: int,
layer_norm_eps: float,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.dense = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense")
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
def forward(self, hidden_states: torch.Tensor,
input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(config,
cache_config,
quant_config,
prefix=f"{prefix}.encoder")
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids)
return self.encoder(hidden_states, kv_caches, attn_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"),
("qkv_proj", "key", "k"),
("qkv_proj", "value", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "pooler" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class BertEmbeddingModel(nn.Module):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
attn_metadata=attn_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=BertEmbedding)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return Pooler.from_config_with_defaults(pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)

View File

@@ -0,0 +1,480 @@
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, Optional, Tuple, Union
import torch
import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
return image_size // patch_size
def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_blip_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length
def get_blip_image_feature_size(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
def get_max_blip_image_tokens(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_image_feature_size(hf_config)
def dummy_seq_data_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
def dummy_image_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def input_processor_for_blip(
model_config: ModelConfig,
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_feature_size = get_blip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module):
def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
)
self.num_patches = get_blip_num_patches(image_size=self.image_size,
patch_size=self.patch_size)
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(
torch.randn(1, self.num_positions, self.embed_dim))
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(
dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embeds = self.position_embedding.to(target_dtype)
embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]
return embeddings
class BlipParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: Union[BlipVisionConfig, Blip2VisionConfig],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.projection = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.projection",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out)
return attn_output, None
class BlipFallbackAttention(nn.Module):
"""Fallback Blip attention for older transformers versions"""
def __init__(self, config: BlipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim)
self.projection = nn.Linear(self.embed_dim, self.embed_dim)
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, None]:
bsz, tgt_len, embed_dim = hidden_states.size()
qkv = self.qkv(hidden_states)
qkv = qkv.reshape(bsz, tgt_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
out = (attn @ v).transpose(1, 2).reshape(bsz, tgt_len, embed_dim)
out = self.projection(out)
return out, None
class BlipMLP(nn.Module):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class BlipEncoderLayer(nn.Module):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
# fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipFallbackAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class BlipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`BlipEncoderLayer`].
Args:
config: BlipConfig
"""
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
BlipEncoderLayer(config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
return hidden_states
class BlipVisionModel(nn.Module):
config_class = BlipVisionConfig
main_input_name = "pixel_values"
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.config = config
self.embeddings = BlipVisionEmbeddings(config)
self.encoder = BlipEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
else:
self.post_layernorm = None
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.encoder(inputs_embeds=hidden_states)
if self.post_layernorm is None:
return hidden_states
return self.post_layernorm(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
layer_count = len(self.encoder.layers)
for name, loaded_weight in weights:
# post_layernorm is not needed in BlipVisionModel
if (name.startswith("post_layernorm")
and self.post_layernorm is None):
continue
# omit layers when num_hidden_layers_override is set
if name.startswith("encoder.layers"):
layer_idx = int(name.split(".")[2])
if layer_idx >= layer_count:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,697 @@
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
apply_chunking_to_forward)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265
class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Blip2ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
class Blip2QFormerMultiHeadAttention(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
) -> None:
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of "
f"the number of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = (config.hidden_size //
config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
kv_hidden_size = config.encoder_hidden_size
else:
kv_hidden_size = config.hidden_size
self.key = nn.Linear(kv_hidden_size, self.all_head_size)
self.value = nn.Linear(kv_hidden_size, self.all_head_size)
self.position_embedding_type = getattr(config,
"position_embedding_type",
"absolute")
if self.position_embedding_type != "absolute":
raise NotImplementedError("Unsupported position_embedding_type: "
f"{self.position_embedding_type}")
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
x = x.view(*x.size()[:-1], self.num_attention_heads,
self.attention_head_size)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
):
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(
self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(
self.value(encoder_hidden_states))
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
mixed_query_layer = self.query(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_probs = torch.softmax(attention_scores * self.scaling,
dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
context_layer = context_layer.view(*context_layer.size()[:-2],
self.all_head_size)
return context_layer
class Blip2QFormerSelfOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
hidden_states: torch.Tensor,
input_tensor: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class Blip2QFormerAttention(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
) -> None:
super().__init__()
self.attention = Blip2QFormerMultiHeadAttention(
config,
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=is_cross_attention,
)
self.output = Blip2QFormerSelfOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor]:
self_output = self.attention(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
attention_output = self.output(self_output, hidden_states)
return attention_output
class Blip2QFormerIntermediate(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = get_act_fn(config.hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class Blip2QFormerOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
hidden_states: torch.Tensor,
input_tensor: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class Blip2QFormerLayer(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
layer_idx: int,
) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = Blip2QFormerAttention(config,
quant_config=quant_config,
cache_config=cache_config)
self.layer_idx = layer_idx
if layer_idx % config.cross_attention_frequency == 0:
self.crossattention = Blip2QFormerAttention(
config,
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=True)
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate_query = Blip2QFormerIntermediate(config)
self.output_query = Blip2QFormerOutput(config)
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
query_length: int,
):
attention_output = self.attention(hidden_states)
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
query_attention_output = self.crossattention(
query_attention_output,
encoder_hidden_states=encoder_hidden_states,
)
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text],
dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
return layer_output
def feed_forward_chunk(self,
attention_output: torch.Tensor) -> torch.Tensor:
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(
self, attention_output: torch.Tensor) -> torch.Tensor:
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
class Blip2QFormerEncoder(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([
Blip2QFormerLayer(config,
quant_config=quant_config,
cache_config=cache_config,
layer_idx=layer_idx)
for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
query_length: int,
) -> torch.Tensor:
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
hidden_states = layer_module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
query_length=query_length,
)
return hidden_states
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
class Blip2QFormerModel(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
) -> None:
super().__init__()
self.config = config
self.layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.encoder = Blip2QFormerEncoder(config,
quant_config=quant_config,
cache_config=cache_config)
def forward(
self,
query_embeds: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
) -> torch.Tensor:
query_length = query_embeds.shape[1]
embedding_output = self.layernorm(query_embeds)
embedding_output = self.dropout(embedding_output)
sequence_output = self.encoder(
embedding_output,
encoder_hidden_states=encoder_hidden_states,
query_length=query_length,
)
return sequence_output
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
return hf_config.num_query_tokens
def get_max_blip2_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
if isinstance(vision_config, Blip2VisionConfig):
return get_max_blip_image_tokens(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def dummy_seq_data_for_blip2(
hf_config: Blip2Config,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip2_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
seq_data, ranges = dummy_seq_data_for_blip2(
hf_config,
seq_len,
num_images,
image_token_id=BLIP2_IMAGE_TOKEN_ID,
)
if isinstance(vision_config, Blip2VisionConfig):
mm_data = dummy_image_for_blip(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
hf_config = ctx.get_hf_config(Blip2Config)
image_feature_size = get_blip2_image_feature_size(hf_config)
# The original model places image tokens at the front
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
new_token_ids += inputs["prompt_token_ids"]
new_prompt = inputs.get("prompt")
if new_prompt is not None:
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_model = BlipVisionModel(config.vision_config, quant_config)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens,
config.qformer_config.hidden_size))
self.qformer = Blip2QFormerModel(config.qformer_config,
cache_config=cache_config,
quant_config=quant_config)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,
config.text_config.hidden_size,
bias=True,
)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return Blip2ImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_model(pixel_values)
return image_features
def _process_image_pixels(self,
inputs: Blip2ImagePixelInputs) -> torch.Tensor:
assert self.vision_model is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_model, pixel_values)
def _process_image_input(self,
image_input: Blip2ImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
-1)
query_output = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_features,
)
return self.language_projection(query_output)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
"""Run forward pass for BLIP-2.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
`"Question: What's the content of the image? Answer:"`.
Tokenizer outputs:
`[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
dummy tokens (denoted as `50265`), resulting in:
`[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
We insert 32 tokens since it corresponds to the number of query
embeddings outputted by the Q-Former and inputted to the language model.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
See also:
:class:`Blip2ImageInputs`
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@@ -0,0 +1,362 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
# Copyright 2023 The vLLM team.
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BloomAttention(nn.Module):
def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
self.total_num_heads = config.n_head
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.dense(attn_output)
return output
class BloomMLP(nn.Module):
def __init__(
self,
config: BloomConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
4 * hidden_size,
quant_config=quant_config,
)
self.gelu_impl = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.dense_h_to_4h(x)
x = self.gelu_impl(x)
x, _ = self.dense_4h_to_h(x)
return x
class BloomBlock(nn.Module):
def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, cache_config,
quant_config)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, quant_config)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attention_output = self.self_attention(
position_ids=position_ids,
hidden_states=layernorm_output,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output) + residual
return output
@support_torch_compile
class BloomModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embed_dim = config.hidden_size
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size,
self.embed_dim,
)
self.word_embeddings_layernorm = nn.LayerNorm(
self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: BloomBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h")
# Final Layer Norm
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class BloomForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = BloomModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
if not name.startswith("transformer."):
name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
# NOTE: BLOOM's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,688 @@
# Adapted from
# https://github.com/THUDM/GLM-4
"""Inference-only ChatGLM model compatible with THUDM weights."""
from argparse import Namespace
from array import array
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
import torch
from PIL import Image
from torch import nn
from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
def calculate_image_placeholder(vision_config):
return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
def mm_input_mapper_for_glmv(
ctx: InputContext,
data: MultiModalData[object],
) -> Dict:
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
if tokenizer is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
try:
raw_batch_data = tokenizer.apply_chat_template(
conversation=[{
"role": "user",
"image": data
}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True).data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
pixel_values = raw_batch_data['images']
return MultiModalKwargs({'pixel_values': pixel_values})
def merge_glm_vision_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
boi_token_id: int,
eoi_token_id: int,
) -> torch.Tensor:
boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]
mask = torch.zeros_like(input_ids, dtype=torch.bool)
for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
assert boi_pos < eoi_pos
mask[boi_pos:eoi_pos + 1] = True
inputs_embeds[mask] = vision_embeddings.view(-1,
vision_embeddings.shape[-1])
return inputs_embeds
class GLMImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
def get_max_glmv_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None:
return 1
elif isinstance(vision_config, dict):
return calculate_image_placeholder(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def dummy_data_for_glmv(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]) -> DummyData:
hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None:
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
seq_data = SequenceData(token_ids)
return DummyData(seq_data, None)
elif isinstance(vision_config, dict):
image_size = vision_config["image_size"]
image_placeholder_length = calculate_image_placeholder(vision_config)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
[0] * image_placeholder_length +
[hf_config.eoi_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0] * (seq_len - image_placeholder_length - 2))
seq_data = SequenceData(token_ids)
mm_data = {
"image": Image.new("RGB", (image_size, image_size), color=0)
}
return DummyData(seq_data, mm_data)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def find_all_positions(input_ids: List[int], target: int) -> List[int]:
return [index for index, value in enumerate(input_ids) if value == target]
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None:
return inputs
elif isinstance(vision_config, dict):
image_placeholder_length = calculate_image_placeholder(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
input_ids = inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(
ctx.model_config.model,
trust_remote_code=ctx.model_config.trust_remote_code)
try:
raw_batch_data = tokenizer.apply_chat_template(
conversation=[{
"role": "user",
"image": multi_modal_data["image"],
"content": inputs['prompt'],
}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
).data
except Exception:
logger.error("Failed to process content (%s)", inputs['prompt'])
raise
input_ids = raw_batch_data['input_ids'][0].tolist()
boi_token_id = hf_config.boi_token_id
eoi_token_id = hf_config.eoi_token_id
boi_positions = find_all_positions(input_ids, boi_token_id)
eoi_positions = find_all_positions(input_ids, eoi_token_id)
assert len(boi_positions) == len(eoi_positions)
new_input_ids = []
final_processed_position = 0
final_processed_position = 0
for boi_position, eoi_position in zip(boi_positions, eoi_positions):
assert boi_position < eoi_position
new_input_ids.extend(input_ids[final_processed_position:boi_position +
1])
new_input_ids.extend([input_ids[boi_position + 1]] *
image_placeholder_length)
final_processed_position = eoi_position
new_input_ids.extend(input_ids[final_processed_position:])
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(new_input_ids)
return token_inputs(
prompt_token_ids=new_input_ids,
prompt=prompt,
multi_modal_data=multi_modal_data,
)
class GLMAttention(nn.Module):
def __init__(
self,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.multi_query_attention = config.multi_query_attention
self.total_num_kv_heads = (config.multi_query_group_num
if config.multi_query_attention else
config.num_attention_heads)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim // 2,
max_position=max_positions,
base=10000 * rope_ratio,
is_neox_style=False,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(
q,
k,
v,
kv_cache,
attn_metadata,
)
attn_output, _ = self.dense(context_layer)
return attn_output
class GLMMLP(nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(
self,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.add_bias = config.add_bias_linear
# Project to 4h.
self.dense_h_to_4h = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
quant_config=quant_config,
)
self.activation_func = SiluAndMul()
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, _ = self.dense_4h_to_h(intermediate_parallel)
return output
class GLMBlock(nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
self.fp32_residual_connection = config.fp32_residual_connection
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = layer_norm_func(config.hidden_size,
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config, cache_config, quant_config)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
self.mlp = GLMMLP(config, quant_config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.self_attention(
hidden_states=layernorm_output,
position_ids=position_ids,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = residual + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = self.mlp(layernorm_output) + residual
return output
class GLMTransformer(nn.Module):
"""Transformer class."""
def __init__(
self,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
# Number of layers.
self.num_layers = config.num_layers
# Transformer layers.
self.start_layer, self.end_layer, self.layers = make_layers(
self.num_layers,
lambda prefix: GLMBlock(config, cache_config, quant_config),
prefix=f"{prefix}.layers",
)
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i - self.start_layer],
attn_metadata=attn_metadata,
)
# Final layer norm.
if get_pp_group().is_last_rank and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class ChatGLMModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config)
vision_config_flag = getattr(config, 'vision_config', None)
if vision_config_flag is not None:
self.vision_config = Namespace(**config.vision_config)
self.vision = EVA2CLIPModel(self.config, quant_config)
else:
self.vision = None
self.make_empty_intermediate_tensors = (
self.encoder.make_empty_intermediate_tensors)
def _parse_and_validate_image_input(
self, **kwargs: object) -> GLMImagePixelInputs:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is not None and self.vision is not None:
if isinstance(pixel_values, torch.Tensor):
if pixel_values.ndim > 2:
pixel_values = torch.concat(list(pixel_values))
elif isinstance(pixel_values, list):
return torch.concat(pixel_values)
else:
raise TypeError("""pixel_values must be a torch.Tensor
or a list of torch.Tensor
""")
return GLMImagePixelInputs(pixel_values=pixel_values)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is None:
inputs_embeds = self.embedding(input_ids)
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input["pixel_values"] is not None:
pixel_values = image_input["pixel_values"].to(
dtype=inputs_embeds.dtype)
image_embeds = self.vision(pixel_values)
boi_token_id = self.config.boi_token_id
eoi_token_id = self.config.eoi_token_id
inputs_embeds = merge_glm_vision_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
vision_embeddings=image_embeds,
boi_token_id=boi_token_id,
eoi_token_id=eoi_token_id)
else:
inputs_embeds = intermediate_tensors["hidden_states"]
# Run encoder.
hidden_states = self.encoder(
hidden_states=inputs_embeds,
position_ids=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
SupportsMultiModal):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.transformer = ChatGLMModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings:
self.transformer.output_layer.weight = (
self.transformer.embedding.weight)
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = get_sampler()
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
**kwargs)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
"transformer.vision.linear_proj.merged_proj.weight": {
"transformer.vision.linear_proj.gate_proj.weight": None,
"transformer.vision.linear_proj.dense_h_to_4h.weight": None,
}
}
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
is_weight_to_be_merge = False
for _, merged_weight_dict in merged_weights_dict.items():
if name in merged_weight_dict:
assert merged_weight_dict[name] is None
merged_weight_dict[name] = loaded_weight
is_weight_to_be_merge = True
if is_weight_to_be_merge:
continue
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
name = name.replace(".word_embeddings", "")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for combined_name, merged_weight_dict in merged_weights_dict.items():
if combined_name in params_dict:
param = params_dict[combined_name]
combined_weight = torch.cat(list(merged_weight_dict.values()),
dim=0)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, combined_weight)

View File

@@ -0,0 +1,585 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData
# FIXME(chenxiaobing|wangchao2): import xformers will intialize device, which
# will cause ray status error.
from vllm.platforms import current_platform
if current_platform.is_mlu():
USE_XFORMERS_OPS = False
else:
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
return image_size // patch_size
def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_clip_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
return get_clip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size) + 1
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
return get_clip_image_feature_size(hf_config)
def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
mm_key: str = "image"):
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_clip(
hf_config: CLIPVisionConfig,
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_video_for_clip(
hf_config: CLIPVisionConfig,
num_frames: int,
num_videos: int = 1,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
pil_frame = dummy_image_for_clip(
hf_config,
num_images=1,
image_width_override=image_width_override,
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
video_data = [mm_data_per_video] * num_videos
mm_data = {"video": video_data}
return mm_data
def input_processor_for_clip(
model_config: ModelConfig,
hf_config: CLIPVisionConfig,
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_clip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = get_clip_num_patches(image_size=self.image_size,
patch_size=self.patch_size)
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
self.register_buffer("position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(
dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class CLIPParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out)
return attn_output, None
class CLIPFallbackAttention(nn.Module):
"""
Fallback CLIP attention implementation when xformers is not available
or num_heads is not divisible by tp_size.
"""
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(self, hidden_states: torch.Tensor) -> tuple:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
# Get Q, K, V
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Reshape for multi-head attention
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
# Compute attention scores
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scale
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# Apply dropout if needed
if self.dropout > 0.0:
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# Apply attention to values
attn_output = torch.matmul(attn_weights, value_states)
# Reshape back
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, tgt_len, embed_dim)
# Output projection
attn_output = self.out_proj(attn_output)
return attn_output, None
class CLIPMLP(nn.Module):
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class CLIPEncoderLayer(nn.Module):
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = CLIPFallbackAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
CLIPEncoderLayer(config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
return hidden_states
class CLIPVisionTransformer(nn.Module):
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings(config)
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
self.post_layernorm = None
def forward(
self,
pixel_values: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states)
if self.post_layernorm is None:
return hidden_states
return self.post_layernorm(hidden_states)
class CLIPVisionModel(nn.Module):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values)
@property
def device(self):
return next(self.parameters()).device
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel
if (name.startswith("vision_model.post_layernorm")
and self.vision_model.post_layernorm is None):
continue
# omit layers when num_hidden_layers_override is set
if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,437 @@
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name,
row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
variance_epsilon)
hidden_states = weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)
class LayerNorm(nn.Module):
def __init__(self, param_shape=None, eps=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(param_shape))
self.variance_epsilon = eps
set_weight_attrs(self.weight,
{"weight_loader": row_parallel_weight_loader})
def forward(self, hidden_states, residuals=None):
hidden_states = layer_norm_func(hidden_states, self.weight,
self.variance_epsilon)
return hidden_states, residuals
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):
def __init__(
self,
config: CohereConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class CohereAttention(nn.Module):
def __init__(
self,
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.total_num_heads = config.num_attention_heads
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.hidden_size // self.total_num_heads
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = getattr(
config, "model_max_length", None) or getattr(
config, "max_position_embeddings", 8192)
self.rope_theta = config.rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None)
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=self.rope_scaling,
is_neox_style=False,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim),
eps=config.layer_norm_eps)
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
self.head_dim),
eps=config.layer_norm_eps)
def _apply_qk_norm(self, q, k):
q = q.view(*q.shape[:-1], -1, self.head_dim)
k = k.view(*k.shape[:-1], -1, self.head_dim)
q, _ = self.q_norm(q)
k, _ = self.k_norm(k)
q = q.view(*q.shape[:-2], -1)
k = k.view(*k.shape[:-2], -1)
return q, k
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class CohereDecoderLayer(nn.Module):
def __init__(self,
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config,
cache_config,
quant_config=quant_config)
self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states_attention = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states_mlp = self.mlp(hidden_states)
# Add everything together
hidden_states = residual + hidden_states_attention + hidden_states_mlp
return hidden_states, residual
@support_torch_compile
class CohereModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: CohereDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {"embed_tokens": "input_embeddings"}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
# currently all existing command R models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.quant_config = quant_config
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=config.logit_scale)
self.model = CohereModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
if is_not_lora:
logits = self.logits_processor(self.model.embed_tokens,
hidden_states, sampling_metadata)
else:
logits = self.logits_processor(self.model.embed_tokens.base_layer,
hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)

View File

@@ -0,0 +1,437 @@
from typing import Iterable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dbrx import DbrxConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class DbrxRouter(nn.Module):
"""A Router implementation for DBRX that returns logits for each expert
per token.
"""
def __init__(
self,
config: DbrxConfig,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.ffn_config.moe_num_experts
self.d_model = config.d_model
self.layer = ReplicatedLinear(
self.d_model,
self.num_total_experts,
bias=False,
params_dtype=params_dtype,
quant_config=None,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.layer(hidden_states)
return router_logits
class DbrxExperts(FusedMoE):
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__(
num_experts=config.ffn_config.moe_num_experts,
top_k=config.ffn_config.moe_top_k,
hidden_size=config.d_model,
intermediate_size=config.ffn_config.ffn_hidden_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=get_tensor_model_parallel_world_size(),
)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.d_model = config.d_model
self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
self.tp_size)
# Define custom weight loader for dbrx model
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
# DBRX uses GLU for each experts.
# GLU has 3 linear layers: w1, v1 and w2.
if weight_name.endswith("w1"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
if weight_name.endswith("v1"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:,
shard_size:2 * shard_size, :] = loaded_weight[:,
shard, :]
if weight_name.endswith("w2"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]
class DbrxMoE(nn.Module):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.d_model = config.d_model
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.router = DbrxRouter(config, self.params_dtype)
self.experts = DbrxExperts(config=config,
quant_config=quant_config,
params_dtype=self.params_dtype)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.d_model)
# router_logits: (num_tokens, n_experts)
router_logits = self.router(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
class DbrxAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.head_dim = self.d_model // self.total_num_heads
self.total_num_kv_heads = config.attn_config.kv_n_heads
self.clip_qkv = config.attn_config.clip_qkv
self.rope_theta = config.attn_config.rope_theta
self.max_position = config.max_seq_len
# pylint: disable=invalid-name
self.Wqkv = QKVParallelLinear(
self.d_model,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
tp_world_size = get_tensor_model_parallel_world_size()
self.tp_size = tp_world_size
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
hidden_states, _ = self.out_proj(attn_output)
return hidden_states
class DbrxFusedNormAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.attn = DbrxAttention(config, cache_config, quant_config)
self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + x
residual = hidden_states
hidden_states = self.norm_2(hidden_states)
return hidden_states, residual
class DbrxBlock(nn.Module):
def __init__(
self,
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
quant_config)
self.ffn = DbrxMoE(config, quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states, residual = self.norm_attn_norm(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = self.ffn(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DbrxModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
)
self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers,
lambda prefix: DbrxBlock(config, cache_config, quant_config),
prefix=f"{prefix}.blocks",
)
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias,
nn.Parameter):
# Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.d_model))
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
else:
assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
block = self.blocks[i]
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm_f(hidden_states)
return hidden_states
class DbrxForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
if config.tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Dbrx models.")
self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.lm_head = ParallelLMHead(
config.vocab_size,
config.d_model,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [(
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, weight_name)
break
else:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,118 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 DeciAI Research Team. All rights reserved.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""
from typing import Iterable, Tuple
import torch
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
from .utils import is_pp_missing_parameter
class DeciLMForCausalLM(LlamaForCausalLM):
"""
Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
Based on the llama executor.
The main difference is that DeciLM uses Variable Grouped Query Attention.
The constant number of GQA heads in the decoder is overridden with a value
per layer.
Usually, in the HuggingFace implementation, instead of
"config.num_key_value_heads", we use
"config.num_key_value_heads_per_layer[i]" which varies.
Currently, PagedAttention does not work well with variable GQA, so we
normalize the weights upon loading, and use uniform GQA with the max value
instead.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")
super().__init__(vllm_config=vllm_config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "k_proj" in name or "v_proj" in name:
loaded_weight = self._degroup_weight(loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
hidden_size = self.config.hidden_size
head_size = self.config.hidden_size // self.config.num_attention_heads
target_num_kv_heads = self.config.num_key_value_heads
num_kv_heads = loaded_weight.shape[0] // head_size
n_repeats = target_num_kv_heads / num_kv_heads
assert n_repeats == int(n_repeats)
n_repeats = int(n_repeats)
loaded_weight = loaded_weight.view(num_kv_heads, head_size,
hidden_size)
loaded_weight = torch.repeat_interleave(loaded_weight,
repeats=n_repeats,
dim=0)
loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size,
hidden_size)
return loaded_weight

View File

@@ -0,0 +1,477 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class DeepseekMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekMoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.n_routed_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.n_routed_experts}.")
self.experts = nn.ModuleList([
DeepseekMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
self.pack_params()
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
bias=False,
quant_config=None)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
def pack_params(self):
w1 = []
w2 = []
for expert in self.experts:
w1.append(expert.gate_up_proj.weight)
w2.append(expert.down_proj.weight)
self.w1 = torch._utils._flatten_dense_tensors(w1)
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
for data, param in zip(w1s, w1):
param.data = data
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
self.w2 = torch._utils._flatten_dense_tensors(w2)
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
for data, param in zip(w2s, w2):
param.data = data
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.config.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True)
if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
class DeepseekAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class DeepseekDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = DeepseekAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
else:
self.mlp = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class DeepseekModel(nn.Module):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config=quant_config),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class DeepseekForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = DeepseekModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,613 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2 model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class DeepseekV2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV2MoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.routed_scaling_factor = config.routed_scaling_factor
if self.tp_size > config.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.n_routed_experts}.")
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
self.experts = FusedMoE(num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts")
self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
import math
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
class DeepseekV2Attention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
else:
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
# O projection.
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
# self.attn = Attention(self.num_heads,
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)
# TODO, support head_size 192
self.attn = Attention(self.num_local_heads,
256,
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.view(
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
class DeepseekV2DecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.self_attn = DeepseekV2Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank
if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(
config,
prefix,
cache_config=cache_config,
quant_config=quant_config,
),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = DeepseekV2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,174 @@
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
class EAGLE(nn.Module):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
Reference implementation: https://github.com/SafeAILab/EAGLE
Differences from reference implementation:
1. In reference, LlamaDecoderLayer implementation doesn't have
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
but we do as HF implementation also does.
2. We allow any decoder layer to be used in EAGLE whereas in reference
decoder layer is fixed to be LlamaDecoderLayer.
3. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
architectures = getattr(self.config.model, "architectures", [])
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
self.model = model_cls(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=getattr(self.config, "eagle_fc_bias", False))
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.truncated_vocab_size,
logit_scale)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self.token_map = None
@property
def sampler(self):
return self.model.sampler
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
tok_embeds = self.model.model.embed_tokens(input_ids)
inputs_embeds = self.fc(
torch.cat([tok_embeds, previous_hidden_states], dim=-1))
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
hidden_states = self.model.model(
input_ids=None,
inputs_embeds=inputs_embeds,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
if self.token_map is not None:
_logits = logits
logits = -torch.inf * torch.ones(
size=(*_logits.shape[:-1], self.orig_vocab_size),
device=_logits.device,
dtype=_logits.dtype)
logits[..., self.token_map] = _logits
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
# due to missing lm_head weights and its config being that of a
# Llama model. Here's a compatible version with the same weights:
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
# Also, here's an example script for converting trained EAGLE
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
model_weights = {}
for name, loaded_weight in weights:
if name == "token_map":
if self.config.truncated_vocab_size < self.config.vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name.startswith("fc.weight"):
weight_loader = getattr(self.fc.weight, "weight_loader",
default_weight_loader)
weight_loader(self.fc.weight, loaded_weight)
elif name.startswith("fc.bias"):
if self.fc.bias is not None:
weight_loader = getattr(self.fc.bias, "weight_loader",
default_weight_loader)
weight_loader(self.fc.bias, loaded_weight)
else:
raise ValueError("Found bias in the loaded weights "
"but the model config doesn't have bias")
elif name.startswith("model.lm_head.") or name.startswith(
"model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight
elif name.startswith("lm_head.") or name.startswith("model."):
model_weights[name] = loaded_weight
else:
model_weights[f"model.{name}"] = loaded_weight
lm_head_weight = model_weights.pop("lm_head.weight")
if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]:
lm_head_weight = lm_head_weight[self.token_map]
weight_loader = getattr(self.lm_head.weight, "weight_loader",
default_weight_loader)
weight_loader(self.lm_head.weight, lm_head_weight)
self.model.load_weights(model_weights.items())

View File

@@ -0,0 +1,601 @@
# Adapted from
# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py
# Copyright 2024 The LG U+ CTO AI Tech Lab.
# Copyright 2021 The LG AI Research EXAONE Lab
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Exaone model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class ExaoneGatedMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.c_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.c_proj(x)
return x
class ExaoneAttention(nn.Module):
def __init__(
self,
config: ExaoneConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
class ExaoneBlockAttention(nn.Module):
def __init__(
self,
config: ExaoneConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.attention = ExaoneAttention(
config=config,
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=bias,
cache_config=cache_config,
prefix=prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
return self.attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
class ExaoneDecoderLayer(nn.Module):
def __init__(
self,
config: ExaoneConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.attn = ExaoneBlockAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.attn",
)
self.mlp = ExaoneGatedMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.activation_function,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
else:
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.ln_2(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class ExaoneModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.wte = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.wte = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.wte = PPMissingLayer()
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: ExaoneDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.h",
)
if get_pp_group().is_last_rank:
self.ln_f = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
else:
self.ln_f = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states
class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"c_fc_0",
"c_fc_1",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"out_proj",
"gate_up_proj",
"c_proj",
"wte",
"lm_head",
]
embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"c_fc_0": ("gate_up_proj", 0),
"c_fc_1": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.transformer = ExaoneModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.transformer.wte.weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".c_fc_0", 0),
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.transformer.h[layer_idx], nn.Identity):
layer_self_attn = self.transformer.h[layer_idx].attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

View File

@@ -0,0 +1,509 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Falcon model."""
import math
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
FalconConfig = Union[HF_FalconConfig, RWConfig]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(1,
1 + 2 * num_remaining_heads,
2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class FalconAttention(nn.Module):
def __init__(
self,
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
self.new_decoder_architecture = config.new_decoder_architecture
self.multi_query = config.multi_query
if self.new_decoder_architecture:
self.total_num_kv_heads = config.num_kv_heads
elif self.multi_query:
self.total_num_kv_heads = 1
else:
self.total_num_kv_heads = self.total_num_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary
self.use_alibi = config.alibi
assert not (self.use_rotary and self.use_alibi), (
"Rotary and alibi are mutually exclusive.")
if self.use_rotary:
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
quant_config=quant_config)
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
self.inv_norm_factor)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = Attention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes,
quant_config=quant_config)
else:
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_rotary:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, bias = self.dense(attn_output)
return attn_output, bias
class FalconMLP(nn.Module):
def __init__(
self,
config: FalconConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config)
self.act = get_act_fn("gelu")
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
bias=config.bias,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results,
quant_config=quant_config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
x, bias = self.dense_h_to_4h(x)
if bias is not None:
x += bias
x = self.act(x)
x, bias = self.dense_4h_to_h(x)
return x, bias
class FalconDecoderLayer(nn.Module):
def __init__(
self,
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, cache_config,
quant_config)
self.mlp = FalconMLP(config, quant_config)
self.config = config
if (not hasattr(config, "num_ln_in_parallel_attn")):
config.num_ln_in_parallel_attn = None
if (config.num_ln_in_parallel_attn is None
and config.new_decoder_architecture):
config.num_ln_in_parallel_attn = 2
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
else:
if config.num_ln_in_parallel_attn == 2:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
if self.config.num_ln_in_parallel_attn == 2:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(
positions=positions,
hidden_states=attention_layernorm_out,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias
if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual += attention_output
mlp_layernorm_out = self.post_attention_layernorm(residual)
if (self.config.new_decoder_architecture and self.config.parallel_attn
and self.config.num_ln_in_parallel_attn == 1):
mlp_layernorm_out = attention_layernorm_out
# MLP.
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
if self.reduce_row_parallel_results and mlp_bias is not None:
mlp_output += mlp_bias
if not self.reduce_row_parallel_results:
# When MLP and Attention layers are parallel, we can use
# only one all-reduce operator to reduce the results from
# both MLP and Attention layers.
mlp_output += attention_output
mlp_output = tensor_model_parallel_all_reduce(mlp_output)
if attention_bias is not None:
mlp_output += attention_bias
if mlp_bias is not None:
mlp_output += mlp_bias
output = mlp_output + residual
return output
@support_torch_compile
class FalconModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_alibi = config.alibi
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size,
self.embed_dim,
)
# Transformer blocks
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: FalconDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.h")
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.word_embeddings(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class FalconForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {}
default_bitsandbytes_target_modules = [
".query_key_value.",
".dense.",
".dense_h_to_4h.",
".dense_4h_to_h.",
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default
self.tie_word_embeddings = (config.tie_word_embeddings
if config.tie_word_embeddings is not None
else True)
if self.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
total_num_heads = self.config.num_attention_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
elif self.config.multi_query:
total_num_kv_heads = 1
else:
total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if name == "lm_head.weight" and self.tie_word_embeddings:
# Falcon uses tied embeddings except Falcon-11b.
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
output_dim = getattr(param, "output_dim", None)
loaded_weight_shape = loaded_weight.shape
if output_dim is not None:
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] +
(total_num_kv_heads, num_query_heads_per_kv_head + 2,
-1) + loaded_weight_shape[output_dim + 1:])
wq = loaded_weight.narrow(
output_dim + 1, 0,
num_query_heads_per_kv_head).reshape(
*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wk = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wv = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head + 1,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,257 @@
import math
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
BartParallelLMHead,
BartScaledWordEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import AutoWeightsLoader
class Florence2LanguageModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
self.encoder = BartEncoder(config,
cache_config=cache_config,
quant_config=quant_config)
self.decoder = BartDecoder(config,
cache_config=cache_config,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.encoder.embed_tokens.weight = self.shared.weight
self.decoder.embed_tokens.weight = self.shared.weight
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Model output torch.Tensor
"""
encoder_hidden_states = None
if encoder_input_ids.numel() > 0:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states,
kv_caches=kv_caches,
attn_metadata=attn_metadata)
return decoder_outputs
class Florence2LanguageForConditionalGeneration(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.model = Florence2LanguageModel(vllm_config=vllm_config,
prefix=prefix)
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.vocab_size = config.vocab_size
self.lm_head = BartParallelLMHead(self.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.vocab_size,
config.vocab_size)
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Output torch.Tensor
"""
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "final_logits_bias" in name:
continue
if self.config.tie_word_embeddings and "embed_tokens" in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class Florence2ForConditionalGeneration(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
# TODO(Isotr0py): Add vision backbone
self.language_model = Florence2LanguageForConditionalGeneration(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=prefix,
)
@property
def sampler(self):
return self.language_model.sampler
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Output torch.Tensor
"""
return self.language_model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
skip_prefixes = [
'image_projection', "vision_tower", "image_proj_norm",
"image_pos_embed", "visual_temporal_embed"
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader.load_weights(weights)

View File

@@ -0,0 +1,359 @@
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
from array import array
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
import torch.utils.checkpoint
from PIL import Image
from transformers import FuyuImageProcessor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
class FuyuImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""
Shape:
(batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
"""
def _calculate_num_image_tokens(
height: int,
width: int,
) -> Tuple[int, int]:
"""
calculate number of image tokens needed for a given image size
The expected Fuyu image prompts is in format:
(image_token * ncols + newline_token) * nrows
args:
image_size: Tuple[int, int] - (width, height) of the image
returns:
ncols: int - number of image tokens in x direction
nrows: int - number of image tokens in y direction
"""
ncol = math.ceil(width / 30)
nrow = math.ceil(height / 30)
return ncol, nrow
def get_max_fuyu_image_feature_size():
return _calculate_num_image_tokens(
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
def get_max_fuyu_image_tokens(ctx: InputContext):
ncol, nrow = get_max_fuyu_image_feature_size()
return (ncol + 1) * nrow
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
ncol, nrow = get_max_fuyu_image_feature_size()
image_feature_size = get_max_fuyu_image_tokens(ctx)
image_token_ids = (
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_fuyu(
num_images: int,
*,
image_width: int,
image_height: int,
):
image = Image.new("RGB", (image_width, image_height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
mm_data = dummy_image_for_fuyu(num_images,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
return DummyData(seq_data, mm_data, ranges)
def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
data: List[Image.Image]):
image_encoding = image_processor.preprocess(data, return_tensors="pt")
batch_images = torch.stack([img[0] for img in image_encoding["images"]
]).unsqueeze(1)
image_unpadded_heights = torch.tensor(
image_encoding["image_unpadded_heights"])
image_unpadded_widths = torch.tensor(
image_encoding["image_unpadded_widths"])
batch_size = len(image_encoding["images"])
image_present = torch.ones(batch_size, 1, 1)
model_image_input = image_processor.preprocess_with_tokenizer_info(
image_input=batch_images,
image_present=image_present,
image_unpadded_h=image_unpadded_heights,
image_unpadded_w=image_unpadded_widths,
image_placeholder_id=_IMAGE_TOKEN_ID,
image_newline_id=_NEWLINE_TOKEN_ID,
variable_sized=True,
)
return model_image_input
def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
image_data = multi_modal_data["image"]
new_multi_modal_data = {}
image_list = image_data if isinstance(image_data, list) else [image_data]
# process image data
if is_list_of(image_list, Image.Image):
# Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, image_data)
image_patches = torch.cat([
image_patch[0]
for image_patch in model_image_input["image_patches"]
])
new_multi_modal_data["image"] = image_patches
elif is_list_of(image_list, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
# process prompts
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(model_config.model)
# dim0 is batch_size, dim1 is subseq_size which will always be 1
image_input_ids: List[List[
torch.Tensor]] = model_image_input["image_input_ids"]
image_input_ids = image_input_ids[0][0].tolist()
bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
new_prompt = prompt + "\x04"
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
1:] + boa_token
return token_inputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=new_multi_modal_data)
def input_mapper_for_fuyu(ctx: InputContext, data: object):
model_config = ctx.model_config
data_list = data if isinstance(data, list) else [data]
if is_list_of(data_list, Image.Image):
# Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, data_list)
data = torch.stack([
image_patch[0]
for image_patch in model_image_input["image_patches"]
])
# image has been processed with prompt in input processor
return MultiModalKwargs({"pixel_values": data})
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.text_config.vocab_size
self.image_token_id = _IMAGE_TOKEN_ID
self.image_feature_size = config.patch_size**2 * config.num_channels
self.vision_embed_tokens = ColumnParallelLinear(
self.image_feature_size,
config.hidden_size,
quant_config=quant_config,
gather_output=True,
)
self.language_model = PersimmonForCausalLM(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.patch_size
num_channels = self.config.num_channels
expected_dims = num_channels * h * w
def _validate_shape(d: torch.Tensor):
actual_dims = d.size(-1)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data.to(self.vision_embed_tokens.weight.dtype)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(pixel_values)}")
return FuyuImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
return None
def _process_image_input(
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
):
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = None
hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.language_model.logits_processor(
self.language_model.lm_head, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.language_model.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@@ -0,0 +1,466 @@
# Copyright 2023 The vLLM team.
# Copyright (c) Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
@lru_cache(maxsize=None)
def _get_gemma_act_fn(
hidden_act: Optional[str],
hidden_activation: Optional[str],
) -> nn.Module:
if hidden_activation is None:
if hidden_act is not None:
logger.warning(
"Gemma's activation function was incorrectly set to exact GeLU "
"in the config JSON file when it was initially released. "
"Changing the activation function to approximate GeLU "
"(`gelu_pytorch_tanh`). If you want to use the legacy "
"`%s`, edit the config JSON to set "
"`hidden_activation=%s` instead of `hidden_act`. "
"See https://github.com/huggingface/transformers/pull/29402 "
"for more details.", hidden_act, hidden_act)
return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu_pytorch_tanh":
return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu":
return GeluAndMul(approximate="none")
else:
raise ValueError(f"Activation function {hidden_act} is not "
"supported for Gemma models.")
class GemmaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: Optional[str] = None,
hidden_activation: Optional[str] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class GemmaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int = 8192,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=self.rope_theta,
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GemmaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = GemmaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
hidden_activation=getattr(config, "hidden_activation", None),
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class GemmaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GemmaDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
# data type such as bfloat16, not float32.
# See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.normalizer
residual = None
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.lora_config = lora_config
self.quant_config = quant_config
self.model = GemmaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)

View File

@@ -0,0 +1,500 @@
# Copyright 2024 The vLLM team.
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class Gemma2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
hidden_activation: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
raise ValueError(
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_act` and `hidden_activation` to "
"`gelu_pytorch_tanh`.")
self.act_fn = GeluAndMul(approximate="tanh")
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Gemma2Attention(nn.Module):
def __init__(self,
layer_idx: int,
config: Gemma2Config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int,
rope_theta: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
attn_logits_soft_cap: Optional[float] = None) -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = config.query_pre_attn_scalar**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=self.rope_theta,
is_neox_style=True,
)
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# odd layer, vLLM currently ignores it and uses global attention for
# all layers.
use_sliding_window = (layer_idx % 2 == 1
and config.sliding_window is not None)
del use_sliding_window # Unused.
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class Gemma2DecoderLayer(nn.Module):
def __init__(
self,
layer_idx: int,
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Gemma2Attention(
layer_idx=layer_idx,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta,
cache_config=cache_config,
quant_config=quant_config,
attn_logits_soft_cap=config.attn_logit_softcapping,
)
self.hidden_size = config.hidden_size
self.mlp = Gemma2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.pre_feedforward_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
return hidden_states, residual
@support_torch_compile
class Gemma2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[
-1]), config, cache_config, quant_config),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
# data type such as bfloat16, not float32.
# See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
del lora_config # Unused.
super().__init__()
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.quant_config = quant_config
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(weights)
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

View File

@@ -0,0 +1,297 @@
# Adapted from
# https://github.com/THUDM/GLM-4
"""Inference-only GLM-4v model visual encoder compatible with THUDM weights."""
from argparse import Namespace
from typing import Optional
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class PatchEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.proj = nn.Conv2d(config.in_channels,
config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size)
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
self.position_embedding = nn.Embedding(config.num_positions,
config.hidden_size)
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
images = images.to(self.proj.weight.device)
x = self.proj(images)
x = x.flatten(2).transpose(1, 2)
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x += self.position_embedding.weight.unsqueeze(0)
return x
class Attention(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_rank = config.num_heads // self.tp_size
self.head_dim = config.hidden_size // config.num_heads
self.scale = self.head_dim**-0.5
self.query_key_value = QKVParallelLinear(
config.hidden_size,
self.head_dim,
config.num_heads,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
quant_config=quant_config,
)
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, _ = x.shape
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
q, k, v = qkv.chunk(3, dim=-1)
q = q.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
k = k.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
v = v.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
out = torch.nn.functional.scaled_dot_product_attention(q,
k,
v,
attn_mask=None,
dropout_p=0.,
is_causal=False)
output, _ = self.dense(out.transpose(1, 2).view(B, L, -1))
output = self.output_dropout(output)
return output
class MLP(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x)
x = self.activation_fn(x)
x, _ = self.fc2(x)
return x
class TransformerLayer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.input_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = Attention(config, quant_config=quant_config)
self.mlp = MLP(config, quant_config=quant_config)
self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states):
attention_input = hidden_states
attention_output = self.input_layernorm(
self.attention(attention_input))
hidden_states = attention_input + attention_output
mlp_input = hidden_states
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
output = mlp_input + mlp_output
return output
class Transformer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
def forward(self, hidden_states):
for layer_module in self.layers:
hidden_states = layer_module(hidden_states)
return hidden_states
class GLU(nn.Module):
def __init__(
self,
config,
in_features,
quant_config: Optional[QuantizationConfig] = None,
):
"""
The original implementation is the same as:
```python
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
self.gate_proj = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
```
```
gate_proj_output, _ = self.gate_proj(x)
dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
```
We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
```
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config
)
```
```
x, _ = self.merged_proj(x)
```
"""
super().__init__()
self.linear_proj = ReplicatedLinear(in_features,
config.hidden_size,
bias=False,
quant_config=quant_config)
self.norm1 = nn.LayerNorm(config.hidden_size)
self.act1 = nn.GELU()
self.act2 = SiluAndMul()
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size, [config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config)
self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config)
def forward(self, x):
x, _ = self.linear_proj(x)
x = self.act1(self.norm1(x))
x, _ = self.merged_proj(x)
x = self.act2(x)
x, _ = self.dense_4h_to_h(x)
return x
class EVA2CLIPModel(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
vision_config = Namespace(**config.vision_config)
self.patch_embedding = PatchEmbedding(vision_config)
self.transformer = Transformer(vision_config,
quant_config=quant_config)
self.linear_proj = GLU(config,
in_features=config.hidden_size,
quant_config=quant_config)
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,
stride=2)
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.scaling_factor = vision_config.scaling_factor
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
x = self.patch_embedding(images)
x = self.transformer(x)
x = x[:, 1:]
b, s, h = x.shape
grid_size = int(s**0.5)
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
x = self.conv(x)
x = x.flatten(2).transpose(1, 2)
x = self.linear_proj(x)
boi = self.boi.expand(x.shape[0], -1, -1)
eoi = self.eoi.expand(x.shape[0], -1, -1)
x = torch.cat((boi, x, eoi), dim=1)
x = x / self.scaling_factor
return x

View File

@@ -0,0 +1,327 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GPT2Attention(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5
self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_attn",
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
class GPT2MLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: GPT2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_fc",
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class GPT2Block(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim,
config,
quant_config,
prefix=f"{prefix}.mlp")
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
@support_torch_compile
class GPT2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: GPT2Block(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPT2LMHeadModel(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.wte(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,337 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTBigCodeAttention(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
self.tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % self.tensor_model_parallel_world_size == 0
self.num_heads = (total_num_heads //
self.tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5
self.multi_query = config.multi_query
if self.multi_query:
total_num_kv_heads = 1
self.num_kv_heads = 1
else:
total_num_kv_heads = total_num_heads
self.num_kv_heads = self.num_heads
self.kv_dim = self.head_dim * self.num_kv_heads
self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
total_num_kv_heads,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
[
self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim, self.kv_dim
],
dim=-1,
)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
class GPTBigMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class GPTBigCodeBlock(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
@support_torch_compile
class GPTBigCodeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
assert not config.add_cross_attention
self.embed_dim = config.hidden_size
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.wte = VocabParallelEmbedding(self.vocab_size,
self.embed_dim,
org_num_embeddings=config.vocab_size)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
prefix=prefix)
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(
self.transformer.vocab_size,
self.transformer.embed_dim,
org_num_embeddings=self.config.vocab_size)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "lm_head.weight" in name:
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
weight_loader(param, loaded_weight, 'q')
weight_loader(param, loaded_weight, 'k')
weight_loader(param, loaded_weight, 'v')
else:
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,320 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GPTJAttention(nn.Module):
def __init__(
self,
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_size,
self.total_num_heads,
bias=False,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
scaling = self.head_size**-0.5
assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=config.rotary_dim,
max_position=max_position_embeddings,
base=rope_theta,
is_neox_style=False,
)
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, _ = self.out_proj(attn_output)
return attn_output
class GPTJMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: GPTJConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear(
hidden_size,
intermediate_size,
quant_config=quant_config,
)
self.fc_out = RowParallelLinear(
intermediate_size,
hidden_size,
quant_config=quant_config,
)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.fc_out(hidden_states)
return hidden_states
class GPTJBlock(nn.Module):
def __init__(
self,
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, cache_config, quant_config)
self.mlp = GPTJMLP(inner_dim, config, quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
mlp_output = self.mlp(hidden_states)
hidden_states = attn_output + mlp_output + residual
return hidden_states
@support_torch_compile
class GPTJModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_dim = config.n_embd
self.wte = VocabParallelEmbedding(
config.vocab_size,
self.embed_dim,
)
self.start_layer, self.end_layer, self.h = make_layers(
config.n_layer,
lambda prefix: GPTJBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTJForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.lm_head = ParallelLMHead(
config.vocab_size,
config.n_embd,
bias=True,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "attn.bias" in name or "attn.masked_bias" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,327 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GPTNeoXAttention(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.bias = getattr(config, "attention_bias", True)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.query_key_value = QKVParallelLinear(
config.hidden_size,
self.head_size,
self.total_num_heads,
bias=self.bias,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=self.bias,
quant_config=quant_config,
)
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.dense(attn_output)
return output
class GPTNeoXMLP(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
)
self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
)
self.act = get_act_fn(config.hidden_act)
def forward(self, hidden_states):
hidden_states, _ = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.dense_4h_to_h(hidden_states)
return hidden_states
class GPTNeoXLayer(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config, cache_config, quant_config)
self.mlp = GPTNeoXMLP(config, quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
attn_input = self.input_layernorm(hidden_states)
attn_output = self.attention(
position_ids=position_ids,
hidden_states=attn_input,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_input = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(mlp_input)
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_input = self.post_attention_layernorm(attn_output)
mlp_output = self.mlp(mlp_input)
hidden_states = mlp_output + attn_output
return hidden_states
@support_torch_compile
class GPTNeoXModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_in = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
prefix=f"{prefix}.layers",
)
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_in(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt_neox"))
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,540 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM Granite model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, make_layers,
maybe_prefix)
class GraniteMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class GraniteAttention(nn.Module):
def __init__(
self,
config: GraniteConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = config.attention_multiplier
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class GraniteDecoderLayer(nn.Module):
def __init__(
self,
config: GraniteConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.residual_multiplier = config.residual_multiplier
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attn = GraniteAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = GraniteMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states * self.residual_multiplier
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states
@support_torch_compile
class GraniteModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GraniteDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
hidden_states *= self.config.embedding_multiplier
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
return hidden_states
class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = GraniteModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0)
if hasattr(config, "logits_scaling"):
logit_scale /= config.logits_scaling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

View File

@@ -0,0 +1,443 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GraniteMoe model."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from . import mixtral
from .interfaces import SupportsLoRA, SupportsPP
from .utils import make_layers, maybe_prefix
class GraniteMoeMoE(nn.Module):
"""A tensor-parallel MoE implementation for GraniteMoe that shards each
expert across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = hidden_size
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size,
num_experts,
bias=False,
params_dtype=params_dtype,
quant_config=None,
prefix=f"{prefix}.gate")
self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
class GraniteMoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
attention_multiplier: Optional[float] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = (attention_multiplier if attention_multiplier
is not None else self.head_dim**-1)
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class GraniteMoeDecoderLayer(nn.Module):
def __init__(
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = GraniteMoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
attention_multiplier=config.attention_multiplier)
self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.residual_multiplier = config.residual_multiplier
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states * self.residual_multiplier
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states
@support_torch_compile
class GraniteMoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.embedding_multiplier = config.embedding_multiplier
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GraniteMoeDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.embedding_multiplier
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
return hidden_states
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
"layer",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = GraniteMoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=1 /
self.config.logits_scaling)
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
new_weights = {}
for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'):
for e in range(p.size(0)):
w1_name = n.replace(
'.block_sparse_moe.input_linear.weight',
".block_sparse_moe.experts.%d.w1.weight" % e)
w3_name = n.replace(
'.block_sparse_moe.input_linear.weight',
".block_sparse_moe.experts.%d.w3.weight" % e)
w1_param, w3_param = p[e].chunk(2, dim=0)
assert w1_name not in new_weights
assert w3_name not in new_weights
new_weights[w1_name] = w1_param
new_weights[w3_name] = w3_param
elif n.endswith('.block_sparse_moe.output_linear.weight'):
for e in range(p.size(0)):
w2_name = n.replace(
'.block_sparse_moe.output_linear.weight',
".block_sparse_moe.experts.%d.w2.weight" % e)
w2_param = p[e]
assert w2_name not in new_weights
new_weights[w2_name] = w2_param
elif n.endswith('.block_sparse_moe.router.layer.weight'):
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
".block_sparse_moe.gate.weight")
assert gate_name not in new_weights
new_weights[gate_name] = p
elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
pass
else:
new_weights[n] = p
mixtral.MixtralForCausalLM.load_weights(self, new_weights.items())

View File

@@ -0,0 +1,400 @@
# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py
# --------------------------------------------------------
# H2OVL-Mississippi
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
from typing import List, Optional, Tuple
import torch
from PIL import Image
from transformers import PretrainedConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.utils import is_list_of
from .intern_vit import InternVisionModel
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, InternVLChatModel,
InternVLInputPipeline, build_transform,
find_closest_aspect_ratio, get_internvl_num_patches)
# modified to include blocks generated in second pass
def calculate_num_blocks(
orig_width: int,
orig_height: int,
min_num: int,
max_num: int,
image_size: int,
use_thumbnail: bool,
prior_aspect_ratio=None,
) -> Tuple[int, int, int, Tuple[int, int]]:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
for i in range(1, n + 1) for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# if prior_aspect_ratio is provided, filter the target ratios
if prior_aspect_ratio is not None:
target_ratios = [
ratio for ratio in target_ratios if prior_aspect_ratio[0] %
ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0
]
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
target_ratios, orig_width,
orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# add thumbnail image if num_blocks > 1
if use_thumbnail and blocks > 1:
blocks += 1
return blocks, target_width, target_height, target_aspect_ratio
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
# refactored to handle prior_aspect_ratio as optional
def dynamic_preprocess(
image: Image.Image,
min_num: int,
max_num: int,
image_size: int,
use_thumbnail: bool,
prior_aspect_ratio: Optional[Tuple[int, int]] = None,
) -> Tuple[List[Image.Image], Tuple[int, int]]:
orig_width, orig_height = image.size
# calculate the number of blocks based on prior aspect ratio if available
blocks, target_width, target_height, target_aspect_ratio = (
calculate_num_blocks(
orig_width,
orig_height,
min_num,
max_num,
image_size,
use_thumbnail=False,
prior_aspect_ratio=prior_aspect_ratio,
))
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
def load_image(
image: Image.Image,
input_size=448,
min_num=1,
max_num=6,
use_thumbnail=True,
prior_aspect_ratio: Optional[Tuple[int, int]] = None,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
transform = build_transform(input_size=input_size)
images, target_aspect_ratio = dynamic_preprocess(
image,
image_size=input_size,
use_thumbnail=use_thumbnail,
min_num=min_num,
max_num=max_num,
prior_aspect_ratio=prior_aspect_ratio,
)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values, target_aspect_ratio
# refactored to use the combined load_image function
def image_to_pixel_values(
image: Image.Image,
input_size: int,
min_num: int,
max_num: int,
use_thumbnail: bool,
use_MSAC: bool,
) -> torch.Tensor:
# when MSAC is turned on, we need to process the image twice
if use_MSAC:
# first pass
pixel_values, target_aspect_ratio = load_image(
image,
input_size=input_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=True,
)
# second pass
pixel_values2, _ = load_image(
image,
input_size=input_size,
min_num=min_num,
max_num=max_num,
prior_aspect_ratio=target_aspect_ratio,
)
# combine pixel values
pixel_values = torch.cat(
[pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
else:
pixel_values, _ = load_image(
image,
input_size=input_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=use_thumbnail,
)
return pixel_values
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None,
use_MSAC: Optional[bool] = None):
image_size = hf_config.vision_config.image_size
min_num = hf_config.min_dynamic_patch
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
if use_MSAC is None:
use_MSAC = hf_config.use_msac
use_thumbnail = hf_config.use_thumbnail
return partial(
image_to_pixel_values,
input_size=image_size,
min_num=min_num,
max_num=max_dynamic_patch,
use_thumbnail=use_thumbnail,
use_MSAC=use_MSAC,
)
def get_max_internvl_image_tokens(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
"""
Calculate the maximum number of tokens with/without MSAC and thumbnail
"""
hf_config = ctx.get_hf_config()
use_thumbnail = hf_config.use_thumbnail
use_MSAC = hf_config.use_msac
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
num_patches = get_internvl_num_patches(hf_config)
coefficient = 2 if use_MSAC else 1
num_blocks = coefficient * max_dynamic_patch + (1 if use_thumbnail else 0)
return num_blocks * num_patches
class H2OVLInputPipeline(InternVLInputPipeline):
"""
Input pipeline for processing image and text data for the H2OVL model.
"""
def input_processor(
self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
max_dynamic_patch: Optional[int] = None,
) -> DecoderOnlyInputs:
# get multi_modal_data
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
use_MSAC = hf_config.use_msac
image_data = multi_modal_data["image"]
num_patches = get_internvl_num_patches(hf_config)
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch=max_dynamic_patch)
# single image
if isinstance(image_data, Image.Image):
pixel_values = image_pixel_values_mapper(image_data,
use_MSAC=use_MSAC)
num_blocks = pixel_values.shape[0]
image_feature_sizes = [num_blocks * num_patches]
pixel_values = pixel_values.unsqueeze(0)
# multi images
elif is_list_of(image_data, Image.Image):
# Do not use MSAC for multi images
image_feature_sizes = []
pixel_values = [
image_pixel_values_mapper(image, use_MSAC=False)
for image in image_data
]
for pixel_value in pixel_values:
num_blocks = pixel_value.shape[0]
image_feature_sizes.append(num_blocks * num_patches)
# image embeddings as input
elif isinstance(image_data, torch.Tensor):
_, image_feature_size, _ = image_data.shape
image_feature_sizes = [image_feature_size]
pixel_values = None
# multi-image image embeddings
elif is_list_of(image_data, torch.Tensor):
image_feature_sizes = []
for image_embed in image_data:
_, image_feature_size, _ = image_embed.shape
image_feature_sizes.append(image_feature_size)
pixel_values = None
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
# Wrap image processing in input_processor to avoid duplication
image_token_id = tokenizer.encode(
self.img_context_token,
add_special_tokens=False,
return_tensors="pt",
)[0]
# Update multi_modal_data to return
if pixel_values is not None:
multi_modal_data = {
"image": {
"pixel_values": pixel_values,
"image_token_id": image_token_id,
}
}
else:
multi_modal_data = {"image": {"image_embeds": image_data}}
return token_inputs(
prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data,
)
def input_mapper(
self,
ctx: InputContext,
data: object,
*,
max_dynamic_patch: Optional[int] = None,
) -> MultiModalKwargs:
# NOTE: Preprocessing for the image data is done in the
# 'input_processor' function during actual inference.
if isinstance(data, dict):
return MultiModalKwargs(data)
# The section below is only used with dummy data during
# memory profiling.
hf_config = ctx.get_hf_config()
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch)
if isinstance(data, Image.Image):
pixel_values = image_pixel_values_mapper(data)
pixel_values = pixel_values.unsqueeze(0)
elif is_list_of(data, Image.Image):
hf_config.use_msac = False
pixel_values = [image_pixel_values_mapper(img) for img in data]
else:
return MultiModalKwargs({"image_embeds": data})
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
image_token_id = tokenizer.encode(
self.img_context_token,
add_special_tokens=False,
return_tensors="pt",
)[0]
return MultiModalKwargs({
"pixel_values": pixel_values,
"image_token_id": image_token_id
})
input_pipeline = H2OVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class H2OVLChatModel(InternVLChatModel):
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
is_mono: bool,
prefix: str,
):
if not is_mono:
vision_feature_layer = config.select_layer
if vision_feature_layer < 0:
num_hidden_layers = (config.vision_config.num_hidden_layers +
vision_feature_layer + 1)
else:
num_hidden_layers = vision_feature_layer + 1
return InternVisionModel(
config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
else:
msg = "Monolith mode is not applicable to H2OVL"
raise NotImplementedError(msg)

View File

@@ -0,0 +1,703 @@
# coding=utf-8
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only HunYuan model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import re
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader,
maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .utils import (PPMissingLayer, is_pp_missing_parameter, make_layers,
maybe_prefix)
class HunYuanMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class HunYuanSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.experts = FusedMoE(num_experts=config.num_experts,
top_k=config.moe_topk,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
reduce_results=False,
renormalize=True if config.moe_topk>1 else False,
quant_config=quant_config)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=None)
if config.use_mixed_mlp_moe > 0:
self.shared_mlp = HunYuanMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size * config.num_shared_expert,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
else:
self.shared_mlp = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_mlp is not None:
shared_output = self.shared_mlp(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(orig_shape)
class HunYuanAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
attention_type: str = "self",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.use_qk_norm = config.use_qk_norm
self.attention_type = attention_type
if attention_type == "self":
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
elif attention_type == "cross":
self.q_proj = ColumnParallelLinear(
hidden_size,
hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
else:
raise RuntimeError("Not support attnention type")
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
if self.use_qk_norm:
self.query_layernorm = RMSNorm(self.head_dim,
eps=config.rms_norm_eps)
self.key_layernorm = RMSNorm(self.head_dim,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_states: Optional[Tuple[torch.Tensor]] = None,
) -> torch.Tensor:
if self.attention_type == "self":
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
ori_k = k
if self.use_qk_norm:
q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous())
k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous())
elif self.attention_type == "cross":
assert kv_states is not None
ori_k, v = kv_states # use last layer kv,
k = ori_k
q, _ = self.q_proj(hidden_states)
k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding
q, _ = self.rotary_emb(positions, q, k_tmp)
if self.use_qk_norm:
q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous())
k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous())
else:
raise RuntimeError("Not support attnention type")
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output, (ori_k, v)
class HunYuanDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
layer_id: int = -1,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
cla_factor = getattr(config, "cla_share_factor", 1)
attention_type = "cross" \
if layer_id >= 0 and layer_id % cla_factor != 0 else "self"
self.self_attn = HunYuanAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
attention_type=attention_type,
)
if getattr(config, "num_experts", None):
self.mlp = HunYuanSparseMoeBlock(config=config,
quant_config=quant_config)
else:
self.mlp = HunYuanMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
kv_states: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, ori_kv_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
kv_states=kv_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual, ori_kv_states
class HunYuanModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: HunYuanDecoderLayer(config=config,
layer_id=int(
prefix.split(".")[-1]),
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
cla_factor = getattr(self.config, "cla_share_factor", 1)
prev_kv_states = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual, kv_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
# kv_caches[(i - self.start_layer) // cla_factor],
attn_metadata,
residual,
prev_kv_states,
)
if (i - self.start_layer) % cla_factor == 0:
prev_kv_states = kv_states
else:
prev_kv_states = None
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class HunYuanForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = HunYuanModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
cla_factor = getattr(self.config, "cla_share_factor", 1)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
if getattr(self.config, "num_experts", None):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
else:
expert_params_mapping = {}
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
# cross layer only have q_proj, skip qkv pack
if weight_name == ".q_proj":
match = re.search(r'layers\.\d+', name)
if match:
layer_id = int(match.group(0).split('.')[-1])
if cla_factor > 1 and layer_id % cla_factor != 0:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
if "mlp.gate.wg." in name:
name = name.replace("wg.", "")
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

View File

@@ -0,0 +1,354 @@
# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py
# Copyright 2024 The vLLM team.
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Idefics2 model."""
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2Config, Idefics2VisionConfig)
from xformers import ops as xops
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
class Idefics2VisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings
` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision
Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the
need to resize them to the same fixed size. In particular, we start from the
original pre-trained SigLIP model(which uses images of fixed-size square
images) and adapt it by training on images of variable resolutions.
"""
def __init__(self, config: Idefics2VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
def forward(self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
1 / self.num_patches_per_side)
position_ids = torch.full(size=(batch_size,
max_nb_patches_h * max_nb_patches_w),
fill_value=0)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h,
boundaries,
right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w,
boundaries,
right=True)
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Idefics2VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv, _ = self.qkv_proj(
hidden_states
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
query_states = query_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
# see: https://facebookresearch.github.io/xformers/components/ops.html
out = xops.memory_efficient_attention_forward(
query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale,
)
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out)
return attn_output
class Idefics2VisionMLP(nn.Module):
def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class Idefics2EncoderLayer(nn.Module):
def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Idefics2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention
layers. Each layer is a
[`Idefics2EncoderLayer`].
Args:
config: Idefics2Config
"""
def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Idefics2EncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
inputs_embeds: torch.Tensor,
) -> torch.Tensor:
r"""
Args:
inputs_embeds (torch.Tensor):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation.
This is useful if you want more control over how to convert
`input_ids` indices into associated vectorsthan the model's
internal embedding lookup matrix.
"""
hidden_states = inputs_embeds
for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states)
hidden_states = layer_outputs
return hidden_states
class Idefics2VisionTransformer(nn.Module):
def __init__(
self,
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
embed_dim = config.hidden_size
self.config = config
self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(config,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes)
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,765 @@
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
import math
from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple,
Optional, Tuple, TypedDict, Union)
import torch
import torch.utils.checkpoint
from PIL import Image
from torch import nn
# Temporary solution for transformers below 4.46.0.
from transformers import PretrainedConfig as Idefics3Config
from transformers import ProcessorMixin as Idefics3ImageProcessor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
# yapf: disable
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
from .interfaces import SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
class Idefics3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
"""
pixel_attention_mask: Optional[torch.BoolTensor]
class Idefics3ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
class Idefics3ProcessorSize(NamedTuple):
"""Hashable wrapper for unhashable `size` dict of Idefics3Processor."""
# NOTE: cached_get_processor/cached_get_image_processor uses lru_cache,
# we need to use NamedTuple instead of TypedDict to avoid hashing issues.
longest_edge: int
def __contains__(self, key: str) -> bool:
return key in self._asdict() and getattr(self, key) is not None
def __getitem__(self, key: str) -> int:
return getattr(self, key)
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
def get_mm_processor_kwargs(size: Optional[Dict[str, int]] = None) -> Dict:
mm_processor_kwargs = {}
if size:
mm_processor_kwargs["size"] = Idefics3ProcessorSize(**size)
return mm_processor_kwargs
def input_mapper_for_idefics3(
ctx: InputContext,
data: object,
*,
size: Optional[Dict[str, int]] = None,
):
model_config = ctx.model_config
mm_processor_kwargs = get_mm_processor_kwargs(size)
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
if isinstance(data, Image.Image):
images = [[data]]
elif is_list_of(data, Image.Image):
images = [data]
else:
raise TypeError(f"Invalid image type: {type(data)}")
try:
batch_data = image_processor(images,
return_tensors="pt",
return_row_col_info=True).data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
return MultiModalKwargs(batch_data)
def _resize_output_size(height: int,
width: int,
max_len: Optional[int] = None,
min_len: Optional[int] = 1,
max_size: Optional[int] = None) -> Tuple[int, int]:
# Set default value for max_len if not provided
max_len = max(height, width) if max_len is None else max_len
aspect_ratio = width / height
# Handle the maximum size constraint
if max_size is not None:
max_len = min(max_len, max_size)
# Adjust dimensions according to the aspect ratio
if width >= height:
width = max_len
height = int(width / aspect_ratio)
else:
height = max_len
width = int(height * aspect_ratio)
# Ensure both width and height are even (if needed)
height += 1 if height % 2 != 0 else 0
width += 1 if width % 2 != 0 else 0
# Ensure dimensions are not smaller than the minimum length
height = max(height, min_len)
width = max(width, min_len)
return height, width
def _get_resize_output_image_size(
image_size: Tuple[int, int],
resolution_max_side: int,
max_image_size: int = 1820,
) -> Tuple[int, int]:
if resolution_max_side > max_image_size:
raise ValueError(
"`resolution_max_side` cannot be larger than `max_image_size`")
height, width = image_size
# Find the output size, when rescaling the longest edge to max_len and
# preserving the aspect ratio
height, width = _resize_output_size(height,
width,
max_len=resolution_max_side)
return height, width
def _prompt_split_image(image_seq_len: int, image_rows: int, image_cols: int,
fake_token_around_image: str, image_token: str,
global_img_token: str) -> str:
"""
Prompt with expanded image tokens for when the image is split
into patches.
"""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (fake_token_around_image +
f"<row_{n_h + 1}_col_{n_w + 1}>" +
image_token * image_seq_len)
text_split_images += "\n"
text_split_images += "\n" + _prompt_single_image(
image_seq_len=image_seq_len,
fake_token_around_image=fake_token_around_image,
image_token=image_token,
global_img_token=global_img_token)
return text_split_images
def _prompt_single_image(image_seq_len: int, fake_token_around_image: str,
image_token: str, global_img_token: str):
"""Prompt with expanded image tokens for a single image."""
return (fake_token_around_image + global_img_token +
image_token * image_seq_len + fake_token_around_image)
def _get_image_prompt_string(image_rows: int, image_cols: int,
image_seq_len: int, fake_token_around_image: str,
image_token: str, global_img_token: str):
if image_rows == 0 and image_cols == 0:
return _prompt_single_image(
image_seq_len=image_seq_len,
fake_token_around_image=fake_token_around_image,
image_token=image_token,
global_img_token=global_img_token,
)
return _prompt_split_image(image_seq_len, image_rows, image_cols,
fake_token_around_image, image_token,
global_img_token)
def input_processor_for_idefics3(ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
size: Optional[Dict[str, int]] = None):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
mm_processor_kwargs = get_mm_processor_kwargs(size)
processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
image_processor = processor.image_processor
tokenizer = processor.tokenizer
size = image_processor.size['longest_edge']
max_image_size = image_processor.max_image_size['longest_edge']
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_list = [image_data]
elif is_list_of(image_data, Image.Image):
image_list = image_data
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
image_rows = []
image_cols = []
for image in image_list:
height, width = _get_resize_output_image_size(image.size, size)
rows = math.ceil(height / max_image_size)
cols = math.ceil(width / max_image_size)
image_rows.append(rows)
image_cols.append(cols)
image_rows = [image_rows]
image_cols = [image_cols]
n_images_in_text = []
text = inputs.get("prompt")
if text is not None:
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, "
"or a list of strings")
fake_image_token = processor.fake_image_token.content
image_token = processor.image_token.content
global_img_token = processor.global_image_tag
prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows,
image_cols):
n_images_in_text.append(sample.count(image_token))
# Replace the image token with fake tokens around the expanded
# image token sequence of length `image_seq_len`
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = _get_image_prompt_string(
n_rows,
n_cols,
processor.image_seq_len,
image_token=image_token,
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
)
image_prompt_strings.append(image_prompt_string)
split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError(
"The image token should be present in the text.")
# Place in the image prompt strings where the image tokens are
sample = split_sample[0]
for i, image_prompt_string in enumerate(image_prompt_strings):
sample += image_prompt_string + split_sample[i + 1]
prompt_strings.append(sample)
prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids
return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt=prompt_strings[0],
multi_modal_data=multi_modal_data,
)
def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int:
size = image_processor.size['longest_edge']
max_image_size = image_processor.max_image_size['longest_edge']
resized_height, resized_width = size, size
grid_h = resized_height // max_image_size
grid_w = resized_width // max_image_size
return (grid_h * grid_w + 1)
def get_max_idefics3_image_tokens(ctx: InputContext,
*,
size: Optional[Dict[str,
int]] = None) -> int:
model_config = ctx.model_config
mm_processor_kwargs = get_mm_processor_kwargs(size)
processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
image_seq_len = processor.image_seq_len
image_processor = processor.image_processor
max_num_image_patches = _get_max_num_image_patch(image_processor)
return max_num_image_patches * image_seq_len
def dummy_data_for_idefics3(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
size: Optional[Dict[str, int]] = None) -> DummyData:
hf_config = ctx.get_hf_config()
num_images = mm_counts["image"]
mm_processor_kwargs = get_mm_processor_kwargs(size)
processor = cached_get_processor(ctx.model_config.model,
**mm_processor_kwargs)
max_num_image_patches = _get_max_num_image_patch(processor.image_processor)
image_seq_len = processor.image_seq_len
max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images
if seq_len - max_llm_image_tokens < 0:
raise RuntimeError(
f"Idefics3 cannot process {num_images} images in a prompt, "
"please increase max_model_len or reduce image limit by "
"--limit-mm-per-prompt.")
seq_data = SequenceData.from_prompt_token_counts(
(hf_config.image_token_id, max_llm_image_tokens),
(0, seq_len - max_llm_image_tokens))
width = height = hf_config.vision_config.image_size
image = Image.new("RGB", (width, height), color=0)
mm_data = {"image": [image] if num_images == 1 else [image] * num_images}
return DummyData(seq_data, mm_data)
class Idefics3SimpleMLP(nn.Module):
def __init__(
self,
config: Idefics3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
input_size = config.vision_config.hidden_size * (config.scale_factor**
2)
output_size = config.text_config.hidden_size
self.proj = ReplicatedLinear(
input_size,
output_size,
bias=False,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "proj"),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out, _ = self.proj(x)
return out
class Idefics3Connector(nn.Module):
def __init__(
self,
config: Idefics3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.scale_factor = config.scale_factor
self.modality_projection = Idefics3SimpleMLP(
config,
quant_config,
prefix=maybe_prefix(prefix, "modality_projection"),
)
def pixel_shuffle(self,
x: torch.Tensor,
scale_factor: int = 2) -> torch.Tensor:
bsz, seq, embed_dim = x.size()
height = width = int(seq**0.5)
x = x.view(bsz, height, width, embed_dim)
x = x.view(bsz, height, int(width / scale_factor),
embed_dim * scale_factor)
x = x.permute(0, 2, 1, 3)
x = x.reshape(
bsz,
int(width / scale_factor),
int(height / scale_factor),
embed_dim * (scale_factor**2),
)
x = x.permute(0, 2, 1, 3)
x = x.reshape(bsz, int(seq / (scale_factor**2)),
embed_dim * (scale_factor**2))
return x
def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
image_hidden_states = self.pixel_shuffle(image_hidden_states,
self.scale_factor)
image_hidden_states = self.modality_projection(image_hidden_states)
return image_hidden_states
class Idefics3Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = self.config.text_config.pad_token_id
self.vocab_size = self.config.text_config.vocab_size
self.vision_model = Idefics3VisionTransformer(
config.vision_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"))
self.connector = Idefics3Connector(
config,
quant_config,
prefix=maybe_prefix(prefix, "connector"),
)
self.text_model = LlamaModel(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "text_model"),
)
self.image_seq_len = int(
((config.vision_config.image_size //
config.vision_config.patch_size)**2) / (config.scale_factor**2))
self.image_token_id = self.config.image_token_id
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
if pixel_values is None and image_embeds is None:
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return Idefics3ImagePixelInputs(type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values,
concat=True)),
pixel_attention_mask=flatten_bn(
pixel_attention_mask,
concat=True))
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(
self,
pixel_values: torch.Tensor,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.to(
dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
) # fp16 compatibility
pixel_values = pixel_values.view(batch_size * num_images,
*pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(pixel_values.size(0), pixel_values.size(2),
pixel_values.size(3)),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask.view(
batch_size * num_images, *pixel_attention_mask.shape[2:])
pixel_attention_mask = pixel_attention_mask[
real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1,
size=patch_size,
step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2,
size=patch_size,
step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
return image_hidden_states
def _process_image_pixels(
self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
assert self.vision_model is not None
pixel_values = inputs["data"]
pixel_attention_mask = inputs["pixel_attention_mask"]
return self._image_pixels_to_features(pixel_values,
pixel_attention_mask)
def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
return self.connector(image_features)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.text_model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = self.text_model.get_input_embeddings(input_ids)
input_ids = None
hidden_states = self.text_model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_idefics3)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3)
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision_model
"fc1",
"fc2",
"out_proj",
# text_model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision_model
".fc1.",
".fc2.",
".out_proj.",
# connector
".proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.model = Idefics3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.image_token_id = self.config.image_token_id
self.lm_head = ParallelLMHead(
config.text_config.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config,
)
if self.config.text_config.tie_word_embeddings:
self.lm_head.weight = self.model.text_model.wte.weight
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
**kwargs,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="model.text_model",
connector="model.connector",
tower_model="model.vision_model")

View File

@@ -0,0 +1,352 @@
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)
import torch
from typing_extensions import TypeIs
from vllm.logger import init_logger
from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
@runtime_checkable
class SupportsMultiModal(Protocol):
"""The interface required for all multi-modal models."""
supports_multimodal: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports multi-modal inputs.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def __init__(self, *, multimodal_config: "MultiModalConfig") -> None:
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsMultiModalType(Protocol):
supports_multimodal: Literal[True]
def __call__(self, *, multimodal_config: "MultiModalConfig") -> None:
...
@overload
def supports_multimodal(
model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
...
@overload
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
...
def supports_multimodal(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
if isinstance(model, type):
return isinstance(model, _SupportsMultiModalType)
return isinstance(model, SupportsMultiModal)
@runtime_checkable
class SupportsLoRA(Protocol):
"""The interface required for all models that support LoRA."""
supports_lora: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports LoRA.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
supported_lora_modules: ClassVar[List[str]]
embedding_modules: ClassVar[Dict[str, str]]
embedding_padding_modules: ClassVar[List[str]]
# lora_config is None when LoRA is not enabled
def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
supports_lora: Literal[True]
packed_modules_mapping: Dict[str, List[str]]
supported_lora_modules: List[str]
embedding_modules: Dict[str, str]
embedding_padding_modules: List[str]
def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
...
@overload
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
...
@overload
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
...
def supports_lora(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
result = _supports_lora(model)
if not result:
lora_attrs = (
"packed_modules_mapping",
"supported_lora_modules",
"embedding_modules",
"embedding_padding_modules",
)
missing_attrs = tuple(attr for attr in lora_attrs
if not hasattr(model, attr))
if getattr(model, "supports_lora", False):
if missing_attrs:
logger.warning(
"The model (%s) sets `supports_lora=True`, "
"but is missing LoRA-specific attributes: %s",
model,
missing_attrs,
)
else:
if not missing_attrs:
logger.warning(
"The model (%s) contains all LoRA-specific attributes, "
"but does not set `supports_lora=True`.", model)
return result
def _supports_lora(model: Union[Type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)
return isinstance(model, SupportsLoRA)
@runtime_checkable
class SupportsPP(Protocol):
"""The interface required for all models that support pipeline parallel."""
supports_pp: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports pipeline parallel.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def make_empty_intermediate_tensors(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> "IntermediateTensors":
"""Called when PP rank > 0 for profiling purposes."""
...
def forward(
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
"""
Accept :class:`IntermediateTensors` when PP rank > 0.
Return :class:`IntermediateTensors` only for the last PP rank.
"""
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
supports_pp: Literal[True]
def make_empty_intermediate_tensors(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> "IntermediateTensors":
...
def forward(
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
...
@overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
...
@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
...
def supports_pp(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
supports_attributes = _supports_pp_attributes(model)
supports_inspect = _supports_pp_inspect(model)
if supports_attributes and not supports_inspect:
logger.warning(
"The model (%s) sets `supports_pp=True`, but does not accept "
"`intermediate_tensors` in its `forward` method", model)
if not supports_attributes:
pp_attrs = ("make_empty_intermediate_tensors", )
missing_attrs = tuple(attr for attr in pp_attrs
if not hasattr(model, attr))
if getattr(model, "supports_pp", False):
if missing_attrs:
logger.warning(
"The model (%s) sets `supports_pp=True`, "
"but is missing PP-specific attributes: %s",
model,
missing_attrs,
)
else:
if not missing_attrs:
logger.warning(
"The model (%s) contains all PP-specific attributes, "
"but does not set `supports_pp=True`.", model)
return supports_attributes and supports_inspect
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsPPType)
return isinstance(model, SupportsPP)
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
return supports_kw(model_forward, "intermediate_tensors")
@runtime_checkable
class HasInnerState(Protocol):
"""The interface required for all models that has inner state."""
has_inner_state: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config
for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
"""
def __init__(self,
*,
scheduler_config: Optional["SchedulerConfig"] = None) -> None:
...
@runtime_checkable
class _HasInnerStateType(Protocol):
has_inner_state: ClassVar[Literal[True]]
def __init__(self,
*,
scheduler_config: Optional["SchedulerConfig"] = None) -> None:
...
@overload
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
...
@overload
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
...
def has_inner_state(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
if isinstance(model, type):
return isinstance(model, _HasInnerStateType)
return isinstance(model, HasInnerState)
@runtime_checkable
class IsAttentionFree(Protocol):
"""The interface required for all models like Mamba that lack attention,
but do have state whose size is constant wrt the number of tokens."""
is_attention_free: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has no attention.
Used for block manager and attention backend selection.
True for Mamba but not Jamba.
"""
def __init__(self) -> None:
...
@runtime_checkable
class _IsAttentionFreeType(Protocol):
is_attention_free: ClassVar[Literal[True]]
def __init__(self) -> None:
...
@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
...
@overload
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
...
def is_attention_free(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
if isinstance(model, type):
return isinstance(model, _IsAttentionFreeType)
return isinstance(model, IsAttentionFree)

View File

@@ -0,0 +1,175 @@
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
overload, runtime_checkable)
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger
from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
logger = init_logger(__name__)
# The type of HF config
C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True)
# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
# which has T = List[torch.Tensor]
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes
@runtime_checkable
class VllmModel(Protocol[C_co, T_co]):
def __init__(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> None:
...
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
) -> T_co:
...
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
model_init = model.__init__
return supports_kw(model_init, "vllm_config")
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata")
missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_forward, kw))
if missing_kws and (isinstance(model, type)
and issubclass(model, nn.Module)):
logger.warning(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s",
model,
missing_kws,
)
return len(missing_kws) == 0
@overload
def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]:
...
@overload
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
...
def is_vllm_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]:
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
@runtime_checkable
class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]):
def compute_logits(
self,
hidden_states: T,
sampling_metadata: "SamplingMetadata",
) -> Optional[T]:
"""Return `None` if TP rank > 0."""
...
def sample(
self,
logits: T,
sampling_metadata: "SamplingMetadata",
) -> "SamplerOutput":
"""Only called on TP rank 0."""
...
@overload
def is_text_generation_model(
model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]:
...
@overload
def is_text_generation_model(
model: object) -> TypeIs[VllmModelForTextGeneration]:
...
def is_text_generation_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForTextGeneration]],
TypeIs[VllmModelForTextGeneration]]:
if not is_vllm_model(model):
return False
if isinstance(model, type):
return isinstance(model, VllmModelForTextGeneration)
return isinstance(model, VllmModelForTextGeneration)
@runtime_checkable
class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
def pooler(
self,
hidden_states: T,
pooling_metadata: "PoolingMetadata",
) -> "PoolerOutput":
"""Only called on TP rank 0."""
...
@overload
def is_embedding_model(
model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]:
...
@overload
def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]:
...
def is_embedding_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]:
if not is_vllm_model(model):
return False
if isinstance(model, type):
return isinstance(model, VllmModelForEmbedding)
return isinstance(model, VllmModelForEmbedding)

View File

@@ -0,0 +1,478 @@
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
NORM2FN = {
'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm,
}
class InternVisionEmbeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.patch_embedding = nn.Conv2d(in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size)
self.num_patches = (self.image_size // self.patch_size)**2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(
torch.randn(1, self.num_positions, self.embed_dim))
def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
target_dtype = pos_embed.dtype
pos_embed = pos_embed.float().reshape(
1, self.image_size // self.patch_size,
self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
pos_embed = F.interpolate(pos_embed,
size=(H, W),
mode='bicubic',
align_corners=False)
return pos_embed.reshape(1, -1, H * W).permute(0, 2,
1).to(target_dtype)
def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
position_embedding = self.position_embedding
if self.num_patches == H * W:
return position_embedding
return torch.cat(
[
position_embedding[:, :1, :],
self._get_pos_embed(position_embedding[:, 1:, :], H, W),
],
dim=1,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(
target_dtype)) # shape = [*, channel, width, height]
batch_size, _, height, width = patch_embeds.shape
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1,
-1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embedding = self._get_position_embedding(height, width)
embeddings = embeddings + position_embedding.to(target_dtype)
return embeddings
class InternVisionPatchModel(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embeddings = InternVisionEmbeddings(config)
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_embeds: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
if pixel_values is None and pixel_embeds is None:
raise ValueError(
'You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
elif pixel_values is not None:
if pixel_values.ndim == 4:
hidden_states = self.embeddings(pixel_values)
else:
raise ValueError(
f'wrong pixel_values size: {pixel_values.shape}')
return hidden_states
class InternParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
self.tp_size)
self.scale = self.head_dim**-0.5
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
num_dummy_heads + self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = RowParallelLinear(
self.dummy_dim,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, _ = x.shape
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1)
x, _ = self.proj(x)
return x
class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: PretrainedConfig,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim,
3 * self.dummy_dim,
bias=config.qkv_bias)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).view(B, N, -1)
x = self.proj(x)
return x
class InternMLP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class InternVisionEncoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
self.attn = self._init_attn(config,
quant_config,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.attn")
self.mlp = InternMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)
self.ls1 = nn.Parameter(config.initializer_factor *
torch.ones(self.embed_dim))
self.ls2 = nn.Parameter(config.initializer_factor *
torch.ones(self.embed_dim))
def _init_attn(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
num_dummy_heads: int,
prefix: str = "",
):
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads,
prefix=prefix)
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
def forward(
self,
hidden_states: torch.Tensor,
):
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states)) * self.ls1
hidden_states = hidden_states + self.mlp(
self.norm2(hidden_states)) * self.ls2
return hidden_states
class InternVisionEncoder(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
InternVisionEncoderLayer(config,
quant_config,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
return hidden_states
class InternVisionModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder",
)
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_embeds: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
if pixel_values is None and pixel_embeds is None:
raise ValueError(
'You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
elif pixel_values is not None:
if pixel_values.ndim == 4:
hidden_states = self.embeddings(pixel_values)
else:
raise ValueError(
f'wrong pixel_values size: {pixel_values.shape}')
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
return encoder_outputs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,399 @@
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class InternLM2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.w2 = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.w2",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.w2(x)
return x
class InternLM2Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.total_num_heads = num_heads
assert self.total_num_heads % self.tp_size == 0
self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= self.tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % self.tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert self.tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.key_value_groups = int(self.num_heads / self.num_kv_heads)
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.wqkv = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wqkv",
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wo",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def split_qkv(self, qkv: torch.Tensor):
seq_len = qkv.shape[0]
if self.tp_size > 1:
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
qkv = tensor_model_parallel_all_gather(qkv)
qkv = torch.split(qkv, qkv_map, dim=-1)
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
qkv = torch.cat(qkv, dim=-1)
qkv = qkv.view(seq_len, self.total_num_kv_heads,
self.key_value_groups + 2, self.head_dim)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
q = q.reshape(seq_len, self.q_size * self.tp_size)
k = k.reshape(seq_len, self.kv_size * self.tp_size)
v = v.reshape(seq_len, self.kv_size * self.tp_size)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
return q, k, v
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states)
q, k, v = self.split_qkv(qkv)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.wo(attn_output)
return output
class InternLMDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.attention = InternLM2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attention",
)
self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.attention_norm(hidden_states)
else:
hidden_states, residual = self.attention_norm(
hidden_states, residual)
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.ffn_norm(hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
@support_torch_compile
class InternLM2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.tok_embeddings = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLMDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = InternLM2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "output"))
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.output, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,165 @@
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.internlm2 import (InternLM2Attention,
InternLM2ForCausalLM,
InternLM2MLP, InternLM2Model)
from vllm.sequence import IntermediateTensors
from .utils import make_layers, maybe_prefix
class InternLM2VEDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.attention = InternLM2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attention",
)
self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
self.feed_forward_ve = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward_ve",
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
visual_token_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.attention_norm(hidden_states)
else:
hidden_states, residual = self.attention_norm(
hidden_states, residual)
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.ffn_norm(hidden_states, residual)
if visual_token_mask is not None and visual_token_mask.any():
visual_token_mask = visual_token_mask.repeat(
1, self.hidden_size).bool()
text_token_mask = ~visual_token_mask
hidden_states[visual_token_mask] = self.feed_forward_ve(
hidden_states[visual_token_mask].reshape(
-1, self.hidden_size)).flatten()
if text_token_mask.any():
hidden_states[text_token_mask] = self.feed_forward(
hidden_states[text_token_mask].reshape(
-1, self.hidden_size)).flatten()
else:
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
class InternLM2VEModel(InternLM2Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLM2VEDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
visual_token_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
visual_token_mask=visual_token_mask,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.model = InternLM2VEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

View File

@@ -0,0 +1,668 @@
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import re
from functools import cached_property, partial
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.quantization import (AWQConfig,
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternVLImageInputs = Union[InternVLImagePixelInputs,
InternVLImageEmbeddingInputs]
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
max_num: int, image_size: int,
use_thumbnail: bool) -> Tuple[int, int, int]:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
for i in range(1, n + 1) for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
target_ratios, orig_width,
orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# add thumbnail image if num_blocks > 1
if use_thumbnail and blocks > 1:
blocks += 1
return blocks, target_width, target_height
def calculate_num_blocks_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None):
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
min_num = hf_config.min_dynamic_patch
image_size = hf_config.vision_config.image_size
use_thumbnail = hf_config.use_thumbnail
return partial(calculate_num_blocks,
min_num=min_num,
max_num=max_dynamic_patch,
image_size=image_size,
use_thumbnail=use_thumbnail)
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
image_size: int,
use_thumbnail: bool) -> List[Image.Image]:
orig_width, orig_height = image.size
# calculate the number of blocks without thumbnail
blocks, target_width, target_height = calculate_num_blocks(
orig_width,
orig_height,
min_num,
max_num,
image_size,
use_thumbnail=False)
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = ((i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
max_num: int, use_thumbnail: bool) -> torch.Tensor:
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image,
min_num=min_num,
max_num=max_num,
image_size=input_size,
use_thumbnail=use_thumbnail)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None):
image_size = hf_config.vision_config.image_size
min_num = hf_config.min_dynamic_patch
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
return partial(image_to_pixel_values,
input_size=image_size,
min_num=min_num,
max_num=max_dynamic_patch,
use_thumbnail=use_thumbnail)
def get_internvl_num_patches(hf_config: PretrainedConfig):
vision_config = hf_config.vision_config
downsample_ratio = hf_config.downsample_ratio
image_size = vision_config.image_size
patch_size = vision_config.patch_size
return int(
get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
(downsample_ratio**2))
def get_max_internvl_image_tokens(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
hf_config = ctx.get_hf_config()
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
if use_thumbnail and max_dynamic_patch > 1:
max_dynamic_patch += 1
num_patches = get_internvl_num_patches(hf_config)
return num_patches * max_dynamic_patch
def get_max_internvl_image_size(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
hf_config = ctx.get_hf_config()
image_size = hf_config.vision_config.image_size
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
if use_thumbnail and max_dynamic_patch > 1:
max_dynamic_patch += 1
width = image_size * max_dynamic_patch
height = image_size
return width, height
class InternVLInputPipeline:
def __init__(
self,
img_start_token: str,
img_end_token: str,
img_context_token: str,
) -> None:
super().__init__()
self.img_start_token = img_start_token
self.img_end_token = img_end_token
self.img_context_token = img_context_token
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
return (self.img_start_token + self.img_context_token * feature_size +
self.img_end_token)
def _expand_image_prompt(
self,
prompt: str,
feature_sizes: List[int],
num_patches: int,
) -> str:
image_idx = sorted(
map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
new_prompt = prompt
for idx, feature_size in enumerate(feature_sizes, start=1):
image_prompt = self._create_image_prompt(feature_size, num_patches)
if not image_idx:
image_prompt = f"Image-{idx}: {image_prompt}"
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
return new_prompt
def input_processor(
self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
max_dynamic_patch: Optional[int] = None,
) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
image_data = multi_modal_data["image"]
num_patches = get_internvl_num_patches(hf_config)
num_blocks_calculator = calculate_num_blocks_wrapper(
hf_config, max_dynamic_patch)
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_sizes = [num_blocks * num_patches]
elif is_list_of(image_data, Image.Image):
image_feature_sizes = []
for image in image_data:
width, height = image.size
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_sizes.append(num_blocks * num_patches)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
image_feature_sizes = [image_feature_size]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
return token_inputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
def input_mapper(
self,
ctx: InputContext,
data: object,
*,
max_dynamic_patch: Optional[int] = None,
):
hf_config = ctx.get_hf_config()
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch)
if isinstance(data, Image.Image):
data = image_pixel_values_mapper(data)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because images may have different num_patches
data = [image_pixel_values_mapper(img) for img in data]
else:
return MultiModalKwargs({"image_embeds": data})
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_token_id = tokenizer.encode(self.img_context_token,
add_special_tokens=False,
return_tensors="pt")[0]
return MultiModalKwargs({
"pixel_values": data,
"image_token_id": image_token_id
})
def dummy_data(
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
max_dynamic_patch: Optional[int] = None,
):
num_images = mm_counts["image"]
hf_config = ctx.get_hf_config()
image_feature_size = get_max_internvl_image_tokens(
ctx, max_dynamic_patch=max_dynamic_patch)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
seq_data, ranges = dummy_seq_data_for_clip(
hf_config.vision_config,
seq_len,
num_images,
image_token_id=tokenizer.encode(self.img_context_token,
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)
max_image_width, max_image_height = get_max_internvl_image_size(
ctx, max_dynamic_patch=max_dynamic_patch)
mm_data = dummy_image_for_clip(
hf_config.vision_config,
num_images,
image_width_override=max_image_width,
image_height_override=max_image_height,
)
return DummyData(seq_data, mm_data, ranges)
input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config)
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2))
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
self.mlp1 = self._init_mlp1(config)
self.img_context_token_id = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _patch_quant_config(self, config: PretrainedConfig,
quant_config: QuantizationConfig):
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
llm_quant_config = getattr(text_config, "quantization_config",
None)
if (not quant_config.modules_to_not_convert) and \
(llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model")
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
is_mono: bool,
prefix: str,
):
if not is_mono:
vision_feature_layer = config.select_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
return InternVisionModel(
config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
else:
return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
return nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
if self.ps_version == 'v1':
pass
else:
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
vit_embeds = self.vision_model(pixel_values=pixel_values)
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1]**0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds,
scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
vit_embeds.shape[-1])
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_token_id = kwargs.pop("image_token_id", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternVLImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
)
self.img_context_token_id = image_token_id[0]
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(flatten_bn(pixel_values), concat=True)),
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["data"])
return image_embeds
def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
if self.is_mono:
visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1)
else:
visual_token_mask = None
return visual_token_mask
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
visual_token_mask = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
visual_token_mask = self._get_visual_token_mask(input_ids)
input_ids = None
else:
inputs_embeds = None
visual_token_mask = None
forward_kwargs = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
if self.is_mono:
forward_kwargs.update({"visual_token_mask": visual_token_mask})
hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@@ -0,0 +1,374 @@
# Adapted from
# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
# reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import JAISConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class SwiGLUActivation(nn.Module):
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return x1 * nn.functional.silu(x2)
def _get_alibi_slopes(n):
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
class JAISAttention(nn.Module):
def __init__(
self,
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
if hasattr(config, "scale_qk_dot_by_d"):
config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
self.scale = self.head_dim**-self.attn_scale_power
self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
)
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end]
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
class JAISMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: JAISConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.swiglu = config.activation_function == "swiglu"
self.c_fc = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config,
)
self.c_fc2 = (ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config,
) if self.swiglu else None)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
)
self.act = SwiGLUActivation()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.swiglu:
hidden_states2, _ = self.c_fc2(hidden_states)
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = (self.act(hidden_states, hidden_states2)
if self.swiglu else self.act(hidden_states))
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class JAISBlock(nn.Module):
def __init__(
self,
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = JAISAttention(config, cache_config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, quant_config)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
@support_torch_compile
class JAISModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = (nn.Embedding(config.max_position_embeddings,
self.embed_dim)
if config.position_embedding_type != "alibi" else None)
if hasattr(config, "embeddings_scale"):
self.embeddings_scale = config.embeddings_scale
else:
self.embeddings_scale = config.mup_embeddings_scale
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: JAISBlock(config=config,
cache_config=cache_config,
quant_config=quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[IntermediateTensors, torch.Tensor]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids)
if self.wpe is not None:
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = inputs_embeds
hidden_states *= torch.tensor(float(self.embeddings_scale),
dtype=hidden_states.dtype)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class JAISLMHeadModel(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = JAISModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale
else:
self.output_logits_scale = (config.mup_output_alpha *
config.mup_width_scale)
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
scale=self.output_logits_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[IntermediateTensors, torch.Tensor]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if "relative_pe" in name:
continue
if not name.startswith("transformer."):
name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,532 @@
"""Inference-only Jamba model."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
from .interfaces import HasInnerState, SupportsLoRA
from .utils import maybe_prefix
KVCache = Tuple[torch.Tensor, torch.Tensor]
class JambaMoE(nn.Module):
def __init__(self,
config: JambaConfig,
num_experts: Optional[int] = None,
top_k: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.num_total_experts = num_experts or config.num_experts
self.top_k = top_k or config.num_experts_per_tok
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
if self.num_total_experts > 1:
self.router = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
quant_config=None,
params_dtype=params_dtype)
self.experts = FusedMoE(self.num_total_experts,
self.top_k,
self.hidden_size,
self.intermediate_size,
tp_size=tp_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=False,
use_grouped_topk=False,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts)
if self.num_total_experts > 1:
router_logits, _ = self.router(hidden_states)
else:
router_logits = torch.ones((hidden_states.shape[0], 1),
device=hidden_states.device,
dtype=hidden_states.dtype)
hidden_states = self.experts(hidden_states, router_logits)
return hidden_states.view(orig_shape)
class JambaMLP(JambaMoE):
def __init__(self,
config: JambaConfig,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__(config,
num_experts=1,
top_k=1,
params_dtype=params_dtype,
tp_size=tp_size,
quant_config=quant_config)
class JambaMambaDecoderLayer(nn.Module):
def __init__(self,
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.mamba = MambaMixer(hidden_size= config.hidden_size,
ssm_state_size = config.mamba_d_state,
conv_kernel_size = config.mamba_d_conv,
intermediate_size = config.mamba_expand *\
config.hidden_size,
time_step_rank = config.mamba_dt_rank,
use_conv_bias = config.mamba_conv_bias,
use_bias = config.mamba_proj_bias,
use_rms_norm=True,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act)
num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata,
mamba_cache_params)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
class JambaAttentionDecoderLayer(nn.Module):
def __init__(
self,
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def self_attention(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
ALL_DECODER_LAYER_TYPES = {
"attention": JambaAttentionDecoderLayer,
"mamba": JambaMambaDecoderLayer
}
class JambaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
decoder_layers = []
for i in range(config.num_hidden_layers):
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
decoder_layers.append(
layer_class(config,
layer_idx=i,
cache_config=cache_config,
quant_config=quant_config))
self.layers = nn.ModuleList(decoder_layers)
self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
kv_cache = None
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
self.config.attn_layer_period]
if isinstance(layer, JambaMambaDecoderLayer):
current_state_layer = i - (1 +
(i - self.config.attn_layer_offset)
// self.config.attn_layer_period)
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=layer_mamba_cache_params)
hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Jamba currently does not support prefix caching"
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.model = JambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
layers_type = self.config.layers_block_type
num_mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type])
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape())
(
mamba_cache_tensors,
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_params)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape = (
self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_conv - 1,
)
temporal_state_shape = (
self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "A_log" in name:
name = name.replace("A_log", "A")
if ".self_attn." in name:
name = name.replace(".self_attn", "")
if "feed_forward" in name and not _is_moe_layer(name):
## map MLP layers to expert with ID=0
name = name.replace("feed_forward", "feed_forward.experts.0")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if 'experts' in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for (
param_name,
weight_name,
expert_id,
shard_id,
) in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def _is_moe_layer(name: str):
return any(
[experts_name in name for experts_name in [
"experts",
"router",
]])

View File

@@ -0,0 +1,694 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class LlamaAttention(nn.Module):
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings"
}
embedding_padding_modules = ["lm_head"]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config
self.lora_config = lora_config
self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else
lora_config.lora_vocab_padding_size),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,
softmax=False)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
logits = self.compute_logits(hidden_states, None)
return self._pooler(logits, pooling_metadata)
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(
self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
mapping = self.mistral_mapping
modules = name.split(".")
# rotary embeds should be sliced
if "wk" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif "wq" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
for item in modules:
if item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])
return name, loaded_weight
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
# LRUCacheWorkerLoRAManager instantiation requires model config.
@property
def config(self):
return self.model.config

View File

@@ -0,0 +1,552 @@
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
Tuple, TypedDict, Union)
import torch
import torch.nn as nn
from PIL import Image
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
PretrainedConfig, SiglipVisionConfig)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
dummy_seq_data_for_pixtral_hf,
get_max_pixtral_hf_image_tokens,
input_processor_for_pixtral_hf)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str):
super().__init__()
self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size,
bias=True)
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_image_tokens = get_max_clip_image_tokens(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_image_tokens = get_max_siglip_image_tokens(vision_config)
elif isinstance(vision_config, PixtralVisionConfig):
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
return num_image_tokens - 1
elif strategy == "full":
return num_image_tokens
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_siglip(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, PixtralVisionConfig):
seq_data, ranges = dummy_seq_data_for_pixtral_hf(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_max_llava_image_tokens(ctx)
elif is_list_of(image_data, Image.Image):
image_feature_size = [get_max_llava_image_tokens(ctx)
] * len(image_data)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, PixtralVisionConfig):
# We ignore image_feature_size_override since we have non-uniform
# image sizes for Pixtral
return input_processor_for_pixtral_hf(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
class LlavaLikeConfig(Protocol):
vision_config: PretrainedConfig
vision_feature_layer: int
def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig,
quant_config: Optional[QuantizationConfig],
*,
require_post_norm: Optional[bool] = None,
prefix: str = "",
):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
)
elif isinstance(vision_config, PixtralVisionConfig):
return PixtralHFVisionModel(
vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
# NOTE: These are special cases for Pixtral-12B in the HF-format
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
if (config.text_config.architectures is None
and config.text_config.model_type == "mistral"):
config.text_config.architectures = ["MistralForCausalLM"]
if (config.projector_hidden_act is None
and config.vision_config.hidden_act == "gelu"):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _validate_image_sizes(self, images: List[torch.Tensor],
sizes: List[torch.Tensor]) -> List[torch.Tensor]:
if not isinstance(sizes, list):
sizes = [sizes]
total_images = sum(size.numel() // 2 for size in sizes)
if total_images != len(images):
raise ValueError("Mismatch in number of images. "
f"Expected {total_images}, got {len(images)}")
img_idx = 0
for size in sizes:
# Flatten the size tensor to a list of (height, width) pairs
size = size.view(-1, 2).tolist()
for expected_h, expected_w in size:
if img_idx >= len(images):
raise ValueError("Ran out of images before sizes. "
f"{img_idx} >= {len(images)}")
img = images[img_idx]
if img.shape[-2:] != (expected_h, expected_w):
raise ValueError(
"Image size mismatch. Expected "
f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
if img.shape[-3] != 3:
raise ValueError("Image channel mismatch. Expected 3, "
f"got {img.shape[-3]}")
img_idx += 1
return images
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Case for models like PixtralHF that have dynamic image sizes
# so we need to produce a list of tensors
if image_sizes is not None:
images = pixel_values
def flatten_to_3d_tensors(item):
if isinstance(item, torch.Tensor):
if item.dim() >= 3:
return [t for t in item.view(-1, *item.shape[-3:])]
else:
raise ValueError(
f"Unexpected tensor dimension: {item.dim()}")
elif isinstance(item, list):
return [
t for subitem in item
for t in flatten_to_3d_tensors(subitem)
]
else:
raise ValueError(f"Unexpected type: {type(item)}")
# Restructure the batched images into a list of lists of images
images = flatten_to_3d_tensors(pixel_values)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_sizes(images, image_sizes),
)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
)
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)
def _process_image_pixels(self,
inputs: LlavaImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
def process_mm_inputs(self, **kwargs):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
vision_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if vision_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LLaVA-1.5.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
`"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
Tokenizer outputs:
`[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
29901]`.
We insert 575 tokens so that including the original image token in the
input, there are a total of 576 (24 * 24) image tokens, which
corresponds to the number of image tokens inputted to the language
model, i.e. the number of image tokens outputted by the visual encoder.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
See also:
:class:`LlavaImageInputs`
"""
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@@ -0,0 +1,679 @@
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
# Conditional import for transformers compatibility
try:
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
except ImportError:
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""Fallback implementation"""
height, width = image_size
best_resolution = None
for pinpoint in grid_pinpoints:
if pinpoint[0] >= height and pinpoint[1] >= width:
if best_resolution is None or (pinpoint[0] * pinpoint[1] < best_resolution[0] * best_resolution[1]):
best_resolution = pinpoint
if best_resolution is None:
best_resolution = grid_pinpoints[-1]
return (best_resolution[0] // patch_size, best_resolution[1] // patch_size)
def unpad_image(tensor, original_size):
"""Fallback implementation"""
return tensor
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model, maybe_prefix)
class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
image_sizes: NotRequired[torch.Tensor]
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
class LlavaNextImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs]
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
def _get_llava_next_num_unpadded_features(
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
current_height -= 2 * padding
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
current_width -= 2 * padding
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def get_llava_next_image_feature_size(
hf_config: LlavaNextConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_patches = get_clip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_clip_image_feature_size(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_patches = get_siglip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_siglip_image_feature_size(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
base_feature_size -= 1
elif strategy == "full":
pass
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)
(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_next_num_unpadded_features(input_height, input_width,
num_patches, num_patch_height,
num_patch_width)
return unpadded_feature_size + newline_feature_size + base_feature_size
def get_max_llava_next_image_tokens(ctx: InputContext):
"""Compute the max feature size for all possible image grid pinpoints."""
return _get_pinpoint_with_largest_features(ctx)[0]
def _get_pinpoint_with_largest_features(
ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
"""Get the grid pinpoint with the largest features & its feature size."""
hf_config = ctx.get_hf_config(LlavaNextConfig)
largest_feature_size = 0
largest_feature_pinpoint = None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = (height, width)
if not largest_feature_size or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_size, largest_feature_pinpoint
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
max_feat_height, max_feat_width = pinpoint
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
vision_config,
num_images,
image_width_override=max_feat_width,
image_height_override=max_feat_height,
)
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_siglip(
vision_config,
num_images,
image_width_override=max_feat_width,
image_height_override=max_feat_height,
)
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_llava_next(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
image_feature_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_next_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
]
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
return LlavaNextImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
)
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
patch_embeddings: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "flat":
return patch_embeddings.flatten(0, 1)
if strategy.startswith("spatial"):
height = width = self.config.vision_config.image_size \
// self.config.vision_config.patch_size
base_patch_embeds = patch_embeddings[0]
if height * width != base_patch_embeds.shape[0]:
raise ValueError(
"The number of patches is not consistent with the "
"image size.")
if patch_embeddings.shape[0] > 1:
other_patch_embeds = patch_embeddings[1:]
# Move to CPU to avoid floating-point errors
orig_height, orig_width = image_size.tolist()
# image_aspect_ratio == "anyres"
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
(orig_height, orig_width),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
num_patches = num_patch_height * num_patch_width
# Image patches might be padded for batch processing
other_patch_embeds = other_patch_embeds[:num_patches] \
.view(num_patch_height, num_patch_width, height, width, -1)
if "unpad" in strategy:
other_patch_embeds = other_patch_embeds \
.permute(4, 0, 2, 1, 3).contiguous() \
.flatten(1, 2).flatten(2, 3)
other_patch_embeds = unpad_image(other_patch_embeds,
(orig_height, orig_width))
other_patch_embeds = torch.cat((
other_patch_embeds,
self.image_newline[:, None, None] \
.expand(*other_patch_embeds.shape[:-1], 1) \
.to(other_patch_embeds.device),
), dim=-1)
other_patch_embeds = other_patch_embeds \
.flatten(1, 2).transpose(0, 1)
else:
other_patch_embeds = other_patch_embeds \
.permute(0, 2, 1, 3, 4).contiguous() \
.flatten(0, 3)
merged_patch_embeddings = torch.cat(
(base_patch_embeds, other_patch_embeds), dim=0)
else:
if "unpad" in strategy:
merged_patch_embeddings = torch.cat(
(base_patch_embeds,
self.image_newline[None] \
.to(base_patch_embeds.device)
), dim=0)
else:
merged_patch_embeddings = base_patch_embeds
return merged_patch_embeddings
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
def _process_image_pixels(
self,
inputs: LlavaNextImagePixelInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
assert self.vision_tower is not None
pixel_values = inputs["data"]
if isinstance(pixel_values, torch.Tensor):
b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
stacked_patch_embeddings = self.multi_modal_projector(
stacked_image_features)
return stacked_patch_embeddings.view(
b, num_patches, *stacked_patch_embeddings.shape[1:])
num_patches_per_batch = [v.shape[0] for v in pixel_values]
stacked_pixel_values = torch.cat(pixel_values)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
return [
self.multi_modal_projector(image_features) for image_features in
torch.split(stacked_image_features, num_patches_per_batch)
]
def _process_image_input(
self,
image_input: LlavaNextImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return [image_input["data"]]
patch_embeddings = self._process_image_pixels(image_input)
image_sizes = image_input.get("image_sizes")
if image_sizes is None:
batch_size = len(image_input["data"])
vision_config = self.config.vision_config
default_height = default_width = vision_config.image_size
image_sizes = torch.as_tensor([[default_height, default_width]
for _ in range(batch_size)])
return [
self._merge_image_patch_embeddings(image_sizes[i],
patch_features_batch,
strategy="spatial_unpad")
for i, patch_features_batch in enumerate(patch_embeddings)
]
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-NeXT.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
`"A chat between a curious human and an artificial intelligence
assistant. The assistant gives helpful, detailed, and polite answers to
the human's questions.
USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
Tokenizer outputs:
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
319, 1799, 9047, 13566, 29901]`.
Unlike in LLaVA-1.5, the number of image tokens inputted to the language
model depends on the original size of the input image. Including the
original image token in the input, the required number of image tokens
is given by :func:`get_llava_next_image_feature_size`.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each grid patch for each input image.
image_sizes: The original `(height, width)` for each input image.
See also:
:class:`LlavaNextImageInputs`
"""
if intermediate_tensors is not None:
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = embed_multimodal(
input_ids,
self.config.image_token_index,
self.language_model.model.get_input_embeddings,
lambda _: self._process_image_input(image_input),
)
else:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@@ -0,0 +1,454 @@
import math
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
SiglipVisionConfig)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip)
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
# For profile run
_MAX_FRAMES_PER_VIDEO = 32
_MAX_NUM_VIDEOS = 1
class LlavaNextVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, num_frames, num_channels, height, width)`
Note that `num_frames` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch.
"""
def get_llava_next_video_frame_feature_size(
hf_config: LlavaNextVideoConfig) -> int:
# Support both CLIPVisionConfig and SiglipVisionConfig
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
spatial_pool_stride = hf_config.spatial_pool_stride
return int((image_size / patch_size / spatial_pool_stride)**2)
def _get_max_llm_tokens(ctx: InputContext) -> int:
"""
Calculated from the maximum video frames under the context length
constraints of the language model.
"""
hf_text_config = ctx.model_config.hf_text_config
model_config = ctx.model_config
max_tokens = model_config.max_model_len
rope_scaling = model_config.rope_scaling
if rope_scaling:
rope_scaling_factor = hf_text_config.rope_scaling["factor"]
else:
rope_scaling_factor = 1
max_tokens *= rope_scaling_factor
return max_tokens
def get_max_llava_next_video_tokens(ctx: InputContext) -> int:
# Currently set to 32 frames
# TODO: max_tokens = _get_max_llm_tokens(ctx)
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
return _MAX_FRAMES_PER_VIDEO * tokens_per_frame
def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
vision_config = hf_config.vision_config
# TODO: support multiple videos
num_videos = mm_counts["video"]
if num_videos != _MAX_NUM_VIDEOS:
raise NotImplementedError(
f"Only {_MAX_NUM_VIDEOS} videos are supported")
# TODO: support configuring the number of frames
frames_per_video = _MAX_FRAMES_PER_VIDEO
# num_images = num_videos * frames_per_video
# fills the sequence with as longer video data as possible
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
video_feature_size = frames_per_video * tokens_per_frame
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video",
)
pil_frame = dummy_image_for_clip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video}
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video",
)
pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video}
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_llava_next_video(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "video" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
vision_config = hf_config.vision_config
if isinstance(video_data, np.ndarray):
# Supports both CLIP and Siglip
num_frames = video_data.shape[0]
frame_feature_size = \
get_llava_next_video_frame_feature_size(hf_config)
video_feature_size = num_frames * frame_feature_size
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray):
raise NotImplementedError(
"Processing multiple videos is not supported")
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
# adopted from transformers modeling_llava_next_video.py
class LlavaNextVideoPooler(nn.Module):
def __init__(self, config):
super().__init__()
mode = config.spatial_pool_mode
stride = config.spatial_pool_stride
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.image_size = image_size // patch_size**2
if mode == "average":
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride)
elif mode == "max":
self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
else:
# TODO: Support Conv2d pooling layer, need to load weights
raise ValueError(
f"Unknown pooling mode: {mode}. Expected [`average`, `max`]")
def forward(self, image_features):
ori_width = int(
math.sqrt(image_features.shape[1] * self.image_size //
self.image_size))
ori_height = int(ori_width * self.image_size // self.image_size)
batch_size, _, dim = image_features.shape
image_features_spatial = image_features \
.view(batch_size, ori_height, ori_height, dim) \
.permute(0, 3, 1, 2)
image_features_spatial = self.pool(image_features_spatial)
return image_features_spatial.flatten(2).transpose(1, 2).contiguous()
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str):
super().__init__()
self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size,
bias=True)
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
@MULTIMODAL_REGISTRY.register_input_mapper("video")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_llava_next_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[2:])
if actual_dims != expected_dims:
expected_expr = ("num_frames", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each video frame "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]:
"""
A legal video input should have the following dimensions:
{
"pixel_values_videos" :
List[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
pixel_values = kwargs.pop("pixel_values_videos", None)
if pixel_values is None:
return None
if not (is_list_of(pixel_values,
(torch.Tensor)) # different shape videos
or isinstance(pixel_values,
torch.Tensor)): # same shape videos
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaNextVideoPixelInputs(
type="pixel_values_videos",
data=pixel_values,
)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _video_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
image_features = self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)
image_features = self.vision_resampler(image_features)
image_features = self.multi_modal_projector(image_features)
return image_features
def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
assert self.vision_tower is not None
video_pixels = inputs["data"]
if isinstance(video_pixels, torch.Tensor):
# TODO: support multiple videos per input
b, num_videos, num_frames, c, h, w = video_pixels.shape
assert (num_videos == 1)
stacked_pixels = video_pixels.view(b * num_videos * num_frames, c,
h, w)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, stacked_pixels)
return stacked_embeddings.view(b, num_frames,
*stacked_embeddings.shape[1:])
elif is_list_of(video_pixels, torch.Tensor):
frames_per_videos = [v.shape[0] for v in video_pixels]
stacked_pixels = torch.cat(video_pixels, dim=0)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, stacked_pixels)
return torch.split(stacked_embeddings, frames_per_videos, dim=0)
else:
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-NeXT-Video.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values_videos: Pixels in each frames for each input videos.
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is not None:
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = self.language_model \
.model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
# This model doesn't support images for now
ignore_unexpected_prefixes=["image_newline"],
)
loader.load_weights(weights)

View File

@@ -0,0 +1,912 @@
import math
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import (CLIPVisionConfig, LlavaOnevisionConfig,
SiglipVisionConfig)
# Conditional import for transformers compatibility
try:
from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape, unpad_image)
except ImportError:
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""Fallback implementation"""
height, width = image_size
best_resolution = None
for pinpoint in grid_pinpoints:
if pinpoint[0] >= height and pinpoint[1] >= width:
if best_resolution is None or (pinpoint[0] * pinpoint[1] < best_resolution[0] * best_resolution[1]):
best_resolution = pinpoint
if best_resolution is None:
best_resolution = grid_pinpoints[-1]
return (best_resolution[0] // patch_size, best_resolution[1] // patch_size)
def unpad_image(tensor, original_size):
"""Fallback implementation"""
return tensor
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)`
Note that `num_videos` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a
list instead of a batched tensor.
"""
class LlavaOnevisionImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
image_sizes: NotRequired[torch.Tensor]
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
class LlavaOnevisionImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs,
LlavaOnevisionImageEmbeddingInputs]
LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
LlavaOnevisionVideoPixelInputs]
def _get_llava_onevision_image_unppaded_feature_size(height, width, patches,
scale_height,
scale_width):
current_height = patches * scale_height
current_width = patches * scale_width
original_aspect_ratio = width / height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = int(height * (current_width / width))
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = int(width * (current_height / height))
padding = (current_width - new_width) // 2
current_width -= padding * 2
unpadded_features = current_height * current_width
newline_features = current_height
ratio = math.sqrt(current_height * current_width / (9 * patches**2))
if ratio > 1.1:
unpadded_features = int(current_height // ratio) * int(
current_width // ratio)
newline_features = int(current_height // ratio)
return (unpadded_features, newline_features)
def get_llava_onevision_image_feature_size(
hf_config: LlavaOnevisionConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_patches = get_clip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_clip_image_feature_size(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_patches = get_siglip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_siglip_image_feature_size(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
base_feature_size -= 1
elif strategy == "full":
pass
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)
(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_onevision_image_unppaded_feature_size(
input_height, input_width, num_patches, num_patch_height,
num_patch_width)
return unpadded_feature_size + newline_feature_size + base_feature_size
def get_max_llava_onevision_image_tokens(ctx: InputContext):
return get_llava_onevision_image_feature_size(
ctx.get_hf_config(LlavaOnevisionConfig),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
def get_llava_onevision_video_frame_feature_size(
hf_config: LlavaOnevisionConfig) -> int:
# Support both CLIPVisionConfig and SiglipVisionConfig
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
spatial_pool_stride = hf_config.spatial_pool_stride if hasattr(
hf_config, "spatial_pool_stride") else 2
height = width = image_size // patch_size
return math.ceil(height / spatial_pool_stride) * math.ceil(
width / spatial_pool_stride)
def get_llava_onevision_video_tokens(ctx: InputContext,
num_frames: int) -> int:
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
# TODO: support configuring (not supported by HF right now)
num_token_image_newline = 1
tokens_per_frame = get_llava_onevision_video_frame_feature_size(hf_config)
video_feature_size = num_frames * tokens_per_frame + num_token_image_newline
return video_feature_size
def get_max_llava_onevision_video_tokens(ctx: InputContext) -> int:
return get_llava_onevision_video_tokens(ctx, _MAX_FRAMES_PER_VIDEO)
def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
num_videos = mm_counts["video"]
# TODO: support configuring the number of frames
num_frames = _MAX_FRAMES_PER_VIDEO
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video")
mm_data = dummy_video_for_clip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video")
mm_data = dummy_video_for_siglip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_when_multimodal_input_image(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
image_feature_size = get_llava_onevision_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_onevision_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
]
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_when_multimodal_input_video(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
if isinstance(video_data, np.ndarray):
# Supports both CLIP and Siglip
num_frames = video_data.shape[0]
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray):
video_feature_size = []
for video in video_data:
num_frames = video.shape[0]
video_feature_size.append(
get_llava_onevision_video_tokens(ctx, num_frames))
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
else:
raise TypeError(f"Invalid video type: {type(video_data)}")
msg = f"Unsupported video type: {type(video_data)}"
raise NotImplementedError(msg)
def input_processor_for_llava_onevision(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or ("video" not in multi_modal_data
and "image" not in multi_modal_data):
return inputs
if "image" in multi_modal_data:
return input_processor_when_multimodal_input_image(ctx, inputs)
if "video" in multi_modal_data:
return input_processor_when_multimodal_input_video(ctx, inputs)
msg = "Unsupported multi data type"
raise NotImplementedError(msg)
class LlavaOnevisionMultiModalProjector(nn.Module):
def __init__(self, config: LlavaOnevisionConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size,
config.text_config.hidden_size,
bias=True)
self.act = get_act_fn(config.projector_hidden_act)
self.linear_2 = nn.Linear(config.text_config.hidden_size,
config.text_config.hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_input_mapper("video")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"image", get_max_llava_onevision_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_llava_onevision_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_image_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return LlavaOnevisionImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixel_values(
flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
return LlavaOnevisionImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
)
raise AssertionError("This line should be unreachable.")
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[2:])
if actual_dims != expected_dims:
expected_expr = ("num_frames", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each video frame "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_video_input(
self,
**kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]:
"""
A legal video input should have the following dimensions:
{
"pixel_values_videos" :
List[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
pixel_values = kwargs.pop("pixel_values_videos", None)
if pixel_values is None:
return None
if not (is_list_of(pixel_values,
(torch.Tensor)) # different shape videos
or isinstance(pixel_values,
torch.Tensor)): # same shape videos
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaOnevisionVideoPixelInputs(
type="pixel_values_videos",
data=pixel_values,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
if "pixel_values" in kwargs:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if "pixel_values_videos" in kwargs:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
def _merge_image_patch_embeddings(self,
image_size: torch.Tensor,
patch_embeddings: torch.Tensor,
*,
image_newline=None,
vision_aspect_ratio="anyres_max_9",
strategy: str) -> torch.Tensor:
if strategy == "flat":
return patch_embeddings.flatten(0, 1)
if strategy.startswith("spatial"):
height = width = self.config.vision_config.image_size \
// self.config.vision_config.patch_size
base_patch_embeds = patch_embeddings[0]
if height * width != base_patch_embeds.shape[0]:
raise ValueError(
"The number of patches is not consistent with the "
"image size.")
if patch_embeddings.shape[0] > 1:
other_patch_embeds = patch_embeddings[1:]
# Move to CPU to avoid floating-point errors
orig_height, orig_width = image_size.tolist()
# image_aspect_ratio == "anyres"
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
(orig_height, orig_width),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
num_patches = num_patch_height * num_patch_width
# Image patches might be padded for batch processing
other_patch_embeds = other_patch_embeds[:num_patches] \
.view(num_patch_height, num_patch_width, height, width, -1)
if "unpad" in strategy:
other_patch_embeds = other_patch_embeds \
.permute(4, 0, 2, 1, 3).contiguous() \
.flatten(1, 2).flatten(2, 3)
other_patch_embeds = unpad_image(other_patch_embeds,
(orig_height, orig_width))
max_num_patches = int(
vision_aspect_ratio.removeprefix("anyres_max_"))
channels, curr_height, curr_width = other_patch_embeds.shape
ratio = math.sqrt(curr_height * curr_width /
(max_num_patches * height**2))
if ratio > 1.1:
other_patch_embeds = other_patch_embeds[None]
other_patch_embeds = nn.functional.interpolate(
other_patch_embeds, [
int(curr_height // ratio),
int(curr_width // ratio)
],
mode="bilinear")[0]
if image_newline is not None:
other_patch_embeds = torch.cat(
(
other_patch_embeds,
image_newline[:, None, None] \
.expand(*other_patch_embeds.shape[:-1], 1) \
.to(other_patch_embeds.device),
),
dim=-1)
other_patch_embeds = other_patch_embeds \
.flatten(1, 2).transpose(0, 1)
else:
other_patch_embeds = other_patch_embeds \
.permute(0, 2, 1, 3, 4).contiguous() \
.flatten(0, 3)
merged_patch_embeddings = torch.cat(
(base_patch_embeds, other_patch_embeds), dim=0)
else:
if "unpad" in strategy:
merged_patch_embeddings = torch.cat(
(base_patch_embeds,
self.image_newline[None] \
.to(base_patch_embeds.device)
), dim=0)
else:
merged_patch_embeddings = base_patch_embeds
return merged_patch_embeddings
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
def _process_image_pixels(
self,
inputs: LlavaOnevisionImagePixelInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
assert self.vision_tower is not None
pixel_values = inputs["data"]
if isinstance(pixel_values, torch.Tensor):
b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
stacked_patch_embeddings = self.multi_modal_projector(
stacked_image_features)
return stacked_patch_embeddings.view(
b, num_patches, *stacked_patch_embeddings.shape[1:])
num_patches_per_batch = [v.shape[0] for v in pixel_values]
stacked_pixel_values = torch.cat(pixel_values)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
return [
self.multi_modal_projector(image_features) for image_features in
torch.split(stacked_image_features, num_patches_per_batch)
]
def _process_image_input(
self,
image_input: LlavaOnevisionImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return [image_input["data"]]
patch_embeddings = self._process_image_pixels(image_input)
image_sizes = image_input.get("image_sizes")
if image_sizes is None:
batch_size = len(image_input["data"])
vision_config = self.config.vision_config
default_height = default_width = vision_config.image_size
image_sizes = torch.as_tensor([[default_height, default_width]
for _ in range(batch_size)])
return [
self._merge_image_patch_embeddings(
image_sizes[i],
patch_features_batch,
image_newline=self.image_newline,
strategy="spatial_unpad")
for i, patch_features_batch in enumerate(patch_embeddings)
]
def _add_image_newline(
self,
video_features: torch.Tensor,
videos: int = 1,
frames: int = 1,
strategy: str = "one_token",
) -> torch.Tensor:
if strategy == "one_token":
video_features = video_features.reshape(
videos, frames * video_features.shape[1], -1)
image_newline = self.image_newline[None, None, :].repeat(
videos, 1, 1).to(video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
return video_features
raise ValueError(f"Unexpected video newline strategy: {strategy}")
def _video_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
video_features = vision_tower(pixel_values)
video_features = self._select_image_features(
video_features,
strategy=self.config.vision_feature_select_strategy,
)
video_features = self.multi_modal_projector(video_features)
video_features = self.apply_pooling(video_features)
return video_features
def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
assert self.vision_tower is not None
video_pixels = inputs["data"]
if isinstance(video_pixels, torch.Tensor):
b, num_videos, frames, c, h, w = video_pixels.shape
pixel_values = video_pixels.view(b * num_videos * frames, c, h, w)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, pixel_values)
stacked_embeddings = self._add_image_newline(stacked_embeddings,
videos=b * num_videos,
frames=frames,
strategy="one_token")
return stacked_embeddings
elif is_list_of(video_pixels, torch.Tensor):
stacked_embeddings = []
for video_pixel in video_pixels:
num_videos, frames, c, h, w = video_pixel.shape
pixel_values = video_pixel.view(num_videos * frames, c, h, w)
embeddings = self._video_pixels_to_features(
self.vision_tower, pixel_values)
embeddings = self._add_image_newline(embeddings,
videos=num_videos,
frames=frames,
strategy="one_token")
stacked_embeddings.append(embeddings)
return stacked_embeddings
else:
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
def apply_pooling(self, image_features, stride=2):
vision_config = self.config.vision_config
height = width = vision_config.image_size // vision_config.patch_size
batch_frames, _, dim = image_features.shape
image_features = image_features.view(batch_frames, height, width, -1)
image_features = image_features.permute(0, 3, 1, 2)
# TODO support other pooling types config
height, width = image_features.shape[2:]
scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
image_feature = nn.functional.interpolate(image_features,
size=scaled_shape,
mode='bilinear')
image_feature = image_feature.permute(0, 2, 3, 1)
image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-Onevision.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values_videos: Pixels in each frames for each input videos.
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if modalities:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@@ -0,0 +1,247 @@
"""PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
from .utils import maybe_prefix
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MambaDecoderLayer(nn.Module):
def __init__(self,
config: MambaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba"
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
self.mixer = MambaMixer(hidden_size=config.hidden_size,
ssm_state_size=config.state_size,
conv_kernel_size=config.conv_kernel,
intermediate_size=config.intermediate_size,
time_step_rank=config.time_step_rank,
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
use_rms_norm=self.is_falcon_mamba,
rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata,
mamba_cache_params)
return hidden_states, residual
class MambaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embeddings = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
decoder_layers = []
for i in range(config.num_hidden_layers):
decoder_layers.append(
MambaDecoderLayer(config,
cache_config=cache_config,
quant_config=quant_config))
self.layers = nn.ModuleList(decoder_layers)
self.norm_f = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
) -> torch.Tensor:
hidden_states = self.embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(i))
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Mamba does not support prefix caching"
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.backbone = MambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "backbone"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
if config.tie_word_embeddings:
self.lm_head = self.backbone.embeddings
else:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, self.config.num_hidden_layers,
max_batch_size, *self._get_mamba_cache_shape())
(
mamba_cache_tensors,
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_params)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape = (
self.config.intermediate_size // world_size,
self.config.conv_kernel - 1,
)
temporal_state_shape = (
self.config.intermediate_size // world_size,
self.config.state_size,
)
return conv_state_shape, temporal_state_shape
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,158 @@
from dataclasses import dataclass
from typing import Dict, List
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.utils import PAD_SLOT_ID
@dataclass
class MambaCacheParams:
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MambaCacheParams(self.conv_state[layer_idx],
self.ssm_state[layer_idx],
self.state_indices_tensor)
class MambaCacheManager:
def __init__(self, dtype, num_mamba_layers, max_batch_size,
conv_state_shape, temporal_state_shape):
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda")
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda")
self.mamba_cache = (conv_state, temporal_state)
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
def current_run_tensors(self, input_ids: torch.Tensor,
attn_metadata: AttentionMetadata, **kwargs):
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
mamba_cache_tensors = self.mamba_cache
else:
# CUDA graph capturing runs
(mamba_cache_tensors,
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
return (mamba_cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.mamba_cache, state_indices_tensor)
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.mamba_cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
return destination_index
else:
# already exists
return self.mamba_cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
for seq_id in self.mamba_cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.mamba_cache_indices_mapping[req_id][seq_id])
self.mamba_cache_indices_mapping.pop(req_id)

View File

@@ -0,0 +1,185 @@
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
class ResidualBlock(nn.Module):
def __init__(self, hidden_size: int, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size, bias=False)
for _ in range(num_layers)
])
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = x + self.act(layer(x))
return x
class Medusa(nn.Module):
"""This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
Reference implementation: https://github.com/FasterDecoding/Medusa
Differences from reference implementation:
1. Currently this only supports generating proposals from top-1 tokens.
2. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
super().__init__()
self.config = config
self.blocks = nn.ModuleList([
ResidualBlock(hidden_size=self.config.hidden_size,
num_layers=self.config.num_hidden_layers)
for _ in range(self.config.num_heads)
])
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
self.lm_heads = nn.ModuleList([
ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) for _ in range(self.config.num_heads)
])
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.truncated_vocab_size,
logit_scale)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self.token_map = None
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
return [block(hidden_states) for block in self.blocks]
def compute_logits(
self, hidden_states: List[torch.Tensor],
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
logits_lst: List[torch.Tensor] = []
for hs, lm_head in zip(hidden_states, self.lm_heads):
_logits = self.logits_processor(lm_head, hs, sampling_metadata)
if _logits is None:
# _logits should only be None on rank > 0, in which case
# it should remain true for every lm_head
assert len(logits_lst) == 0
continue
if self.token_map is None:
logits_lst.append(_logits)
else:
logits_lst.append(-torch.inf * torch.ones(
size=(*_logits.shape[:-1], self.orig_vocab_size),
device=_logits.device,
dtype=_logits.dtype))
logits_lst[-1][..., self.token_map] = _logits
return logits_lst
def sample(
self,
logits: List[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
logits = torch.stack(logits, dim=0).float()
logprobs = torch.log_softmax(logits, dim=-1)
token_ids = logits.argmax(-1) # support only top-1 for now
probs = torch.softmax(logits, dim=-1)
token_id_list = []
token_prob_list = []
token_logprob_list = []
for idx, seq_group in enumerate(sampling_metadata.seq_groups):
token_id_list.append(token_ids[:, seq_group.sample_indices])
token_prob_list.append(probs[:, seq_group.sample_indices])
token_logprob_list.append(logprobs[:, seq_group.sample_indices])
outputs: List[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_prob_list[idx].squeeze(1),
logprobs=token_logprob_list[idx].squeeze(1),
sampled_token_ids=token_id_list[idx].squeeze(1),
))
return outputs
def generate_proposals(
self,
previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
return self.sample(
logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),
sampling_metadata=sampling_metadata,
),
sampling_metadata=sampling_metadata,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
weights_map = {}
for name, loaded_weight in weights:
name = name.replace("medusa_heads.", "")
if name == "token_map":
if self.truncated_vocab_size < self.orig_vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name in params_dict:
weights_map[name] = loaded_weight
for name, loaded_weight in weights_map.items():
if "lm_head" in name and self.token_map is not None and\
loaded_weight.shape[0] > self.token_map.shape[0]:
loaded_weight = loaded_weight[self.token_map]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.token_map is not None:
self.token_map.to(device=self.lm_heads[0].weight.device)
assert (self.truncated_vocab_size
== self.orig_vocab_size) or (self.token_map is not None)

View File

@@ -0,0 +1,603 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import math
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class MiniCPMMoE(nn.Module):
"""A tensor-parallel MoE implementation that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None)
self.ws = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
dtype=self.params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype))
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
class MiniCPMMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
hidden_act_param: float,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act == "silu":
self.act_fn = SiluAndMul()
elif hidden_act == "fatrelu":
self.act_fn = FatreluAndMul(threshold=hidden_act_param)
else:
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu and fatrelu are supported for now.")
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class MiniCPMAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
# set rope as fp32 instead of bf16
self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache(
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
orig_dtype = q.dtype
q, k = q.float(), k.float()
q, k = self.rotary_emb(positions, q, k)
q, k = q.to(orig_dtype), k.to(orig_dtype)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class MiniCPMDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.cache_config = cache_config
self.quant_config = quant_config
self.hidden_size = config.hidden_size
self.rope_theta = getattr(config, "rope_theta", 10000)
self.rope_scaling = getattr(config, "rope_scaling", None)
self.max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
self._init_attn_block()
self._init_ffn_block()
def _init_attn_block(self):
self.input_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.self_attn = MiniCPMAttention(
hidden_size=self.hidden_size,
num_heads=self.config.num_attention_heads,
num_kv_heads=self.config.num_key_value_heads,
rope_theta=self.rope_theta,
rope_scaling=self.rope_scaling,
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
)
def _init_ffn_block(self):
self.post_attention_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.num_experts = getattr(self.config, "num_experts", 0)
if self.num_experts == 0:
self.mlp = MiniCPMMLP(
hidden_size=self.hidden_size,
intermediate_size=self.config.intermediate_size,
hidden_act=self.config.hidden_act,
hidden_act_param=getattr(self.config, "hidden_act_param", 0.),
quant_config=self.quant_config,
)
else:
self.mlp = MiniCPMMoE(
num_experts=self.config.num_experts,
top_k=self.config.num_experts_per_tok,
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states * \
(self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states * \
(self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
return hidden_states, None
@support_torch_compile
class MiniCPMModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.cache_config = cache_config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self._init_layers(prefix, config, cache_config, quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size))
def _init_layers(
self,
prefix: str,
config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPMDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids)
return embedding * self.config.scale_emb
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
return hidden_states
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.prefix = prefix
self.vllm_config = vllm_config
self.config = config
self.lora_config = lora_config
self.cache_config = cache_config
self.quant_config = quant_config
self.num_experts = getattr(self.config, "num_experts", 0)
self._init_model(vllm_config=vllm_config, prefix=prefix)
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
self.scale_width = self.config.hidden_size / self.config.dim_model_base
self.logits_processor = LogitsProcessor(unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = MiniCPMModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states / self.scale_width
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.num_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,243 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2024 The ModelBest team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM3 model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM,
MiniCPMModel)
from .utils import make_layers, maybe_prefix
class MiniCPM3Attention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config)
self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
self.kv_lora_rank +
self.qk_rope_head_dim,
bias=False,
quant_config=quant_config)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config)
# O projection.
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config)
self.rotary_emb = get_rope(
self.qk_rope_head_dim,
rotary_dim=self.qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_local_heads,
self.qk_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
q, _ = self.q_a_proj(hidden_states)
q = self.q_a_layernorm(q)
q, _ = self.q_b_proj(q)
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states)
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv, _ = self.kv_b_proj(kv_a)
kv = kv.view(-1, self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]
q_pe, k_pe = self.rotary_emb(
positions,
q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim),
k_pe.reshape(-1, self.qk_rope_head_dim))
q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim)
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe
q = q.reshape(-1, self.num_local_heads * self.qk_head_dim)
k = k.view(-1, self.num_local_heads * self.qk_head_dim)
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.view(
-1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
def _init_attn_block(self):
self.input_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.self_attn = MiniCPM3Attention(
config=self.config,
hidden_size=self.hidden_size,
num_heads=self.config.num_attention_heads,
qk_nope_head_dim=self.config.qk_nope_head_dim,
qk_rope_head_dim=self.config.qk_rope_head_dim,
v_head_dim=self.config.v_head_dim,
q_lora_rank=self.config.q_lora_rank,
kv_lora_rank=self.config.kv_lora_rank,
rope_theta=self.rope_theta,
rope_scaling=self.rope_scaling,
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
)
class MiniCPM3Model(MiniCPMModel):
def _init_layers(
self,
prefix: str,
config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPM3DecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
packed_modules_mapping = {
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"kv_a_proj_with_mqa",
"q_a_proj",
"q_b_proj",
"kv_b_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
# `embedding_modules` and `embedding_padding_modules`
# are inherited from MiniCPMForCausalLM
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = MiniCPM3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,468 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = hidden_size
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size,
num_experts,
bias=False,
params_dtype=params_dtype,
quant_config=None,
prefix=f"{prefix}.gate")
self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
class MixtralAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
@support_torch_compile
class MixtralModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "w1", "w2", "w3",
"gate"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,438 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class MixtralMLP(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
quant_config=quant_config)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
quant_config=quant_config)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
quant_config=quant_config)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}.")
# Split experts equally between ranks
self.expert_indicies = np.array_split(range(
self.num_total_experts), self.tp_size)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(
f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList([
MixtralMLP(self.num_total_experts,
config.hidden_size,
config.intermediate_size,
quant_config=quant_config)
if idx in self.expert_indicies else None
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts,
bias=False,
quant_config=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(
expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states).view(
num_tokens, hidden_dim)
class MixtralAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(config=config,
quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
class MixtralModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class MixtralForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if ("block_sparse_moe.experts." in name
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,198 @@
import math
from typing import Iterable, List, Tuple
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
SQRT2 = 2**0.5
class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
eps : float
Safety term to prevent division by zero. Make sure the chosen value
fits in the range of your encoding scheme
(i.e. fp16 requires eps >= 6e-8).
elementwise_scale_and_shift : bool
Include a learned scaling and shift term after normalization.
"""
def __init__(
self,
normalized_shape,
eps=1e-06,
elementwise_scale_and_shift=True,
):
super().__init__()
self.elementwise_scale_and_shift = elementwise_scale_and_shift
if self.elementwise_scale_and_shift:
self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps
def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
if self.elementwise_scale_and_shift:
x = self.weight * x
x = x + self.bias
return x
class MLPSpeculator(nn.Module):
"""
An implementation of the speculative models introduced in
"Accelerating Production LLMs with Combined Token/Embedding
Speculators"
https://arxiv.org/pdf/2404.19124
Trained speculators of this type are available on HF hub at:
https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
else config.emb_dim
self.max_speculative_tokens = config.num_lookahead_tokens
self.tie_weights = config.tie_weights
self.scale_input = config.scale_input
if self.tie_weights:
assert (
self.n_predict >
1), "You cannot tie weights between stages when only 1 exists"
embedding = VocabParallelEmbedding(
config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
# the initial projection from the base model may
# have a different size, so that stays separate.
proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
(self.max_speculative_tokens - 1))
head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
ln = MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
else:
self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
for _ in range(self.max_speculative_tokens)
])
self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim,
bias=False)
for i in range(self.max_speculative_tokens)
])
self.head = nn.ModuleList([
ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
for _ in range(self.max_speculative_tokens)
])
self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
for _ in range(self.max_speculative_tokens)
])
if self.scale_input:
self.ln0 = MLPSpeculatorLayerNorm(
self.emb_dim, elementwise_scale_and_shift=False)
self.state_weight = 0.5**(0.5 / config.n_predict)
self.emb_weight = math.sqrt(
(1 - self.state_weight**2) * (self.inner_dim / 2))
self.activation = nn.GELU()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
config.vocab_size, 1.0)
self.sampler = get_sampler()
def generate_proposals(
self,
input_ids: torch.Tensor,
previous_hidden_states: torch.Tensor,
num_predict_tokens: int,
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
if num_predict_tokens > self.max_speculative_tokens:
raise ValueError(f"Max speculative tokens for model is "
f"{self.max_speculative_tokens}, but "
f"{num_predict_tokens} were requested")
# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)
if self.scale_input:
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
# b x 1
last_tokens = input_ids.unsqueeze(1)
next_tokens = []
for head_index in range(num_predict_tokens):
# Project and predict
z = self.emb[head_index](last_tokens) # b k d
states = self.proj[head_index](previous_hidden_states)
# Weighted add of state_weight*state and emb_weight*z
# Let subsequent LN take care of denominator
# state_weight is close to 1, so shouldn't be any precision issues
states.add_(z, alpha=self.emb_weight / self.state_weight)
states = self.activation(self.ln[head_index](states)) # b k d
previous_hidden_states = states
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)
logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)
output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids
next_tokens.append(output)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict.get(name.replace("speculator.", ""))
if param is not None:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,69 @@
# Adapted from
# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py
from dataclasses import dataclass, field
from typing import List, Union
@dataclass
class ModelKeys:
model_type: str = None
module_list: str = None
embedding: str = None
mlp: str = None
down_proj: str = None
attention: str = None
o_proj: str = None
q_proj: str = None
k_proj: str = None
v_proj: str = None
qkv_proj: str = None
qk_proj: str = None
qa_proj: str = None
qb_proj: str = None
kva_proj: str = None
kvb_proj: str = None
output: str = None
@dataclass
class MultiModelKeys(ModelKeys):
language_model: List[str] = field(default_factory=list)
connector: List[str] = field(default_factory=list)
# vision tower and audio tower
tower_model: List[str] = field(default_factory=list)
generator: List[str] = field(default_factory=list)
@staticmethod
def from_string_field(language_model: Union[str, List[str]] = None,
connector: Union[str, List[str]] = None,
tower_model: Union[str, List[str]] = None,
generator: Union[str, List[str]] = None,
**kwargs) -> 'MultiModelKeys':
def to_list(value):
if value is None:
return []
return [value] if isinstance(value, str) else list(value)
return MultiModelKeys(language_model=to_list(language_model),
connector=to_list(connector),
tower_model=to_list(tower_model),
generator=to_list(generator),
**kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,326 @@
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
from typing import Iterable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.mpt import MPTConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
def _get_alibi_slopes(
total_num_heads: int,
alibi_bias_max: int,
) -> torch.Tensor:
next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / next_power_of_2)
slopes = 1.0 / torch.pow(2, m)
if next_power_of_2 != total_num_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
return slopes
class MPTAttention(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.head_dim = self.d_model // self.total_num_heads
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
if "kv_n_heads" in config.attn_config:
self.total_num_kv_heads = config.attn_config['kv_n_heads']
else:
self.total_num_kv_heads = self.total_num_heads
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]
# pylint: disable=invalid-name
self.Wqkv = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=not config.no_bias,
quant_config=quant_config,
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
self.k_ln = nn.LayerNorm(self.d_model)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=not config.no_bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads,
self.alibi_bias_max)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.head_dim = self.d_model // self.total_num_heads
scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
del position_ids # unused.
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.qk_ln:
q = self.q_ln(q)
k = self.k_ln(k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
class MPTMLP(nn.Module):
def __init__(
self,
config: MPTConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
expansion_ratio = config.expansion_ratio
intermediate_size = expansion_ratio * hidden_size
self.up_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=not config.no_bias,
quant_config=quant_config,
)
self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=not config.no_bias,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.up_proj(x)
x = self.act(x)
x, _ = self.down_proj(x)
return x
class MPTBlock(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, cache_config, quant_config)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
x = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=x,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = hidden_states + x
x = self.norm_2(hidden_states)
x = self.ffn(x)
hidden_states = hidden_states + x
return hidden_states
@support_torch_compile
class MPTModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
)
self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers,
lambda prefix: MPTBlock(config, cache_config, quant_config),
prefix=f"{prefix}.blocks")
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
if hasattr(module, "bias") and isinstance(
module.bias, nn.Parameter):
# Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.d_model))
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
block = self.blocks[i]
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm_f(hidden_states)
return hidden_states
class MPTForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
assert config.tie_word_embeddings
self.quant_config = quant_config
self.transformer = MPTModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "transformer"))
self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,519 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Nemotron model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import NemotronConfig
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
# The architecture is pretty similar to Llama, with these changes:
# - There is no gate_proj, just up_proj
# - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
# - Squared ReLU instead of SwiGLU
# - Adds a partial_rotary_factor to RoPE
def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
else:
return torch.cuda.amp.autocast_mode._cast(
args, torch.get_autocast_gpu_dtype())
class NemotronLayerNorm1P(nn.LayerNorm):
def __init__(self,
normalized_shape: Union[int, List[int], torch.Size],
eps: float = 1e-5,
elementwise_affine: bool = True,
bias: bool = True,
device=None,
dtype=None):
super().__init__(normalized_shape, eps, elementwise_affine, bias,
device, dtype)
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if residual is not None:
x = x + residual
residual = x
args = _cast_if_autocast_enabled(x, self.normalized_shape,
self.weight + 1, self.bias, self.eps)
with torch.cuda.amp.autocast(enabled=False):
x = torch.nn.functional.layer_norm(*args)
return x if residual is None else (x, residual)
class NemotronMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.up_proj = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj")
self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.act_fn = get_act_fn(hidden_act)
def forward(self, x):
up, _ = self.up_proj(x)
x = self.act_fn(up)
x, _ = self.down_proj(x)
return x
class NemotronAttention(nn.Module):
def __init__(
self,
config: NemotronConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
partial_rotary_factor=self.partial_rotary_factor,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class NemotronDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attn = NemotronAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = NemotronMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = NemotronLayerNorm1P(config.hidden_size,
eps=config.norm_eps)
self.post_attention_layernorm = NemotronLayerNorm1P(
config.hidden_size, eps=config.norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class NemotronModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: NemotronDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = NemotronLayerNorm1P(config.hidden_size,
eps=config.norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "up_proj", "down_proj", "embed_tokens", "lm_head"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
assert isinstance(config, NemotronConfig)
self.config = config
self.lora_config = lora_config
self.model = NemotronModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,88 @@
# adapted from https://huggingface.co/nvidia/NVLM-D-72B/blob/main/modeling_nvlm_d.py
# --------------------------------------------------------
# NVLM-D
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from typing import Optional
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from .intern_vit import InternVisionModel
from .internvl import (InternVLChatModel, InternVLInputPipeline,
get_max_internvl_image_tokens)
IMG_START = '<|vision_start|>'
IMG_END = '<|vision_end|>'
IMG_CONTEXT = '<|vision_pad|>'
class NVLMInputPipeline(InternVLInputPipeline):
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
tile_pos_identifiers = ([f"<tile_{i}>"
for i in range(1, num_patches)] +
["<tile_global_thumbnail>"])
context_size = feature_size // num_patches
return '<Image>' + ''.join(
tile_pos_identifier + self.img_context_token * context_size
for tile_pos_identifier in tile_pos_identifiers) + '</Image>'
input_pipeline = NVLMInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class NVLM_D_Model(InternVLChatModel):
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
vit_hidden_size = config.vision_config.hidden_size
llm_intermediate_size = config.text_config.intermediate_size
llm_hidden_size = config.text_config.hidden_size
return nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
llm_intermediate_size,
bias=False),
nn.GELU(),
nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False),
)
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
is_mono: bool,
prefix: str,
):
if not is_mono:
vision_feature_layer = config.select_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
# We added additional dummy heads to the original num of heads to
# make the number of heads divisible by 8.
return InternVisionModel(
config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
num_dummy_heads=7,
prefix=prefix,
)
else:
msg = "Monolith mode is not applicable to NVLM_D"
raise NotImplementedError(msg)

View File

@@ -0,0 +1,397 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
# Copyright 2024 The vLLM team.
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OlmoAttention(nn.Module):
"""
This is the attention block where the output is computed as
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(
self,
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.total_num_heads = config.num_attention_heads
assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.clip_qkv = config.clip_qkv
# Attention input projection. Projects x -> (q, k, v)
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
bias=config.attention_bias,
quant_config=quant_config,
)
# Rotary embeddings.
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config,
quant_config=quant_config)
# Attention output projection.
self.o_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class OlmoMLP(nn.Module):
"""
This is the MLP block where the output is computed as
``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(
self,
config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# Feed-forward input projection.
self.gate_up_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function.
self.act_fn = SiluAndMul()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class OlmoDecoderLayer(nn.Module):
"""
This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(self,
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
# Attention block.
self.self_attn = OlmoAttention(config, cache_config, quant_config)
# MLP block.
self.mlp = OlmoMLP(config, quant_config)
# LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,
bias=False)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,
bias=False)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Attention block.
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(positions, hidden_states, kv_cache,
attn_metadata)
hidden_states = hidden_states + residual
# MLP block.
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@support_torch_compile
class OlmoModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config
),
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,
bias=False)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
if get_pp_group().is_first_rank:
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
# Apply blocks one-by-one.
for i in range(self.start_layer, self.end_layer):
# shape: (batch_size, seq_len, d_model)
hidden_states = self.layers[i](
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
hidden_states = self.norm(hidden_states)
return hidden_states
class OlmoForCausalLM(nn.Module, SupportsPP):
"""
Extremely barebones HF model wrapper.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.model = OlmoModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,445 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OlmoeMoE(nn.Module):
"""A tensor-parallel MoE implementation for Olmoe that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = hidden_size
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size,
num_experts,
bias=False,
quant_config=None)
self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=True,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
return final_hidden_states.view(orig_shape)
class OlmoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 4096,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.q_norm = RMSNorm(hidden_size, eps=1e-5)
self.k_norm = RMSNorm(hidden_size, eps=1e-5)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class OlmoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
4096)
self.self_attn = OlmoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
self.mlp = OlmoeMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class OlmoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoeDecoderLayer(config, int(
prefix.split(".")[-1]), cache_config, quant_config),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class OlmoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = OlmoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,433 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import OPTConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OPTLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the
# embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, positions: torch.Tensor):
return super().forward(positions + self.offset)
class OPTAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.embed_dim = embed_dim
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
total_num_heads = num_heads
assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
embed_dim,
self.head_dim,
total_num_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
class OPTDecoderLayer(nn.Module):
def __init__(
self,
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
bias=config.enable_bias,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.do_layer_norm_before = config.do_layer_norm_before
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(
self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.activation_fn = get_act_fn(config.activation_function)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class OPTDecoder(nn.Module):
def __init__(
self,
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.word_embed_proj_dim,
)
# Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size)
# Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = ReplicatedLinear(config.hidden_size,
config.word_embed_proj_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.project_out")
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.project_in")
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to
# keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine)
else:
self.final_layer_norm = None
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OPTDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
pos_embeds = self.embed_positions(positions)
if self.project_in is not None:
inputs_embeds, _ = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states, _ = self.project_out(hidden_states)
return hidden_states
@support_torch_compile
class OPTModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.decoder = OPTDecoder(config,
cache_config,
quant_config,
prefix=f"{prefix}.decoder")
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.decoder.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.decoder(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
class OPTForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = OPTModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if self.config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.word_embed_proj_dim)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "lm_head.weight" in name and self.config.tie_word_embeddings:
continue
if name.startswith("decoder."):
name = "model." + name
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,358 @@
# Adapted from
# https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
# Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OrionMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class OrionAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class OrionDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = OrionAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
self.mlp = OrionMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@support_torch_compile
class OrionModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OrionDecoderLayer(
config,
cache_config,
quant_config,
),
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory([
"hidden_states",
], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = self.norm(hidden_states)
return hidden_states
class OrionForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = OrionModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,300 @@
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
from torch import nn
from transformers import PaliGemmaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__)
class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class PaliGemmaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
PaliGemmaImageEmbeddingInputs]
def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
return get_max_siglip_image_tokens(vision_config)
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
)
mm_data = dummy_image_for_siglip(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
def input_processor_for_paligemma(ctx: InputContext,
inputs: DecoderOnlyInputs):
"""
The correct prompt format needs to be:
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
""" # noqa
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(PaliGemmaConfig)
tokenizer = cached_get_tokenizer(model_config.tokenizer)
image_feature_size = hf_config.text_config.num_image_tokens
image_token_str = tokenizer.decode(hf_config.image_token_index)
bos_token = tokenizer.decode(hf_config.bos_token_id)
image_token_str_pad = image_token_str * image_feature_size
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
orig_prompt = inputs.get("prompt")
orig_prompt_ids = inputs.get("prompt_token_ids")
if orig_prompt is not None and image_token_str in orig_prompt:
logger.warning(
"The image token '%s' was detected in the prompt and "
"will be removed. Please follow the proper prompt format"
" documented on HuggingFace.", image_token_str)
orig_prompt = orig_prompt.replace(image_token_str, "")
orig_prompt_ids.remove(hf_config.image_token_index)
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear(image_features)
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
prefix=maybe_prefix(
prefix, "vision_tower"))
self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim)
self.quant_config = quant_config
config.text_config.architectures = ["GemmaForCausalLM"]
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return PaliGemmaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return PaliGemmaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(
self,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
) -> torch.Tensor:
target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_features = vision_tower(pixel_values.to(dtype=target_dtype))
return image_features
def _process_image_input(
self,
image_input: PaliGemmaImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
pixel_values = image_input["data"]
image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
return self.multi_modal_projector(image_features)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
if parsed_image_input is not None:
vision_embeddings = self._process_image_input(
parsed_image_input)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
vision_embeddings = vision_embeddings * (
self.config.hidden_size**-0.5)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@@ -0,0 +1,354 @@
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py
# Copyright 2023 The vLLM team.
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PersimmonConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class PersimmonMLP(nn.Module):
def __init__(self,
config: PersimmonConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
quant_config=quant_config)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
quant_config=quant_config)
self.act = get_act_fn(config.hidden_act)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.dense_4h_to_h(hidden_states)
return hidden_states
class PersimmonAttention(nn.Module):
def __init__(self,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
tensor_parallel_world_size = get_tensor_model_parallel_world_size()
self.hidden_size = config.hidden_size
self.total_num_heads = config.num_attention_heads
self.num_heads = self.total_num_heads // tensor_parallel_world_size
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
self.is_causal = True
assert (self.head_dim * self.total_num_heads) == self.hidden_size
assert self.total_num_heads % tensor_parallel_world_size == 0
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=True,
quant_config=quant_config,
)
self.is_qk_layernorm = config.qk_layernorm
if self.is_qk_layernorm:
self.q_layernorm = nn.LayerNorm(self.head_dim)
self.k_layernorm = nn.LayerNorm(self.head_dim)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=int(self.partial_rotary_factor * self.head_dim),
max_position=self.max_position_embeddings,
base=self.rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config,
quant_config=quant_config)
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
# [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
seq_length = x.shape[0]
return x.view(seq_length, self.num_heads, self.head_dim)
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
# [seq_length, num_heads, head_dim] -> [seq_length, hidden_size]
seq_length = x.shape[0]
return x.view(seq_length, self.num_heads * self.head_dim)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# [seq_length, 3 x hidden_size]
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.is_qk_layernorm:
# [seq_length, num_heads, head_dim]
q = self._split_heads(q)
k = self._split_heads(k)
q = self.q_layernorm(q)
k = self.k_layernorm(k)
q = self._merge_heads(q)
k = self._merge_heads(k)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.dense(attn_output)
return output
class PersimmonDecoderLayer(nn.Module):
def __init__(self,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PersimmonAttention(config=config,
cache_config=cache_config,
quant_config=quant_config)
self.mlp = PersimmonMLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
outputs = hidden_states
return outputs
@support_torch_compile
class PersimmonModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: PersimmonDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
hidden_states = self.layers[i](
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class PersimmonForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.vocab_size = config.vocab_size
self.model = PersimmonModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=False)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
# copy from vllm/model_executor/models/bloom.py
# NOTE: Persimmon's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,373 @@
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PhiConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class PhiAttention(nn.Module):
def __init__(self,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
# pylint: disable=C0103
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_size,
self.total_num_heads,
bias=True,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
quant_config=quant_config,
)
scaling = self.head_size**-0.5
rotary_dim = int(config.partial_rotary_factor *
(config.hidden_size // config.num_attention_heads))
assert rotary_dim % 2 == 0
# pylint: disable=C0301
# Refer to:
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
rope_theta = getattr(config, "rope_theta", 10000.0)
max_position_embeddings = getattr(config, "max_position_embeddings",
2048)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.dense(attn_output)
return output
class PhiMLP(nn.Module):
def __init__(self,
config: PhiConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
n_inner = getattr(config, "n_inner", None)
n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
self.fc1 = ColumnParallelLinear(
config.hidden_size,
n_inner,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
n_inner,
config.hidden_size,
quant_config=quant_config,
)
self.act = get_act_fn(config.hidden_act)
def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class PhiLayer(nn.Module):
def __init__(self,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, cache_config, quant_config)
self.mlp = PhiMLP(config, quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_outputs + feed_forward_hidden_states + residual
return hidden_states
@support_torch_compile
class PhiModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: PhiLayer(config, cache_config, quant_config),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
]
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"dense",
"fc1",
"fc2",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
# lm_head use bias, cannot share word embeddings
assert not config.tie_word_embeddings
self.lora_config = lora_config
self.quant_config = quant_config
self.model = PhiModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v")
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# pylint: disable=E1136
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

Some files were not shown because too many files have changed in this diff Show More