First commit
This commit is contained in:
23
vllm/model_executor/models/__init__.py
Normal file
23
vllm/model_executor/models/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
|
||||
SupportsPP, has_inner_state, supports_lora,
|
||||
supports_multimodal, supports_pp)
|
||||
from .interfaces_base import (VllmModelForEmbedding,
|
||||
VllmModelForTextGeneration, is_embedding_model,
|
||||
is_text_generation_model)
|
||||
from .registry import ModelRegistry
|
||||
|
||||
__all__ = [
|
||||
"ModelRegistry",
|
||||
"VllmModelForEmbedding",
|
||||
"is_embedding_model",
|
||||
"VllmModelForTextGeneration",
|
||||
"is_text_generation_model",
|
||||
"HasInnerState",
|
||||
"has_inner_state",
|
||||
"SupportsLoRA",
|
||||
"supports_lora",
|
||||
"SupportsMultiModal",
|
||||
"supports_multimodal",
|
||||
"SupportsPP",
|
||||
"supports_pp",
|
||||
]
|
||||
BIN
vllm/model_executor/models/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/arctic.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/arctic.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/baichuan.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/baichuan.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/bart.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/bart.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/blip.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/blip.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/blip2.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/blip2.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/bloom.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/bloom.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/chameleon.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/chameleon.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/chatglm.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/chatglm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/clip.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/clip.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/commandr.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/commandr.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/dbrx.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/dbrx.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/decilm.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/decilm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/deepseek.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/deepseek.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/eagle.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/eagle.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/exaone.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/exaone.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/falcon.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/falcon.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/fuyu.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/fuyu.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/gemma.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/gemma.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/gemma2.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/gemma2.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/glm4.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/glm4.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/gpt2.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/gpt2.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/gpt_j.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/gpt_j.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/gpt_neox.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/gpt_neox.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/granite.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/granite.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/internlm2.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/internlm2.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/internvl.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/internvl.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/jais.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/jais.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/jamba.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/jamba.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/llama.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/llama.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/llava.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/llava.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/mamba.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/mamba.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/medusa.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/medusa.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/minicpm.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/minicpm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/minicpm3.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/minicpm3.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/minicpmv.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/minicpmv.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/mixtral.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/mixtral.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/mllama.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/mllama.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/molmo.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/molmo.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/mpt.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/mpt.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/nemotron.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/nemotron.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/nvlm_d.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/nvlm_d.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/olmo.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/olmo.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/olmoe.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/olmoe.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/opt.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/opt.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/orion.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/orion.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/paligemma.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/paligemma.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/persimmon.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/persimmon.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/phi.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/phi.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/phi3.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/phi3.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/phi3v.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/phi3v.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/phimoe.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/phimoe.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/pixtral.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/pixtral.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/qwen.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/qwen.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/qwen2.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/qwen2.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/qwen2_moe.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/qwen2_moe.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/qwen2_rm.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/qwen2_rm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/qwen2_vl.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/qwen2_vl.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/qwen3.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/qwen3.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/qwen3_moe.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/qwen3_moe.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/registry.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/registry.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/siglip.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/siglip.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/solar.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/solar.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/stablelm.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/stablelm.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/ultravox.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/ultravox.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/models/__pycache__/xverse.cpython-310.pyc
Normal file
BIN
vllm/model_executor/models/__pycache__/xverse.cpython-310.pyc
Normal file
Binary file not shown.
562
vllm/model_executor/models/arctic.py
Normal file
562
vllm/model_executor/models/arctic.py
Normal file
@@ -0,0 +1,562 @@
|
||||
"""Inference-only Snowflake Arctic model."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.deepspeedfp import (
|
||||
DeepSpeedFPConfig, DeepSpeedFPParameter)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.arctic import ArcticConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ArcticMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: ArcticConfig,
|
||||
layer_id: int,
|
||||
expert_id: int = -1,
|
||||
is_residual_mlp: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True):
|
||||
super(ArcticMLP, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.expert_id = expert_id
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.ffn_dim = config.intermediate_size if not is_residual_mlp \
|
||||
else self.hidden_size
|
||||
|
||||
self.w13 = MergedColumnParallelLinear(self.hidden_size,
|
||||
[self.ffn_dim] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.w2 = RowParallelLinear(self.ffn_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
reduce_results=reduce_results,
|
||||
quant_config=quant_config)
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up, _ = self.w13(hidden_states)
|
||||
hidden_states = self.act_fn(gate_up)
|
||||
hidden_states, _ = self.w2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ArcticMoE(nn.Module):
|
||||
"""
|
||||
Model-parallel implementation of Arctic MoE Layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: ArcticConfig,
|
||||
layer_id: int,
|
||||
tp_size: Optional[int] = None,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True):
|
||||
super(ArcticMoE, self).__init__()
|
||||
|
||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_experts = config.num_local_experts
|
||||
self.layer_id = layer_id
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.intermediate_size = config.intermediate_size // self.tp_size
|
||||
|
||||
self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
|
||||
self.is_quant = isinstance(quant_config, DeepSpeedFPConfig)
|
||||
self.reduce_results = reduce_results
|
||||
# Some other parameters
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
if not self.is_moe_layer:
|
||||
self.mlp = ArcticMLP(config,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results)
|
||||
else:
|
||||
self.gate = ReplicatedLinear(self.hidden_size,
|
||||
self.num_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
quant_config=quant_config)
|
||||
if self.is_quant:
|
||||
self.ws = DeepSpeedFPParameter(
|
||||
torch.Size((self.num_experts, 2 * self.intermediate_size,
|
||||
self.hidden_size)),
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.w2s = DeepSpeedFPParameter(
|
||||
torch.Size((self.num_experts, self.hidden_size,
|
||||
self.intermediate_size)),
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.ws = nn.Parameter(
|
||||
torch.empty(self.num_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
device="cuda",
|
||||
dtype=self.params_dtype))
|
||||
self.w2s = nn.Parameter(
|
||||
torch.empty(self.num_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
device="cuda",
|
||||
dtype=self.params_dtype))
|
||||
set_weight_attrs(self.ws, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2s, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||
weight_name: str, expert_id: int):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param.ds_dequantize() if self.is_quant else param.data
|
||||
shard_size = self.intermediate_size
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
if weight_name.endswith("w1.weight"):
|
||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("w3.weight"):
|
||||
param_data[expert_id,
|
||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("w2.weight"):
|
||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||
if self.is_quant:
|
||||
param.ds_quantize_(param_data)
|
||||
|
||||
def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
do_normalize = self.top_k > 1
|
||||
topk_weights, topk_ids = fused_topk(hidden_states,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=do_normalize)
|
||||
# topk_ids: (num_tokens, k)
|
||||
if self.is_quant:
|
||||
if 2 * num_tokens <= self.num_experts:
|
||||
# If much fewer tokens than experts, use selective dequantize.
|
||||
ws_dequantized = self.ws.ds_selective_dequantize(
|
||||
topk_ids.flatten())
|
||||
w2s_dequantized = self.w2s.ds_selective_dequantize(
|
||||
topk_ids.flatten())
|
||||
# We gathered the experts to the tokens so update the mapping.
|
||||
topk_ids = torch.arange(
|
||||
0,
|
||||
topk_ids.numel(),
|
||||
device=topk_ids.device,
|
||||
).reshape(topk_ids.shape)
|
||||
else:
|
||||
ws_dequantized = self.ws.ds_dequantize()
|
||||
w2s_dequantized = self.w2s.ds_dequantize()
|
||||
|
||||
final_hidden_states = fused_experts(
|
||||
hidden_states,
|
||||
ws_dequantized if self.is_quant else self.ws,
|
||||
w2s_dequantized if self.is_quant else self.w2s,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
return final_hidden_states.view(num_tokens, hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
if self.is_moe_layer:
|
||||
final_hidden_states = self.local_moe_fused(hidden_states)
|
||||
else:
|
||||
final_hidden_states = self.mlp(hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class ArcticAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ArcticConfig,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = config.num_key_value_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
reduce_results=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class ArcticDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ArcticConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
|
||||
self.use_residual = config.use_residual and is_moe_layer
|
||||
self.self_attn = ArcticAttention(config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
self.block_sparse_moe = ArcticMoE(
|
||||
config,
|
||||
layer_id=layer_idx,
|
||||
quant_config=quant_config,
|
||||
reduce_results=(not self.use_residual))
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
if self.use_residual:
|
||||
self.residual_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.residual_mlp = ArcticMLP(config,
|
||||
layer_id=layer_idx,
|
||||
is_residual_mlp=True,
|
||||
reduce_results=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual_input = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual_input + hidden_states
|
||||
|
||||
residual_attn = hidden_states
|
||||
if self.use_residual:
|
||||
hidden_states = self.residual_layernorm(hidden_states)
|
||||
hidden_states = self.residual_mlp(hidden_states)
|
||||
residual_mlp = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(residual_input)
|
||||
hidden_states = self.block_sparse_moe(hidden_states)
|
||||
hidden_states = residual_mlp + hidden_states
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states = residual_attn + hidden_states
|
||||
else:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.block_sparse_moe(hidden_states)
|
||||
hidden_states = residual_attn + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ArcticModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ArcticConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=self.vocab_size)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: ArcticDecoderLayer(config, int(
|
||||
prefix.split(".")[-1]), cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ArcticForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: ArcticConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = ArcticModel(config, cache_config, quant_config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.num_experts = config.num_local_experts
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
mlp_params_mapping: List[Tuple[str, str, int]] = []
|
||||
expert_params_mapping: List[Tuple[str, str, int]] = []
|
||||
num_layers = self.config.num_hidden_layers
|
||||
|
||||
for layer in range(num_layers):
|
||||
mlp_params_mapping.append(
|
||||
(f"layers.{layer}.residual_mlp.w13.weight",
|
||||
f"layers.{layer}.residual_mlp.w1.weight", 0))
|
||||
mlp_params_mapping.append(
|
||||
(f"layers.{layer}.residual_mlp.w13.weight",
|
||||
f"layers.{layer}.residual_mlp.w3.weight", 1))
|
||||
if layer % 2 == 0:
|
||||
# MLP layers
|
||||
mlp_params_mapping.append(
|
||||
(f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
||||
f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0))
|
||||
mlp_params_mapping.append(
|
||||
(f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
||||
f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1))
|
||||
else:
|
||||
# MoE layers
|
||||
for expert_id in range(self.config.num_local_experts):
|
||||
expert_params_mapping.append(
|
||||
("ws", f"experts.{expert_id}.w1.weight", expert_id))
|
||||
expert_params_mapping.append(
|
||||
("w2s", f"experts.{expert_id}.w2.weight", expert_id))
|
||||
expert_params_mapping.append(
|
||||
("ws", f"experts.{expert_id}.w3.weight", expert_id))
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
logger.info(
|
||||
"It will take ~10 minutes loading from the 16-bit weights. "
|
||||
"Alternatively, use the prequantized 8-bit weights of arctic "
|
||||
"and set load-format to `sharded_state` will accelerate loading.")
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for param_name, weight_name, shard_id in mlp_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for param_name, weight_name, shard_id \
|
||||
in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
expert_id=shard_id)
|
||||
break
|
||||
else:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
463
vllm/model_executor/models/baichuan.py
Normal file
463
vllm/model_executor/models/baichuan.py
Normal file
@@ -0,0 +1,463 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2,
|
||||
total_num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(start=1,
|
||||
end=1 + 2 * num_remaining_heads,
|
||||
step=2,
|
||||
dtype=torch.int32)
|
||||
slopes = torch.cat(
|
||||
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
class BaiChuanMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class BaiChuanAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
position_embedding: str,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||
)
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.postion_embedding = position_embedding
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.W_pack = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.postion_embedding == "ALIBI":
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
|
||||
scaling = self.head_dim**-0.5
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.W_pack(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.postion_embedding != "ALIBI":
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class BaiChuanDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = BaiChuanAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
position_embedding=position_embedding,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mlp = BaiChuanMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class BaiChuanModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: BaiChuanDecoderLayer(config, position_embedding,
|
||||
cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual,
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"W_pack": ["W_pack"],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"W_pack",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = BaiChuanModel(config, position_embedding, cache_config,
|
||||
quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if name == "lm_head.weight":
|
||||
# Unlike Baichuan, Baichuan2 normalizes the head weights.
|
||||
# Refer to:
|
||||
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
|
||||
# Distinguish between Baichuan and Baichuan2 by checking the
|
||||
# vocab size. This is suggested by
|
||||
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
|
||||
is_baichuan2 = self.config.vocab_size == 125696
|
||||
if is_baichuan2:
|
||||
loaded_weight = torch.nn.functional.normalize(
|
||||
loaded_weight)
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
"""Baichuan 13B and Baichuan2 7B/13B."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
if config.hidden_size == 4096: # baichuan2 7b
|
||||
super().__init__(config, "ROPE", cache_config, quant_config,
|
||||
lora_config)
|
||||
else: # baichuan 13b, baichuan2 13b
|
||||
super().__init__(config, "ALIBI", cache_config, quant_config,
|
||||
lora_config)
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
"""Baichuan 7B."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__(config, "ROPE", cache_config, quant_config,
|
||||
lora_config)
|
||||
1003
vllm/model_executor/models/bart.py
Normal file
1003
vllm/model_executor/models/bart.py
Normal file
File diff suppressed because it is too large
Load Diff
423
vllm/model_executor/models/blip.py
Normal file
423
vllm/model_executor/models/blip.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""Minimal implementation of BlipVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from typing import Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||
from transformers.models.blip.modeling_blip import BlipAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
USE_XFORMERS_OPS = True
|
||||
|
||||
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||
grid_length = get_blip_patch_grid_length(image_size=image_size,
|
||||
patch_size=patch_size)
|
||||
return grid_length * grid_length
|
||||
|
||||
|
||||
def get_blip_image_feature_size(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
|
||||
return get_blip_num_patches(image_size=hf_config.image_size,
|
||||
patch_size=hf_config.patch_size)
|
||||
|
||||
|
||||
def get_max_blip_image_tokens(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
|
||||
return get_blip_image_feature_size(hf_config)
|
||||
|
||||
|
||||
def dummy_seq_data_for_blip(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_blip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
|
||||
|
||||
def dummy_image_for_blip(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
num_images: int,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
):
|
||||
width = height = hf_config.image_size
|
||||
if image_width_override is not None:
|
||||
width = image_width_override
|
||||
if image_height_override is not None:
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def input_processor_for_blip(
|
||||
model_config: ModelConfig,
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
llm_inputs: LLMInputs,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_blip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
placeholder_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
|
||||
class BlipVisionEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: BlipVisionConfig):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
)
|
||||
|
||||
self.num_patches = get_blip_num_patches(image_size=self.image_size,
|
||||
patch_size=self.patch_size)
|
||||
self.num_positions = self.num_patches + 1
|
||||
|
||||
self.position_embedding = nn.Parameter(
|
||||
torch.randn(1, self.num_positions, self.embed_dim))
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(
|
||||
dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
|
||||
position_embeds = self.position_embedding.to(target_dtype)
|
||||
embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class BlipParallelAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.projection = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states, _ = self.qkv(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
query_states = query_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
query_states = query_states.transpose(1,2)
|
||||
key_states = key_states.transpose(1,2)
|
||||
value_states = value_states.transpose(1,2)
|
||||
attn = (query_states @ (key_states * self.scale).permute(0,1,3,2)).float()
|
||||
attn = torch.softmax(attn, dim=-1).to(value_states.dtype)
|
||||
out = attn @ value_states
|
||||
out = out.transpose(1,2).reshape(bsz, tgt_len, -1)
|
||||
# TODO: use use F.attention when supported
|
||||
# out = xops.memory_efficient_attention_forward(query_states,
|
||||
# key_states,
|
||||
# value_states,
|
||||
# p=self.dropout,
|
||||
# scale=self.scale)
|
||||
# out = out.view(bsz, tgt_len, -1)
|
||||
attn_output, _ = self.projection(out)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class BlipMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlipEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
# fallback to sdpa attention if tp unavailable
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = BlipParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
# Blip doesn't have SDPA attention implemented in transformers
|
||||
# use eager attention instead for cpu backend
|
||||
self.self_attn = BlipAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = BlipMLP(config, quant_config=quant_config)
|
||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlipEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self
|
||||
attention layers. Each layer is a [`BlipEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: BlipConfig
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
BlipEncoderLayer(config=config, quant_config=quant_config)
|
||||
for _ in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, inputs_embeds: torch.Tensor):
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlipVisionModel(nn.Module):
|
||||
config_class = BlipVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embeddings = BlipVisionEmbeddings(config)
|
||||
self.encoder = BlipEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
)
|
||||
|
||||
if len(self.encoder.layers) > config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"The original encoder only has {config.num_hidden_layers} "
|
||||
f"layers, but you requested {len(self.encoder.layers)} layers."
|
||||
)
|
||||
elif len(self.encoder.layers) == config.num_hidden_layers:
|
||||
self.post_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
else:
|
||||
# post_layernorm is unused when we extract intermediate features
|
||||
# In this case, we can skip it to conserve memory
|
||||
self.post_layernorm = None
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.encoder(inputs_embeds=hidden_states)
|
||||
|
||||
if self.post_layernorm is None:
|
||||
return hidden_states
|
||||
|
||||
return self.post_layernorm(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
] if self.shard_weight else []
|
||||
params_dict = dict(self.named_parameters())
|
||||
layer_count = len(self.encoder.layers)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# post_layernorm is not needed in BlipVisionModel
|
||||
if (name.startswith("post_layernorm")
|
||||
and self.post_layernorm is None):
|
||||
continue
|
||||
|
||||
# omit layers when num_hidden_layers_override is set
|
||||
if name.startswith("encoder.layers"):
|
||||
layer_idx = int(name.split(".")[2])
|
||||
if layer_idx >= layer_count:
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
690
vllm/model_executor/models/blip2.py
Normal file
690
vllm/model_executor/models/blip2.py
Normal file
@@ -0,0 +1,690 @@
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
|
||||
apply_chunking_to_forward)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||
get_max_blip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
# defined on the HuggingFace repo
|
||||
BLIP2_IMAGE_TOKEN = "<image>"
|
||||
BLIP2_IMAGE_TOKEN_ID = 50265
|
||||
|
||||
|
||||
class Blip2ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
|
||||
|
||||
class Blip2ImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
|
||||
|
||||
|
||||
class Blip2QFormerMultiHeadAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Blip2QFormerConfig,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
is_cross_attention: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of "
|
||||
f"the number of attention heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = (config.hidden_size //
|
||||
config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
if is_cross_attention:
|
||||
kv_hidden_size = config.encoder_hidden_size
|
||||
else:
|
||||
kv_hidden_size = config.hidden_size
|
||||
self.key = nn.Linear(kv_hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(kv_hidden_size, self.all_head_size)
|
||||
|
||||
self.position_embedding_type = getattr(config,
|
||||
"position_embedding_type",
|
||||
"absolute")
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise NotImplementedError("Unsupported position_embedding_type: "
|
||||
f"{self.position_embedding_type}")
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
x = x.view(*x.size()[:-1], self.num_attention_heads,
|
||||
self.attention_head_size)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(
|
||||
self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(
|
||||
self.value(encoder_hidden_states))
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
attention_scores = torch.matmul(query_layer,
|
||||
key_layer.transpose(-1, -2))
|
||||
attention_probs = torch.softmax(attention_scores * self.scaling,
|
||||
dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs_dropped = self.dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
context_layer = context_layer.view(*context_layer.size()[:-2],
|
||||
self.all_head_size)
|
||||
|
||||
return context_layer
|
||||
|
||||
|
||||
class Blip2QFormerSelfOutput(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Blip2QFormerAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Blip2QFormerConfig,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
is_cross_attention: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attention = Blip2QFormerMultiHeadAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
is_cross_attention=is_cross_attention,
|
||||
)
|
||||
|
||||
self.output = Blip2QFormerSelfOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
self_output = self.attention(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
attention_output = self.output(self_output, hidden_states)
|
||||
|
||||
return attention_output
|
||||
|
||||
|
||||
class Blip2QFormerIntermediate(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Blip2QFormerOutput(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Blip2QFormerLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Blip2QFormerConfig,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
layer_idx: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = Blip2QFormerAttention(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config)
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
if layer_idx % config.cross_attention_frequency == 0:
|
||||
self.crossattention = Blip2QFormerAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
is_cross_attention=True)
|
||||
self.has_cross_attention = True
|
||||
else:
|
||||
self.has_cross_attention = False
|
||||
|
||||
self.intermediate_query = Blip2QFormerIntermediate(config)
|
||||
self.output_query = Blip2QFormerOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
query_length: int,
|
||||
):
|
||||
attention_output = self.attention(hidden_states)
|
||||
|
||||
if query_length > 0:
|
||||
query_attention_output = attention_output[:, :query_length, :]
|
||||
|
||||
if self.has_cross_attention:
|
||||
query_attention_output = self.crossattention(
|
||||
query_attention_output,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk_query,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
query_attention_output,
|
||||
)
|
||||
|
||||
if attention_output.shape[1] > query_length:
|
||||
layer_output_text = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
attention_output[:, query_length:, :],
|
||||
)
|
||||
layer_output = torch.cat([layer_output, layer_output_text],
|
||||
dim=1)
|
||||
else:
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
attention_output,
|
||||
)
|
||||
|
||||
return layer_output
|
||||
|
||||
def feed_forward_chunk(self,
|
||||
attention_output: torch.Tensor) -> torch.Tensor:
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
def feed_forward_chunk_query(
|
||||
self, attention_output: torch.Tensor) -> torch.Tensor:
|
||||
intermediate_output = self.intermediate_query(attention_output)
|
||||
layer_output = self.output_query(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class Blip2QFormerEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Blip2QFormerConfig,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.layer = nn.ModuleList([
|
||||
Blip2QFormerLayer(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
layer_idx=layer_idx)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
query_length: int,
|
||||
) -> torch.Tensor:
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
layer_module = self.layer[i]
|
||||
|
||||
hidden_states = layer_module(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
query_length=query_length,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
|
||||
class Blip2QFormerModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Blip2QFormerConfig,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
self.encoder = Blip2QFormerEncoder(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query_embeds: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
) -> torch.Tensor:
|
||||
query_length = query_embeds.shape[1]
|
||||
|
||||
embedding_output = self.layernorm(query_embeds)
|
||||
embedding_output = self.dropout(embedding_output)
|
||||
|
||||
sequence_output = self.encoder(
|
||||
embedding_output,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
query_length=query_length,
|
||||
)
|
||||
|
||||
return sequence_output
|
||||
|
||||
|
||||
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
|
||||
def get_max_blip2_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, Blip2VisionConfig):
|
||||
return get_max_blip_image_tokens(vision_config)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def dummy_seq_data_for_blip2(
|
||||
hf_config: Blip2Config,
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
|
||||
|
||||
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
seq_data = dummy_seq_data_for_blip2(
|
||||
hf_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=BLIP2_IMAGE_TOKEN_ID,
|
||||
)
|
||||
|
||||
if isinstance(vision_config, Blip2VisionConfig):
|
||||
mm_data = dummy_image_for_blip(vision_config, num_images)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
|
||||
# The original model places image tokens at the front
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
|
||||
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
|
||||
new_token_ids += llm_inputs["prompt_token_ids"]
|
||||
|
||||
new_prompt = llm_inputs.get("prompt")
|
||||
if new_prompt is not None:
|
||||
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
|
||||
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: Blip2Config,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_model = BlipVisionModel(config.vision_config)
|
||||
|
||||
self.query_tokens = nn.Parameter(
|
||||
torch.zeros(1, config.num_query_tokens,
|
||||
config.qformer_config.hidden_size))
|
||||
|
||||
self.qformer = Blip2QFormerModel(config.qformer_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.language_projection = nn.Linear(
|
||||
config.qformer_config.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
f"The expected shape of pixel values is {expected_expr}. "
|
||||
f"You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
pixel_values = pixel_values.squeeze(1)
|
||||
|
||||
return Blip2ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
image_embeds = image_embeds.squeeze(1)
|
||||
|
||||
return Blip2ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
|
||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_model(pixel_values)
|
||||
|
||||
return image_features
|
||||
|
||||
def _process_image_pixels(self,
|
||||
inputs: Blip2ImagePixelInputs) -> torch.Tensor:
|
||||
assert self.vision_model is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
|
||||
return self._image_pixels_to_features(self.vision_model, pixel_values)
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: Blip2ImageInputs) -> torch.Tensor:
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_model is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
|
||||
-1)
|
||||
query_output = self.qformer(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_features,
|
||||
)
|
||||
|
||||
return self.language_projection(query_output)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[SamplerOutput, IntermediateTensors]:
|
||||
"""Run forward pass for BLIP-2.
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
positions of the to-be-inserted image embeddings.
|
||||
|
||||
Concretely, consider a text prompt:
|
||||
`"Question: What's the content of the image? Answer:"`.
|
||||
|
||||
Tokenizer outputs:
|
||||
`[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
|
||||
|
||||
To reserve space in KV cache, we have to insert placeholder tokens
|
||||
before they are inputted to the model, so the input processor prepends
|
||||
dummy tokens (denoted as `50265`), resulting in:
|
||||
`[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
|
||||
|
||||
We insert 32 tokens since it corresponds to the number of query
|
||||
embeddings outputted by the Q-Former and inputted to the language model.
|
||||
|
||||
This way, the `positions` and `attn_metadata` are consistent
|
||||
with the `input_ids`.
|
||||
|
||||
Args:
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
pixel_values: The pixels in each input image.
|
||||
|
||||
See also:
|
||||
:class:`Blip2ImageInputs`
|
||||
"""
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
BLIP2_IMAGE_TOKEN_ID)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
362
vllm/model_executor/models/bloom.py
Normal file
362
vllm/model_executor/models/bloom.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only BLOOM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BloomConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2,
|
||||
total_num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(start=1,
|
||||
end=1 + 2 * num_remaining_heads,
|
||||
step=2,
|
||||
dtype=torch.int32)
|
||||
slopes = torch.cat(
|
||||
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
class BloomAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.total_num_heads = config.n_head
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
assert self.head_dim * self.total_num_heads == self.hidden_size
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Create the alibi slopes and slice them.
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
|
||||
scaling = self.head_dim**-0.5
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
del position_ids # Unused.
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class BloomMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
4 * hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, _ = self.dense_h_to_4h(x)
|
||||
x = self.gelu_impl(x)
|
||||
x, _ = self.dense_4h_to_h(x)
|
||||
return x
|
||||
|
||||
|
||||
class BloomBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.self_attention = BloomAttention(config, cache_config,
|
||||
quant_config)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = BloomMLP(config, quant_config)
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
|
||||
# Layer norm post the self attention.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# Self attention.
|
||||
attention_output = self.self_attention(
|
||||
position_ids=position_ids,
|
||||
hidden_states=layernorm_output,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
attention_output = attention_output + residual
|
||||
layernorm_output = self.post_attention_layernorm(attention_output)
|
||||
|
||||
# Get residual
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
# MLP.
|
||||
output = self.mlp(layernorm_output) + residual
|
||||
return output
|
||||
|
||||
|
||||
class BloomModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Embedding + LN Embedding
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
self.embed_dim,
|
||||
)
|
||||
self.word_embeddings_layernorm = nn.LayerNorm(
|
||||
self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Transformer blocks
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: BloomBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.h")
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
hidden_states = self.word_embeddings_layernorm(hidden_states)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BloomForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = BloomModel(config, cache_config, quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.transformer.word_embeddings
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
if name == "lm_head.weight":
|
||||
continue
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
# NOTE: BLOOM's fused QKV's output_dim has the shape of
|
||||
# (num_heads * 3 * head_size), while the
|
||||
# required shape is (3 * num_heads * head_size).
|
||||
# Thus, we need weight conversion.
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
num_heads = self.config.num_attention_heads
|
||||
if output_dim is not None:
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = loaded_weight.transpose(
|
||||
output_dim, output_dim + 1)
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
1104
vllm/model_executor/models/chameleon.py
Normal file
1104
vllm/model_executor/models/chameleon.py
Normal file
File diff suppressed because it is too large
Load Diff
671
vllm/model_executor/models/chatglm.py
Normal file
671
vllm/model_executor/models/chatglm.py
Normal file
@@ -0,0 +1,671 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/GLM-4
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||
from argparse import Namespace
|
||||
from array import array
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalInputs)
|
||||
from vllm.multimodal.base import MultiModalData
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def calculate_image_placeholder(vision_config):
|
||||
return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
|
||||
|
||||
|
||||
def mm_input_mapper_for_glmv(
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[object],
|
||||
) -> Dict:
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the image object")
|
||||
try:
|
||||
raw_batch_data = tokenizer.apply_chat_template(
|
||||
conversation=[{
|
||||
"role": "user",
|
||||
"image": data
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True).data
|
||||
except Exception:
|
||||
logger.error("Failed to process image (%s)", data)
|
||||
raise
|
||||
pixel_values = raw_batch_data['images']
|
||||
|
||||
return MultiModalInputs({'pixel_values': pixel_values})
|
||||
|
||||
|
||||
def merge_glm_vision_embeddings(
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: torch.Tensor,
|
||||
boi_token_id: int,
|
||||
eoi_token_id: int,
|
||||
) -> torch.Tensor:
|
||||
|
||||
boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
|
||||
eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]
|
||||
|
||||
mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
||||
|
||||
for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
|
||||
assert boi_pos < eoi_pos
|
||||
mask[boi_pos:eoi_pos + 1] = True
|
||||
inputs_embeds[mask] = vision_embeddings.view(-1,
|
||||
vision_embeddings.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class GLMImagePixelInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
|
||||
|
||||
def get_max_glmv_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
||||
|
||||
vision_config = getattr(hf_config, 'vision_config', None)
|
||||
if vision_config is None:
|
||||
return 1
|
||||
elif isinstance(vision_config, dict):
|
||||
return calculate_image_placeholder(vision_config)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def dummy_data_for_glmv(
|
||||
ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
|
||||
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
|
||||
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
||||
vision_config = getattr(hf_config, 'vision_config', None)
|
||||
|
||||
if vision_config is None:
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
|
||||
seq_data = SequenceData(token_ids)
|
||||
return seq_data, None
|
||||
elif isinstance(vision_config, dict):
|
||||
image_size = vision_config["image_size"]
|
||||
image_placeholder_length = calculate_image_placeholder(vision_config)
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
|
||||
[0] * image_placeholder_length +
|
||||
[hf_config.eoi_token_id])
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0] * (seq_len - image_placeholder_length - 2))
|
||||
seq_data = SequenceData(token_ids)
|
||||
|
||||
mm_data = {
|
||||
"image": Image.new("RGB", (image_size, image_size), color=0)
|
||||
}
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def find_all_positions(input_ids: List[int], target: int) -> List[int]:
|
||||
return [index for index, value in enumerate(input_ids) if value == target]
|
||||
|
||||
|
||||
def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
||||
vision_config = getattr(hf_config, 'vision_config', None)
|
||||
|
||||
if vision_config is None:
|
||||
return llm_inputs
|
||||
elif isinstance(vision_config, dict):
|
||||
image_placeholder_length = calculate_image_placeholder(vision_config)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
input_ids = llm_inputs.get("prompt_token_ids")
|
||||
position_ids = llm_inputs.get("position_ids")
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.model,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code)
|
||||
|
||||
try:
|
||||
raw_batch_data = tokenizer.apply_chat_template(
|
||||
conversation=[{
|
||||
"role": "user",
|
||||
"image": llm_inputs['multi_modal_data']["image"],
|
||||
"content": llm_inputs['prompt']
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True).data
|
||||
except Exception:
|
||||
logger.error("Failed to process content (%s)", llm_inputs['prompt'])
|
||||
raise
|
||||
input_ids = raw_batch_data['input_ids'][0].tolist()
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = list(range(len(input_ids)))
|
||||
boi_token_id = hf_config.boi_token_id
|
||||
eoi_token_id = hf_config.eoi_token_id
|
||||
boi_positions = find_all_positions(input_ids, boi_token_id)
|
||||
eoi_positions = find_all_positions(input_ids, eoi_token_id)
|
||||
|
||||
assert len(boi_positions) == len(eoi_positions)
|
||||
|
||||
new_input_ids = []
|
||||
new_position_ids = []
|
||||
final_processed_position = 0
|
||||
final_processed_position = 0
|
||||
|
||||
for boi_position, eoi_position in zip(boi_positions, eoi_positions):
|
||||
assert boi_position < eoi_position
|
||||
new_input_ids.extend(input_ids[final_processed_position:boi_position +
|
||||
1])
|
||||
new_position_ids.extend(
|
||||
list(range(final_processed_position, boi_position + 1)))
|
||||
new_input_ids.extend([input_ids[boi_position + 1]] *
|
||||
image_placeholder_length)
|
||||
new_position_ids.extend([boi_position + 1] * image_placeholder_length)
|
||||
final_processed_position = eoi_position
|
||||
|
||||
new_input_ids.extend(input_ids[final_processed_position:])
|
||||
new_position_ids.extend(
|
||||
list(range(final_processed_position, len(input_ids))))
|
||||
|
||||
assert len(new_input_ids) == len(new_position_ids)
|
||||
|
||||
llm_inputs["prompt_token_ids"] = new_input_ids
|
||||
llm_inputs["position_ids"] = new_position_ids
|
||||
return llm_inputs
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.multi_query_attention = config.multi_query_attention
|
||||
self.total_num_kv_heads = (config.multi_query_group_num
|
||||
if config.multi_query_attention else
|
||||
config.num_attention_heads)
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = config.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||
rope_ratio = getattr(config, "rope_ratio", 1.0)
|
||||
max_positions = getattr(config, "seq_length", 8192)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim // 2,
|
||||
max_position=max_positions,
|
||||
base=10000 * rope_ratio,
|
||||
is_neox_style=False,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
context_layer = self.attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
)
|
||||
attn_output, _ = self.dense(context_layer)
|
||||
return attn_output
|
||||
|
||||
|
||||
class GLMMLP(nn.Module):
|
||||
"""MLP.
|
||||
|
||||
MLP will take the input with h hidden state, project it to 4*h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.add_bias = config.add_bias_linear
|
||||
|
||||
# Project to 4h.
|
||||
self.dense_h_to_4h = MergedColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
[config.ffn_hidden_size] * 2,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.activation_func = SiluAndMul()
|
||||
|
||||
# Project back to h.
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.ffn_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# [s, b, 4hp]
|
||||
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
|
||||
intermediate_parallel = self.activation_func(intermediate_parallel)
|
||||
# [s, b, h]
|
||||
output, _ = self.dense_4h_to_h(intermediate_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class GLMBlock(nn.Module):
|
||||
"""A single transformer layer.
|
||||
|
||||
Transformer layer takes input with size [s, b, h] and returns an
|
||||
output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm)
|
||||
|
||||
self.fp32_residual_connection = config.fp32_residual_connection
|
||||
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = layer_norm_func(config.hidden_size,
|
||||
eps=config.layernorm_epsilon)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = GLMAttention(config, cache_config, quant_config)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = layer_norm_func(
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
# MLP
|
||||
self.mlp = GLMMLP(config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# hidden_states: [num_tokens, h]
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
# Self attention.
|
||||
attention_output = self.self_attention(
|
||||
hidden_states=layernorm_output,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
layernorm_input = residual + attention_output
|
||||
|
||||
# Layer norm post the self attention.
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
|
||||
# Second residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
output = self.mlp(layernorm_output) + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GLMTransformer(nn.Module):
|
||||
"""Transformer class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.post_layer_norm = config.post_layer_norm
|
||||
|
||||
# Number of layers.
|
||||
self.num_layers = config.num_layers
|
||||
|
||||
# Transformer layers.
|
||||
self.layers = nn.ModuleList([
|
||||
GLMBlock(config, cache_config, quant_config)
|
||||
for i in range(self.num_layers)
|
||||
])
|
||||
|
||||
if self.post_layer_norm:
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = layer_norm_func(
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_caches[i],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
# Final layer norm.
|
||||
if self.post_layer_norm:
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ChatGLMModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.num_layers = config.num_layers
|
||||
self.multi_query_group_num = config.multi_query_group_num
|
||||
self.kv_channels = config.kv_channels
|
||||
self.encoder = GLMTransformer(config, cache_config, quant_config)
|
||||
|
||||
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
vision_config_flag = getattr(config, 'vision_config', None)
|
||||
if vision_config_flag is not None:
|
||||
self.vision_config = Namespace(**config.vision_config)
|
||||
self.vision = EVA2CLIPModel(self.config, quant_config)
|
||||
else:
|
||||
self.vision = None
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> GLMImagePixelInputs:
|
||||
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if pixel_values is not None and self.vision is not None:
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
if pixel_values.ndim > 2:
|
||||
pixel_values = torch.concat(list(pixel_values))
|
||||
elif isinstance(pixel_values, list):
|
||||
return torch.concat(pixel_values)
|
||||
else:
|
||||
raise TypeError("""pixel_values must be a torch.Tensor
|
||||
or a list of torch.Tensor
|
||||
""")
|
||||
return GLMImagePixelInputs(pixel_values=pixel_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input["pixel_values"] is not None:
|
||||
pixel_values = image_input["pixel_values"].to(
|
||||
dtype=inputs_embeds.dtype)
|
||||
image_embeds = self.vision(pixel_values)
|
||||
|
||||
boi_token_id = self.config.boi_token_id
|
||||
eoi_token_id = self.config.eoi_token_id
|
||||
|
||||
inputs_embeds = merge_glm_vision_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
vision_embeddings=image_embeds,
|
||||
boi_token_id=boi_token_id,
|
||||
eoi_token_id=eoi_token_id)
|
||||
|
||||
# Run encoder.
|
||||
hidden_states = self.encoder(
|
||||
hidden_states=inputs_embeds,
|
||||
position_ids=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
|
||||
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"query_key_value",
|
||||
"dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
||||
8192)
|
||||
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.transformer.output_layer.weight = (
|
||||
self.transformer.embedding.weight)
|
||||
self.lm_head = self.transformer.output_layer
|
||||
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
|
||||
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
|
||||
"transformer.vision.linear_proj.merged_proj.weight": {
|
||||
"transformer.vision.linear_proj.gate_proj.weight": None,
|
||||
"transformer.vision.linear_proj.dense_h_to_4h.weight": None,
|
||||
}
|
||||
}
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
is_weight_to_be_merge = False
|
||||
for _, merged_weight_dict in merged_weights_dict.items():
|
||||
if name in merged_weight_dict:
|
||||
assert merged_weight_dict[name] is None
|
||||
merged_weight_dict[name] = loaded_weight
|
||||
is_weight_to_be_merge = True
|
||||
if is_weight_to_be_merge:
|
||||
continue
|
||||
if "rotary_pos_emb.inv_freq" in name:
|
||||
continue
|
||||
if "word_embeddings" in name:
|
||||
name = name.replace(".word_embeddings", "")
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
for combined_name, merged_weight_dict in merged_weights_dict.items():
|
||||
if combined_name in params_dict:
|
||||
param = params_dict[combined_name]
|
||||
combined_weight = torch.cat(list(merged_weight_dict.values()),
|
||||
dim=0)
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, combined_weight)
|
||||
474
vllm/model_executor/models/clip.py
Normal file
474
vllm/model_executor/models/clip.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""Minimal implementation of CLIPVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig
|
||||
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
USE_XFORMERS_OPS = True
|
||||
|
||||
|
||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||
grid_length = get_clip_patch_grid_length(image_size=image_size,
|
||||
patch_size=patch_size)
|
||||
return grid_length * grid_length
|
||||
|
||||
|
||||
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
|
||||
return get_clip_num_patches(image_size=hf_config.image_size,
|
||||
patch_size=hf_config.patch_size) + 1
|
||||
|
||||
|
||||
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
|
||||
return get_clip_image_feature_size(hf_config)
|
||||
|
||||
|
||||
def dummy_seq_data_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_clip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
|
||||
|
||||
def dummy_image_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
num_images: int,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
):
|
||||
width = height = hf_config.image_size
|
||||
if image_width_override is not None:
|
||||
width = image_width_override
|
||||
if image_height_override is not None:
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def dummy_video_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
num_frames: int,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
):
|
||||
pil_frame = dummy_image_for_clip(
|
||||
hf_config,
|
||||
num_images=1,
|
||||
image_width_override=image_width_override,
|
||||
image_height_override=image_height_override)
|
||||
np_frame = np.array(pil_frame["image"])
|
||||
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
|
||||
mm_data = {"video": mm_data_per_video}
|
||||
return mm_data
|
||||
|
||||
|
||||
def input_processor_for_clip(
|
||||
model_config: ModelConfig,
|
||||
hf_config: CLIPVisionConfig,
|
||||
llm_inputs: LLMInputs,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[Union[int, List[int]]] = None,
|
||||
):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
if image_feature_size_override is None:
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
image_feature_size = get_clip_image_feature_size(hf_config)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
placeholder_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: CLIPVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.num_patches = get_clip_num_patches(image_size=self.image_size,
|
||||
patch_size=self.patch_size)
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = nn.Embedding(self.num_positions,
|
||||
self.embed_dim)
|
||||
self.register_buffer("position_ids",
|
||||
torch.arange(self.num_positions).expand((1, -1)),
|
||||
persistent=False)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(
|
||||
dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPParallelAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
query_states = query_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
query_states = query_states.transpose(1,2)
|
||||
key_states = key_states.transpose(1,2)
|
||||
value_states = value_states.transpose(1,2)
|
||||
attn = (query_states @ (key_states * self.scale).permute(0,1,3,2)).float()
|
||||
attn = torch.softmax(attn, dim=-1).to(value_states.dtype)
|
||||
out = attn @ value_states
|
||||
out = out.transpose(1,2).reshape(bsz, tgt_len, -1)
|
||||
# TODO: use F.attention when supported
|
||||
# out = xops.memory_efficient_attention_forward(query_states,
|
||||
# key_states,
|
||||
# value_states,
|
||||
# p=self.dropout,
|
||||
# scale=self.scale)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class CLIPMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = CLIPParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.self_attn = CLIPSdpaAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = CLIPMLP(config, quant_config=quant_config)
|
||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self
|
||||
attention layers. Each layer is a [`CLIPEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: CLIPConfig
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
self.layers = nn.ModuleList([
|
||||
CLIPEncoderLayer(config=config, quant_config=quant_config)
|
||||
for _ in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, inputs_embeds: torch.Tensor):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPVisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(config)
|
||||
|
||||
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
|
||||
# the original transformers code and name of the model weights.
|
||||
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.encoder = CLIPEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override)
|
||||
|
||||
if len(self.encoder.layers) > config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"The original encoder only has {config.num_hidden_layers} "
|
||||
f"layers, but you requested {len(self.encoder.layers)} layers."
|
||||
)
|
||||
elif len(self.encoder.layers) == config.num_hidden_layers:
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
else:
|
||||
# post_layernorm is unused when we extract intermediate features
|
||||
# In this case, we can skip it to conserve memory
|
||||
self.post_layernorm = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
hidden_states = self.encoder(inputs_embeds=hidden_states)
|
||||
|
||||
if self.post_layernorm is None:
|
||||
return hidden_states
|
||||
|
||||
return self.post_layernorm(hidden_states)
|
||||
|
||||
|
||||
class CLIPVisionModel(nn.Module):
|
||||
|
||||
config_class = CLIPVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
return self.vision_model(pixel_values)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
# (TODO) Add prefix argument for filtering out weights to be loaded
|
||||
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
] if self.shard_weight else []
|
||||
params_dict = dict(self.named_parameters())
|
||||
layer_count = len(self.vision_model.encoder.layers)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# post_layernorm is not needed in CLIPVisionModel
|
||||
if (name.startswith("vision_model.post_layernorm")
|
||||
and self.vision_model.post_layernorm is None):
|
||||
continue
|
||||
|
||||
# omit layers when num_hidden_layers_override is set
|
||||
if name.startswith("vision_model.encoder.layers"):
|
||||
layer_idx = int(name.split(".")[3])
|
||||
if layer_idx >= layer_count:
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
441
vllm/model_executor/models/commandr.py
Normal file
441
vllm/model_executor/models/commandr.py
Normal file
@@ -0,0 +1,441 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This file is based on the LLama model definition file in transformers
|
||||
"""PyTorch Cohere model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers import CohereConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name,
|
||||
row_parallel_weight_loader)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
@torch.compile
|
||||
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
mean = hidden_states.mean(-1, keepdim=True)
|
||||
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
|
||||
variance_epsilon)
|
||||
hidden_states = weight.to(torch.float32) * hidden_states
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, param_shape=None, eps=1e-5):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(param_shape))
|
||||
self.variance_epsilon = eps
|
||||
set_weight_attrs(self.weight,
|
||||
{"weight_loader": row_parallel_weight_loader})
|
||||
|
||||
def forward(self, hidden_states, residuals=None):
|
||||
hidden_states = layer_norm_func(hidden_states, self.weight,
|
||||
self.variance_epsilon)
|
||||
return hidden_states, residuals
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
||||
class CohereMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class CohereAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.config = config
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.total_num_kv_heads = config.num_key_value_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.max_position_embeddings = getattr(
|
||||
config, "model_max_length", None) or getattr(
|
||||
config, "max_position_embeddings", 8192)
|
||||
self.rope_theta = config.rope_theta
|
||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=self.rope_scaling,
|
||||
is_neox_style=False,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
||||
self.head_dim),
|
||||
eps=config.layer_norm_eps)
|
||||
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
|
||||
self.head_dim),
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def _apply_qk_norm(self, q, k):
|
||||
q = q.view(*q.shape[:-1], -1, self.head_dim)
|
||||
k = k.view(*k.shape[:-1], -1, self.head_dim)
|
||||
q, _ = self.q_norm(q)
|
||||
k, _ = self.k_norm(k)
|
||||
q = q.view(*q.shape[:-2], -1)
|
||||
k = k.view(*k.shape[:-2], -1)
|
||||
return q, k
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_qk_norm:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class CohereDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CohereConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = CohereAttention(config,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.mlp = CohereMLP(config, quant_config=quant_config)
|
||||
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states_attention = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
# Add everything together
|
||||
hidden_states = residual + hidden_states_attention + hidden_states_mlp
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class CohereModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: CohereDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
||||
eps=config.layer_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
|
||||
]
|
||||
embedding_modules = {"embed_tokens": "input_embeddings"}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# currently all existing command R models have `tie_word_embeddings`
|
||||
# enabled
|
||||
assert config.tie_word_embeddings
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.quant_config = quant_config
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
scale=config.logit_scale)
|
||||
self.model = CohereModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
|
||||
if is_not_lora:
|
||||
logits = self.logits_processor(self.model.embed_tokens,
|
||||
hidden_states, sampling_metadata)
|
||||
else:
|
||||
logits = self.logits_processor(self.model.embed_tokens.base_layer,
|
||||
hidden_states, sampling_metadata)
|
||||
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
name = name.replace(shard_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# lm_head is not used in vllm as it is tied with embed_token.
|
||||
# To prevent errors, skip loading lm_head.weight.
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
439
vllm/model_executor/models/dbrx.py
Normal file
439
vllm/model_executor/models/dbrx.py
Normal file
@@ -0,0 +1,439 @@
|
||||
# coding=utf-8
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class DbrxRouter(nn.Module):
|
||||
"""A Router implementation for DBRX that returns logits for each expert
|
||||
per token.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = config.ffn_config.moe_num_experts
|
||||
self.d_model = config.d_model
|
||||
self.layer = ReplicatedLinear(
|
||||
self.d_model,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=None,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
router_logits, _ = self.layer(hidden_states)
|
||||
return router_logits
|
||||
|
||||
|
||||
class DbrxExperts(FusedMoE):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=config.ffn_config.moe_num_experts,
|
||||
top_k=config.ffn_config.moe_top_k,
|
||||
hidden_size=config.d_model,
|
||||
intermediate_size=config.ffn_config.ffn_hidden_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=True,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
)
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.d_model = config.d_model
|
||||
self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
|
||||
self.tp_size)
|
||||
|
||||
# Define custom weight loader for dbrx model
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||
weight_name: str):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param.data
|
||||
shard_size = self.intermediate_size
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
# DBRX uses GLU for each experts.
|
||||
# GLU has 3 linear layers: w1, v1 and w2.
|
||||
if weight_name.endswith("w1"):
|
||||
loaded_weight = torch.reshape(
|
||||
loaded_weight,
|
||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||
)
|
||||
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
|
||||
if weight_name.endswith("v1"):
|
||||
loaded_weight = torch.reshape(
|
||||
loaded_weight,
|
||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||
)
|
||||
param_data[:,
|
||||
shard_size:2 * shard_size, :] = loaded_weight[:,
|
||||
shard, :]
|
||||
if weight_name.endswith("w2"):
|
||||
loaded_weight = torch.reshape(
|
||||
loaded_weight,
|
||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||
).transpose(1, 2)
|
||||
param_data[:] = loaded_weight[:, :, shard]
|
||||
|
||||
|
||||
class DbrxMoE(nn.Module):
|
||||
"""A tensor-parallel MoE implementation for DBRX.
|
||||
|
||||
Each expert's weights are sharded across all ranks and a fused MoE
|
||||
kernel is used for the forward pass, and finally we reduce the outputs
|
||||
across ranks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
self.router = DbrxRouter(config, self.params_dtype)
|
||||
|
||||
self.experts = DbrxExperts(config=config,
|
||||
quant_config=quant_config,
|
||||
params_dtype=self.params_dtype)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.d_model)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.router(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class DbrxAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.total_num_heads = config.n_heads
|
||||
self.head_dim = self.d_model // self.total_num_heads
|
||||
self.total_num_kv_heads = config.attn_config.kv_n_heads
|
||||
self.clip_qkv = config.attn_config.clip_qkv
|
||||
self.rope_theta = config.attn_config.rope_theta
|
||||
self.max_position = config.max_seq_len
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.Wqkv = QKVParallelLinear(
|
||||
self.d_model,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.d_model,
|
||||
self.d_model,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position,
|
||||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_size = tp_world_size
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
if self.total_num_kv_heads >= tp_world_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_world_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_world_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.Wqkv(hidden_states)
|
||||
if self.clip_qkv is not None:
|
||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
hidden_states, _ = self.out_proj(attn_output)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DbrxFusedNormAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.attn = DbrxAttention(config, cache_config, quant_config)
|
||||
self.norm_1 = nn.LayerNorm(self.d_model)
|
||||
self.norm_2 = nn.LayerNorm(self.d_model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_1(hidden_states)
|
||||
x = self.attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + x
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_2(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class DbrxBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
|
||||
quant_config)
|
||||
self.ffn = DbrxMoE(config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, residual = self.norm_attn_norm(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = self.ffn(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DbrxModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.wte = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
)
|
||||
self.start_layer, self.end_layer, self.blocks = make_layers(
|
||||
config.n_layers,
|
||||
lambda prefix: DbrxBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.blocks",
|
||||
)
|
||||
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
||||
for module in self.modules():
|
||||
if hasattr(module, "bias") and isinstance(module.bias,
|
||||
nn.Parameter):
|
||||
# Remove the bias term in Linear and LayerNorm.
|
||||
module.register_parameter("bias", None)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.d_model))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.wte(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
block = self.blocks[i]
|
||||
hidden_states = block(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if config.tie_word_embeddings:
|
||||
raise ValueError(
|
||||
"tie_word_embeddings is not supported for Dbrx models.")
|
||||
self.quant_config = quant_config
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.transformer = DbrxModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
expert_params_mapping = [(
|
||||
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
|
||||
f"mlp.{weight_name}",
|
||||
) for weight_name in ["w1", "v1", "w2"]]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, weight_name)
|
||||
break
|
||||
else:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
129
vllm/model_executor/models/decilm.py
Normal file
129
vllm/model_executor/models/decilm.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 DeciAI Research Team. All rights reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only DeciLM model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
|
||||
from .utils import is_pp_missing_parameter
|
||||
|
||||
|
||||
class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
"""
|
||||
Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
|
||||
Based on the llama executor.
|
||||
|
||||
The main difference is that DeciLM uses Variable Grouped Query Attention.
|
||||
The constant number of GQA heads in the decoder is overridden with a value
|
||||
per layer.
|
||||
|
||||
Usually, in the HuggingFace implementation, instead of
|
||||
"config.num_key_value_heads", we use
|
||||
"config.num_key_value_heads_per_layer[i]" which varies.
|
||||
|
||||
Currently, PagedAttention does not work well with variable GQA, so we
|
||||
normalize the weights upon loading, and use uniform GQA with the max value
|
||||
instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
|
||||
delattr(config, "num_key_value_heads_per_layer")
|
||||
super().__init__(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
lora_config=lora_config)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "k_proj" in name or "v_proj" in name:
|
||||
loaded_weight = self._degroup_weight(loaded_weight)
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = self.config.hidden_size // self.config.num_attention_heads
|
||||
target_num_kv_heads = self.config.num_key_value_heads
|
||||
num_kv_heads = loaded_weight.shape[0] // head_size
|
||||
n_repeats = target_num_kv_heads / num_kv_heads
|
||||
assert n_repeats == int(n_repeats)
|
||||
|
||||
n_repeats = int(n_repeats)
|
||||
loaded_weight = loaded_weight.view(num_kv_heads, head_size,
|
||||
hidden_size)
|
||||
loaded_weight = torch.repeat_interleave(loaded_weight,
|
||||
repeats=n_repeats,
|
||||
dim=0)
|
||||
loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size,
|
||||
hidden_size)
|
||||
|
||||
return loaded_weight
|
||||
480
vllm/model_executor/models/deepseek.py
Normal file
480
vllm/model_executor/models/deepseek.py
Normal file
@@ -0,0 +1,480 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Deepseek model."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class DeepseekMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class DeepseekMoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
if self.tp_size > self.n_routed_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {self.n_routed_experts}.")
|
||||
|
||||
self.experts = nn.ModuleList([
|
||||
DeepseekMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False)
|
||||
for idx in range(self.n_routed_experts)
|
||||
])
|
||||
self.pack_params()
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
self.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
self.shared_experts = DeepseekMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
)
|
||||
|
||||
def pack_params(self):
|
||||
w1 = []
|
||||
w2 = []
|
||||
for expert in self.experts:
|
||||
w1.append(expert.gate_up_proj.weight)
|
||||
w2.append(expert.down_proj.weight)
|
||||
self.w1 = torch._utils._flatten_dense_tensors(w1)
|
||||
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
|
||||
for data, param in zip(w1s, w1):
|
||||
param.data = data
|
||||
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
|
||||
|
||||
self.w2 = torch._utils._flatten_dense_tensors(w2)
|
||||
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
|
||||
for data, param in zip(w2s, w2):
|
||||
param.data = data
|
||||
|
||||
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
if self.config.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.w1,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=self.config.norm_topk_prob,
|
||||
inplace=True)
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class DeepseekAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = DeepseekAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0):
|
||||
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
|
||||
else:
|
||||
self.mlp = DeepseekMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class DeepseekModel(nn.Module):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: DeepseekDecoderLayer(config,
|
||||
int(prefix.split(".")[-1]),
|
||||
cache_config,
|
||||
quant_config=quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DeepseekForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user