This commit is contained in:
2026-01-09 13:34:11 +08:00
parent dfa6476b58
commit b2ef04d792
538 changed files with 105693 additions and 2 deletions

View File

@@ -0,0 +1,119 @@
import importlib
from typing import Dict, List, Optional, Type
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__)
# Architecture -> (module, class).
_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
}
# Architecture -> type.
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Qwen2ForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
}
class ModelRegistry:
@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
return None
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
"Model architecture %s is partially supported by ROCm: %s",
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())
@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls.__name__)
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls
__all__ = [
"ModelRegistry",
]

View File

@@ -0,0 +1,410 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BaiChuanMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class BaiChuanAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size: int,
num_heads: int,
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name
self.W_pack = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
else:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class BaiChuanDecoderLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class BaiChuanModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class BaiChuanBaseForCausalLM(nn.Module):
packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config,
position_embedding: str,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name == "lm_head.weight":
# Unlike Baichuan, Baichuan2 normalizes the head weights.
# Refer to:
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
# Distinguish between Baichuan and Baichuan2 by checking the
# vocab size. This is suggested by
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
is_baichuan2 = self.config.vocab_size == 125696
if is_baichuan2:
loaded_weight = torch.nn.functional.normalize(
loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B."""
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", quant_config, lora_config)
else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", quant_config, lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 7B."""
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__(config, "ROPE", quant_config, lora_config)

View File

@@ -0,0 +1,327 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
# Copyright 2023 The vLLM team.
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BloomAttention(nn.Module):
def __init__(
self,
config: BloomConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
self.total_num_heads = config.n_head
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.dense(attn_output)
return output
class BloomMLP(nn.Module):
def __init__(
self,
config: BloomConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
4 * hidden_size,
quant_config=quant_config,
)
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.dense_h_to_4h(x)
x = self.gelu_impl(x)
x, _ = self.dense_4h_to_h(x)
return x
class BloomBlock(nn.Module):
def __init__(
self,
config: BloomConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, quant_config)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, quant_config)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attention_output = self.self_attention(
position_ids=position_ids,
hidden_states=layernorm_output,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output) + residual
return output
class BloomModel(nn.Module):
def __init__(
self,
config: BloomConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.embed_dim = config.hidden_size
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size,
self.embed_dim,
)
self.word_embeddings_layernorm = nn.LayerNorm(
self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.h = nn.ModuleList([
BloomBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
# Final Layer Norm
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class BloomForCausalLM(nn.Module):
def __init__(
self,
config: BloomConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = BloomModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
if not name.startswith("transformer."):
name = "transformer." + name
param = params_dict[name]
if "query_key_value" in name:
# NOTE: BLOOM's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,386 @@
# coding=utf-8
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
class GLMAttention(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.multi_query_attention = config.multi_query_attention
self.total_num_kv_heads = (config.multi_query_group_num
if config.multi_query_attention else
config.num_attention_heads)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim // 2,
max_position=max_positions,
base=10000 * rope_ratio,
is_neox_style=False,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(
q,
k,
v,
kv_cache,
attn_metadata,
)
attn_output, _ = self.dense(context_layer)
return attn_output
class GLMMLP(nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.add_bias = config.add_bias_linear
# Project to 4h.
self.dense_h_to_4h = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
quant_config=quant_config,
)
self.activation_func = SiluAndMul()
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, _ = self.dense_4h_to_h(intermediate_parallel)
return output
class GLMBlock(nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
self.fp32_residual_connection = config.fp32_residual_connection
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = layer_norm_func(config.hidden_size,
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config, quant_config)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
self.mlp = GLMMLP(config, quant_config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.self_attention(
hidden_states=layernorm_output,
position_ids=position_ids,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = residual + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = self.mlp(layernorm_output) + residual
return output
class GLMTransformer(nn.Module):
"""Transformer class."""
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
# Number of layers.
self.num_layers = config.num_layers
# Transformer layers.
self.layers = nn.ModuleList(
[GLMBlock(config, quant_config) for i in range(self.num_layers)])
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
for i in range(self.num_layers):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i],
attn_metadata=attn_metadata,
)
# Final layer norm.
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class ChatGLMModel(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
# Run encoder.
hidden_states = self.encoder(
hidden_states=inputs_embeds,
position_ids=position_ids,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return hidden_states
class ChatGLMForCausalLM(nn.Module):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
self.quant_config = quant_config
self.transformer = ChatGLMModel(config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
name = name.replace(".word_embeddings", "")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,373 @@
# coding=utf-8
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from typing import Iterable, List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn.parameter import Parameter
from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
@torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
variance_epsilon)
hidden_states = weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)
class LayerNorm(nn.Module):
def __init__(self, param_shape=None, eps=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(param_shape))
self.variance_epsilon = eps
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
def forward(self, hidden_states, residuals=None):
hidden_states = layer_norm_func(hidden_states, self.weight,
self.variance_epsilon)
return hidden_states, residuals
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
param_data = param.data
if shard_dim is not None:
shard_size = param_data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class CohereAttention(nn.Module):
def __init__(
self,
config: CohereConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.total_num_heads = config.num_attention_heads
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.hidden_size // self.total_num_heads
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = getattr(
config, "model_max_length", None) or getattr(
config, "max_position_embeddings", 8192)
self.rope_theta = config.rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None)
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=self.rope_scaling,
is_neox_style=False,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
)
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim),
eps=config.layer_norm_eps)
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
self.head_dim),
eps=config.layer_norm_eps)
def _apply_qk_norm(self, q, k):
q = q.view(*q.shape[:-1], -1, self.head_dim)
k = k.view(*k.shape[:-1], -1, self.head_dim)
q, _ = self.q_norm(q)
k, _ = self.k_norm(k)
q = q.view(*q.shape[:-2], -1)
k = k.view(*k.shape[:-2], -1)
return q, k
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class CohereDecoderLayer(nn.Module):
def __init__(self,
config: CohereConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config, quant_config=quant_config)
self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states_attention = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states_mlp = self.mlp(hidden_states)
# Add everything together
hidden_states = residual + hidden_states_attention + hidden_states_mlp
return hidden_states, residual
class CohereModel(nn.Module):
def __init__(
self,
config: CohereConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
CohereDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class CohereForCausalLM(nn.Module):
def __init__(
self,
config: CohereConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config.vocab_size,
scale=config.logit_scale)
self.model = CohereModel(config, quant_config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight,
hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)

View File

@@ -0,0 +1,413 @@
# coding=utf-8
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.dbrx import DbrxConfig
class DbrxRouter(nn.Module):
"""A Router implementation for DBRX that returns logits for each expert
per token.
"""
def __init__(
self,
config: DbrxConfig,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.ffn_config.moe_num_experts
self.d_model = config.d_model
self.layer = ReplicatedLinear(
self.d_model,
self.num_total_experts,
bias=False,
params_dtype=params_dtype,
quant_config=None,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.layer(hidden_states)
return router_logits
class DbrxExperts(nn.Module):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.ffn_config.moe_num_experts
self.top_k = config.ffn_config.moe_top_k
self.d_model = config.d_model
self.intermediate_size = (config.ffn_config.ffn_hidden_size //
self.tp_size)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.router = DbrxRouter(config, self.params_dtype)
self.ws = nn.Parameter(
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.d_model,
device="cuda",
dtype=self.params_dtype,
))
self.w2s = nn.Parameter(
torch.empty(
self.num_total_experts,
self.d_model,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype,
))
set_weight_attrs(
self.ws,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2s,
{
"weight_loader": self.weight_loader,
},
)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
# DBRX uses GLU for each experts.
# GLU has 3 linear layers: w1, v1 and w2.
if weight_name.endswith("w1"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
if weight_name.endswith("v1"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:,
shard_size:2 * shard_size, :] = loaded_weight[:,
shard, :]
if weight_name.endswith("w2"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.d_model)
# router_logits: (num_tokens, n_experts)
router_logits = self.router(hidden_states)
final_hidden_states = fused_moe(
hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
class DbrxAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.head_dim = self.d_model // self.total_num_heads
self.total_num_kv_heads = config.attn_config.kv_n_heads
self.clip_qkv = config.attn_config.clip_qkv
self.rope_theta = config.attn_config.rope_theta
self.max_position = config.max_seq_len
# pylint: disable=invalid-name
self.Wqkv = QKVParallelLinear(
self.d_model,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
tp_world_size = get_tensor_model_parallel_world_size()
self.tp_size = tp_world_size
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
hidden_states, _ = self.out_proj(attn_output)
return hidden_states
class DbrxFusedNormAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.attn = DbrxAttention(config, quant_config)
self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + x
residual = hidden_states
hidden_states = self.norm_2(hidden_states)
return hidden_states, residual
class DbrxBlock(nn.Module):
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
self.ffn = DbrxExperts(config, quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states, residual = self.norm_attn_norm(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = self.ffn(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DbrxModel(nn.Module):
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
)
self.blocks = nn.ModuleList(
[DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias,
nn.Parameter):
# Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.blocks)):
block = self.blocks[i]
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = self.norm_f(hidden_states)
return hidden_states
class DbrxForCausalLM(nn.Module):
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.d_model,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [(
"ws" if weight_name in ["w1", "v1"] else "w2s",
f"experts.mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, weight_name)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,122 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 DeciAI Research Team. All rights reserved.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Tuple
import torch
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
class DeciLMForCausalLM(LlamaForCausalLM):
"""
Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
Based on the llama executor.
The main difference is that DeciLM uses Variable Grouped Query Attention.
The constant number of GQA heads in the decoder is overridden with a value
per layer.
Usually, in the HuggingFace implementation, instead of
"config.num_key_value_heads", we use
"config.num_key_value_heads_per_layer[i]" which varies.
Currently, PagedAttention does not work well with variable GQA, so we
normalize the weights upon loading, and use uniform GQA with the max value
instead.
"""
def __init__(
self,
config: Optional[PretrainedConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config,
quant_config=quant_config,
lora_config=lora_config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "k_proj" in name or "v_proj" in name:
loaded_weight = self._degroup_weight(loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
hidden_size = self.config.hidden_size
head_size = self.config.hidden_size // self.config.num_attention_heads
target_num_kv_heads = self.config.num_key_value_heads
num_kv_heads = loaded_weight.shape[0] // head_size
n_repeats = target_num_kv_heads / num_kv_heads
assert n_repeats == int(n_repeats)
n_repeats = int(n_repeats)
loaded_weight = loaded_weight.view(num_kv_heads, head_size,
hidden_size)
loaded_weight = torch.repeat_interleave(loaded_weight,
repeats=n_repeats,
dim=0)
loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size,
hidden_size)
return loaded_weight

View File

@@ -0,0 +1,438 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class DeepseekMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekMoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.n_routed_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.n_routed_experts}.")
self.experts = nn.ModuleList([
DeepseekMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
self.pack_params()
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
bias=False,
quant_config=None)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
def pack_params(self):
w1 = []
w2 = []
for expert in self.experts:
w1.append(expert.gate_up_proj.weight)
w2.append(expert.down_proj.weight)
self.w1 = torch._utils._flatten_dense_tensors(w1)
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
for data, param in zip(w1s, w1):
param.data = data
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
self.w2 = torch._utils._flatten_dense_tensors(w2)
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
for data, param in zip(w2s, w2):
param.data = data
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.config.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True)
if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
class DeepseekAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class DeepseekDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = DeepseekAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
else:
self.mlp = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class DeepseekModel(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class DeepseekForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = DeepseekModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,444 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Falcon model."""
import math
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig
FalconConfig = Union[HF_FalconConfig, RWConfig]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(1,
1 + 2 * num_remaining_heads,
2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class FalconAttention(nn.Module):
def __init__(
self,
config: FalconConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
self.new_decoder_architecture = config.new_decoder_architecture
self.multi_query = config.multi_query
if self.new_decoder_architecture:
self.total_num_kv_heads = config.num_kv_heads
elif self.multi_query:
self.total_num_kv_heads = 1
else:
self.total_num_kv_heads = self.total_num_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary
self.use_alibi = config.alibi
assert not (self.use_rotary and self.use_alibi), (
"Rotary and alibi are mutually exclusive.")
if self.use_rotary:
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads)
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
self.inv_norm_factor)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = Attention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes)
else:
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_rotary:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, bias = self.dense(attn_output)
return attn_output, bias
class FalconMLP(nn.Module):
def __init__(
self,
config: FalconConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
bias=config.bias,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results,
quant_config=quant_config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
x, bias = self.dense_h_to_4h(x)
if bias is not None:
x += bias
x = self.act(x)
x, bias = self.dense_4h_to_h(x)
return x, bias
class FalconDecoderLayer(nn.Module):
def __init__(
self,
config: FalconConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, quant_config)
self.mlp = FalconMLP(config, quant_config)
self.config = config
if config.new_decoder_architecture:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
if self.config.new_decoder_architecture:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(
positions=positions,
hidden_states=attention_layernorm_out,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias
if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual += attention_output
mlp_layernorm_out = self.post_attention_layernorm(residual)
# MLP.
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
if self.reduce_row_parallel_results and mlp_bias is not None:
mlp_output += mlp_bias
if not self.reduce_row_parallel_results:
# When MLP and Attention layers are parallel, we can use
# only one all-reduce operator to reduce the results from
# both MLP and Attention layers.
mlp_output += attention_output
mlp_output = tensor_model_parallel_all_reduce(mlp_output)
if attention_bias is not None:
mlp_output += attention_bias
if mlp_bias is not None:
mlp_output += mlp_bias
output = mlp_output + residual
return output
class FalconModel(nn.Module):
def __init__(
self,
config: FalconConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_alibi = config.alibi
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size,
self.embed_dim,
)
# Transformer blocks
self.h = nn.ModuleList([
FalconDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class FalconForCausalLM(nn.Module):
def __init__(
self,
config: FalconConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
positions,
kv_caches,
attn_metadata,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
total_num_heads = self.config.num_attention_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
elif self.config.multi_query:
total_num_kv_heads = 1
else:
total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if name == "lm_head.weight":
# Falcon uses tied embeddings.
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if "query_key_value" in name:
output_dim = getattr(param, "output_dim", None)
loaded_weight_shape = loaded_weight.shape
if output_dim is not None:
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] +
(total_num_kv_heads, num_query_heads_per_kv_head + 2,
-1) + loaded_weight_shape[output_dim + 1:])
wq = loaded_weight.narrow(
output_dim + 1, 0,
num_query_heads_per_kv_head).reshape(
*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wk = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wv = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head + 1,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,394 @@
# coding=utf-8
# Copyright 2023 The vLLM team.
# Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
# Copyright (c) Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
logger = init_logger(__name__)
@lru_cache(maxsize=None)
def _get_gemma_act_fn(
hidden_act: Optional[str],
hidden_activation: Optional[str],
) -> nn.Module:
if hidden_activation is None:
if hidden_act is not None:
logger.warning(
"Gemma's activation function was incorrectly set to exact GeLU "
"in the config JSON file when it was initially released. "
"Changing the activation function to approximate GeLU "
"(`gelu_pytorch_tanh`). If you want to use the legacy "
"`%s`, edit the config JSON to set "
"`hidden_activation=%s` instead of `hidden_act`. "
"See https://github.com/huggingface/transformers/pull/29402 "
"for more details.", hidden_act, hidden_act)
return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu_pytorch_tanh":
return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu":
return GeluAndMul(approximate="none")
else:
raise ValueError(f"Activation function {hidden_act} is not "
"supported for Gemma models.")
class GemmaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: Optional[str] = None,
hidden_activation: Optional[str] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class GemmaAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int = 8192,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=self.rope_theta,
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: GemmaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GemmaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta,
quant_config=quant_config,
)
self.mlp = GemmaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
hidden_activation=getattr(config, "hidden_activation", None),
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class GemmaModel(nn.Module):
def __init__(
self,
config: GemmaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
GemmaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
# data type such as bfloat16, not float32.
# See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class GemmaForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: GemmaConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config # Unused.
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = GemmaModel(config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight,
hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name:
loaded_weight += 1.0
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")

View File

@@ -0,0 +1,267 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class GPT2Attention(nn.Module):
def __init__(
self,
config: GPT2Config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5
self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
)
self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
class GPT2MLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: GPT2Config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class GPT2Block(nn.Module):
def __init__(
self,
config: GPT2Config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
class GPT2Model(nn.Module):
def __init__(
self,
config: GPT2Config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
GPT2Block(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPT2LMHeadModel(nn.Module):
def __init__(
self,
config: GPT2Config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,275 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class GPTBigCodeAttention(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
self.tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % self.tensor_model_parallel_world_size == 0
self.num_heads = (total_num_heads //
self.tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5
self.multi_query = config.multi_query
if self.multi_query:
total_num_kv_heads = 1
self.num_kv_heads = 1
else:
total_num_kv_heads = total_num_heads
self.num_kv_heads = self.num_heads
self.kv_dim = self.head_dim * self.num_kv_heads
self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
total_num_kv_heads,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
num_kv_heads=self.num_kv_heads)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
[
self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim, self.kv_dim
],
dim=-1,
)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
class GPTBigMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class GPTBigCodeBlock(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
class GPTBigCodeModel(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
assert not config.add_cross_attention
self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
GPTBigCodeBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTBigCodeForCausalLM(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "lm_head.weight" in name:
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,281 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class GPTJAttention(nn.Module):
def __init__(
self,
config: GPTJConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_size,
self.total_num_heads,
bias=False,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
scaling = self.head_size**-0.5
assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=config.rotary_dim,
max_position=max_position_embeddings,
base=rope_theta,
is_neox_style=False,
)
self.attn = Attention(self.num_heads, self.head_size, scaling)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, _ = self.out_proj(attn_output)
return attn_output
class GPTJMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: GPTJConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear(
hidden_size,
intermediate_size,
quant_config=quant_config,
)
self.fc_out = RowParallelLinear(
intermediate_size,
hidden_size,
quant_config=quant_config,
)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.fc_out(hidden_states)
return hidden_states
class GPTJBlock(nn.Module):
def __init__(
self,
config: GPTJConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, quant_config)
self.mlp = GPTJMLP(inner_dim, config, quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
mlp_output = self.mlp(hidden_states)
hidden_states = attn_output + mlp_output + residual
return hidden_states
class GPTJModel(nn.Module):
def __init__(
self,
config: GPTJConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.n_embd
self.wte = VocabParallelEmbedding(
config.vocab_size,
self.embed_dim,
)
self.h = nn.ModuleList(
[GPTJBlock(config, quant_config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTJForCausalLM(nn.Module):
def __init__(
self,
config: GPTJConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.n_embd,
bias=True,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "attn.bias" in name or "attn.masked_bias" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,295 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class GPTNeoXAttention(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.bias = getattr(config, "attention_bias", True)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.query_key_value = QKVParallelLinear(
config.hidden_size,
self.head_size,
self.total_num_heads,
bias=self.bias,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=self.bias,
quant_config=quant_config,
)
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(self.num_heads, self.head_size, scaling)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.dense(attn_output)
return output
class GPTNeoXMLP(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
)
self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
def forward(self, hidden_states):
hidden_states, _ = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.dense_4h_to_h(hidden_states)
return hidden_states
class GPTNeoXLayer(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config, quant_config)
self.mlp = GPTNeoXMLP(config, quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
attn_input = self.input_layernorm(hidden_states)
attn_output = self.attention(
position_ids=position_ids,
hidden_states=attn_input,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_input = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(mlp_input)
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_input = self.post_attention_layernorm(attn_output)
mlp_output = self.mlp(mlp_input)
hidden_states = mlp_output + attn_output
return hidden_states
class GPTNeoXModel(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_in = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
GPTNeoXLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class GPTNeoXForCausalLM(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(config, quant_config)
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.embed_out.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them.
continue
param = params_dict[name]
if "query_key_value" in name:
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,323 @@
# -*- coding: utf-8 -*-
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class InternLM2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.w2 = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.w2(x)
return x
class InternLM2Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.wqkv = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.wo(attn_output)
return output
class InternLMDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.attention = InternLM2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
)
self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.attention_norm(hidden_states)
else:
hidden_states, residual = self.attention_norm(
hidden_states, residual)
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.ffn_norm(hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
class InternLM2Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.tok_embeddings = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
InternLMDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.tok_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class InternLM2ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = InternLM2Model(config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.output.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if "wqkv" in name:
config = self.config
kv_groups = (config.num_attention_heads //
config.num_key_value_heads)
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
head_dim,
loaded_weight.shape[-1])
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
dim=1)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
weight_loader = param.weight_loader
weight_loader(param, wq, 'q')
weight_loader(param, wk, 'k')
weight_loader(param, wv, 'v')
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,333 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
# reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import JAISConfig
class SwiGLUActivation(nn.Module):
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return x1 * nn.functional.silu(x2)
def _get_alibi_slopes(n):
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
class JAISAttention(nn.Module):
def __init__(
self,
config: JAISConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
if hasattr(config, "scale_qk_dot_by_d"):
config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
self.scale = self.head_dim**-self.attn_scale_power
self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
)
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end]
self.attn = Attention(
self.num_heads,
self.head_dim,
scale=self.scale,
alibi_slopes=alibi_slopes,
)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
class JAISMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: JAISConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.swiglu = config.activation_function == "swiglu"
self.c_fc = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config,
)
self.c_fc2 = (ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config,
) if self.swiglu else None)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
)
self.act = SwiGLUActivation()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.swiglu:
hidden_states2, _ = self.c_fc2(hidden_states)
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = (self.act(hidden_states, hidden_states2)
if self.swiglu else self.act(hidden_states))
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class JAISBlock(nn.Module):
def __init__(
self,
config: JAISConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = JAISAttention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, quant_config)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
class JAISModel(nn.Module):
def __init__(
self,
config: JAISConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = (nn.Embedding(config.max_position_embeddings,
self.embed_dim)
if config.position_embedding_type != "alibi" else None)
if hasattr(config, "embeddings_scale"):
self.embeddings_scale = config.embeddings_scale
else:
self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([
JAISBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
if self.wpe is not None:
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = inputs_embeds
hidden_states *= torch.tensor(float(self.embeddings_scale),
dtype=hidden_states.dtype)
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class JAISLMHeadModel(nn.Module):
def __init__(
self,
config: JAISConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = JAISModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale
else:
self.output_logits_scale = (config.mup_output_alpha *
config.mup_width_scale)
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
scale=self.output_logits_scale)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if "relative_pe" in name:
continue
if not name.startswith("transformer."):
name = "transformer." + name
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,442 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.utils import is_hip
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QKVParallelLinear] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# This will be overwritten by model initialization if we are using it.
# N.B. currently we only support per tensor scalar scaling factors
# & only applicable to ROCm (AMD GPU).
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
self.kv_scale = 1.0
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=bias,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale)
output, _ = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
sliding_window=sliding_window,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class LlamaForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.model = LlamaModel(config, quant_config, lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

View File

@@ -0,0 +1,239 @@
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VisionLanguageConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str):
super().__init__()
self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size,
bias=True)
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size,
bias=True)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def _merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)
inputs_embeds[mask] = vision_embeddings.view(-1,
vision_embeddings.shape[-1])
class LlavaForConditionalGeneration(nn.Module):
def __init__(self,
config: "LlavaConfig",
vision_language_config: VisionLanguageConfig,
quant_config: Optional["QuantizationConfig"] = None) -> None:
super().__init__()
self.config = config
self.vision_language_config = vision_language_config
assert self.vision_language_config, (
"Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")
if self.vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
else:
self.vision_tower = None
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.quant_config = quant_config
self.language_model = LlamaModel(config.text_config, quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
image_input: Optional[torch.Tensor] = None
) -> SamplerOutput: # noqa: E501
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
The model takes two types of image inputs:
PIXEL_VALUES and IMAGE_FEATURES.
The following shows how each maps to huggingface implementation.
PIXEL_VALUES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
IMAGE_FEATURES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
before going through the multi modal projector.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
image_input: A batch of image inputs.
For PIXEL_VALUES, expecting [1, 3, 336, 336].
For IMAGE_FEATURES, expecting [1, 576, 1024].
"""
if image_input is not None:
if list(image_input.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]):
raise ValueError(
f"The expected image tensor shape is batch dimension "
f"plus "
f"{self.vision_language_config.image_input_shape[1:]}."
f" You supplied {image_input.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
if self.vision_tower is not None:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = self.vision_tower(image_input,
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if self.config.vision_feature_select_strategy == "default":
image_features = image_features[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
image_features = image_features
else:
raise ValueError(
f"Unexpected select feature strategy: "
f"{self.config.vision_feature_select_strategy}")
else:
image_features = image_input
vision_embeddings = self.multi_modal_projector(image_features)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
_merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,531 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import math
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
class MiniCPMMoE(nn.Module):
"""A tensor-parallel MoE implementation that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None)
self.ws = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
dtype=self.params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype))
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
class MiniCPMMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class MiniCPMAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
# set rope as fp32 instead of bf16
self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache(
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
orig_dtype = q.dtype
q, k = q.float(), k.float()
q, k = self.rotary_emb(positions, q, k)
q, k = q.to(orig_dtype), k.to(orig_dtype)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class MiniCPMDecoderLayer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = MiniCPMAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
)
self.num_experts = getattr(self.config, "num_experts", 0)
if self.num_experts == 0:
self.mlp = MiniCPMMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
else:
self.mlp = MiniCPMMoE(num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states * \
(self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states * \
(self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
return hidden_states, None
class MiniCPMModel(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
MiniCPMDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids)
return embedding * self.config.scale_emb
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
residual,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class MiniCPMForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config
self.model = MiniCPMModel(config,
quant_config,
lora_config=lora_config)
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
if not self.config.tie_word_embeddings:
self.lm_head = ParallelLMHead(
unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.scale_width = self.config.hidden_size / self.config.dim_model_base
self.logits_processor = LogitsProcessor(unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
logits = self.logits_processor(lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.num_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,583 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import MixtralConfig
from vllm import _custom_ops as ops
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
from vllm.utils import print_warning_once
class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self.use_fp8 = isinstance(quant_config, Fp8Config)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None)
if self.use_fp8:
params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype))
self.w2_weight = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype))
set_weight_attrs(self.w13_weight, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_weight, {
"weight_loader": self.weight_loader,
})
# Used for fp8.
self.w13_scale = None
self.w2_scale = None
self.a13_scale = None
self.a2_scale = None
if self.use_fp8:
# WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.w13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_scale, {
"weight_loader": self.weight_loader,
})
# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
self.a13_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
self.a2_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
set_weight_attrs(self.a13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2_scale, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name:
param_data[expert_id] = loaded_weight
def process_weights_after_loading(self):
# Fp8 is the only case where we need to process after loading.
if not self.use_fp8:
return
# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(self.w2_weight.data,
dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[
expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], self.w2_scale[
expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :])
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif self.quant_config.activation_scheme == "static":
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ")
self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
class MixtralAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
if isinstance(
quant_config,
Fp8Config) and not quant_config.is_checkpoint_fp8_serialized:
print_warning_once(
"For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved."
)
quant_config = None
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class MixtralForCausalLM(nn.Module):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.model = MixtralModel(config,
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))

View File

@@ -0,0 +1,404 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class MixtralMLP(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
quant_config=quant_config)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
quant_config=quant_config)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
quant_config=quant_config)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}.")
# Split experts equally between ranks
self.expert_indicies = np.array_split(range(
self.num_total_experts), self.tp_size)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(
f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList([
MixtralMLP(self.num_total_experts,
config.hidden_size,
config.intermediate_size,
quant_config=quant_config)
if idx in self.expert_indicies else None
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts,
bias=False,
quant_config=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(
expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states).view(
num_tokens, hidden_dim)
class MixtralAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(config=config,
quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class MixtralForCausalLM(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = MixtralModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if ("block_sparse_moe.experts." in name
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,295 @@
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
def _get_alibi_slopes(
total_num_heads: int,
alibi_bias_max: int,
) -> torch.Tensor:
next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / next_power_of_2)
slopes = 1.0 / torch.pow(2, m)
if next_power_of_2 != total_num_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
return slopes
class MPTAttention(nn.Module):
def __init__(
self,
config: MPTConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.head_dim = self.d_model // self.total_num_heads
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
if "kv_n_heads" in config.attn_config:
self.total_num_kv_heads = config.attn_config['kv_n_heads']
else:
self.total_num_kv_heads = self.total_num_heads
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]
# pylint: disable=invalid-name
self.Wqkv = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=not config.no_bias,
quant_config=quant_config,
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
self.k_ln = nn.LayerNorm(self.d_model)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=not config.no_bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads,
self.alibi_bias_max)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.head_dim = self.d_model // self.total_num_heads
scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
del position_ids # unused.
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.qk_ln:
q = self.q_ln(q)
k = self.k_ln(k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
class MPTMLP(nn.Module):
def __init__(
self,
config: MPTConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
expansion_ratio = config.expansion_ratio
intermediate_size = expansion_ratio * hidden_size
self.up_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=not config.no_bias,
quant_config=quant_config,
)
self.act = get_act_fn("gelu", quant_config, intermediate_size)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=not config.no_bias,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.up_proj(x)
x = self.act(x)
x, _ = self.down_proj(x)
return x
class MPTBlock(nn.Module):
def __init__(
self,
config: MPTConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, quant_config)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
x = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=x,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = hidden_states + x
x = self.norm_2(hidden_states)
x = self.ffn(x)
hidden_states = hidden_states + x
return hidden_states
class MPTModel(nn.Module):
def __init__(
self,
config: MPTConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
)
self.blocks = nn.ModuleList(
[MPTBlock(config, quant_config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
if hasattr(module, "bias") and isinstance(
module.bias, nn.Parameter):
# Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.blocks)):
block = self.blocks[i]
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = self.norm_f(hidden_states)
return hidden_states
class MPTForCausalLM(nn.Module):
def __init__(
self,
config: MPTConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
assert config.tie_word_embeddings
self.quant_config = quant_config
self.transformer = MPTModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,356 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
# Copyright 2024 The vLLM team.
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class OlmoAttention(nn.Module):
"""
This is the attention block where the output is computed as
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(
self,
config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.total_num_heads = config.num_attention_heads
assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.clip_qkv = config.clip_qkv
# Attention input projection. Projects x -> (q, k, v)
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
bias=config.attention_bias,
quant_config=quant_config,
)
# Rotary embeddings.
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling)
# Attention output projection.
self.o_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class OlmoMLP(nn.Module):
"""
This is the MLP block where the output is computed as
``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(
self,
config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# Feed-forward input projection.
self.gate_up_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function.
self.act_fn = SiluAndMul()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class OlmoDecoderLayer(nn.Module):
"""
This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(self,
config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
# Attention block.
self.self_attn = OlmoAttention(config, quant_config)
# MLP block.
self.mlp = OlmoMLP(config, quant_config)
# LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,
bias=False)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,
bias=False)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Attention block.
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(positions, hidden_states, kv_cache,
attn_metadata)
hidden_states = hidden_states + residual
# MLP block.
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class OlmoModel(nn.Module):
def __init__(self,
config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
OlmoDecoderLayer(config, quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,
bias=False)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
# Apply blocks one-by-one.
for layer_idx, decoder_layer in enumerate(self.layers):
# shape: (batch_size, seq_len, d_model)
hidden_states = decoder_layer(
positions,
hidden_states,
kv_caches[layer_idx],
attn_metadata,
)
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
hidden_states = self.norm(hidden_states)
return hidden_states
class OlmoForCausalLM(nn.Module):
"""
Extremely barebones HF model wrapper.
"""
def __init__(self,
config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.model = OlmoModel(config, quant_config)
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,349 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import OPTConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class OPTLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the
# embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, positions: torch.Tensor):
return super().forward(positions + self.offset)
class OPTAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_dim = embed_dim
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
total_num_heads = num_heads
assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
embed_dim,
self.head_dim,
total_num_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
class OPTDecoderLayer(nn.Module):
def __init__(
self,
config: OPTConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
bias=config.enable_bias,
quant_config=quant_config,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(
self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
quant_config=quant_config,
)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class OPTDecoder(nn.Module):
def __init__(
self,
config: OPTConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.word_embed_proj_dim,
)
# Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size)
# Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = ReplicatedLinear(config.hidden_size,
config.word_embed_proj_dim,
bias=False,
quant_config=quant_config)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
config.hidden_size,
bias=False,
quant_config=quant_config)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to
# keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList([
OPTDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions)
if self.project_in is not None:
inputs_embeds, _ = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states, _ = self.project_out(hidden_states)
return hidden_states
class OPTModel(nn.Module):
def __init__(
self,
config: OPTConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.decoder = OPTDecoder(config, quant_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
return self.decoder(input_ids, positions, kv_caches, attn_metadata)
class OPTForCausalLM(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = OPTModel(config, quant_config)
self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "lm_head.weight" in name:
continue
if name.startswith("decoder."):
name = "model." + name
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,320 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
# Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
# Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class OrionMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class OrionAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class OrionDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = OrionAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
)
self.mlp = OrionMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, None
class OrionModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
OrionDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
residual,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class OrionForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = OrionModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,301 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class PhiAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
# pylint: disable=C0103
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_size,
self.total_num_heads,
bias=True,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
quant_config=quant_config,
)
scaling = self.head_size**-0.5
rotary_dim = int(config.partial_rotary_factor *
(config.hidden_size // config.num_attention_heads))
assert rotary_dim % 2 == 0
# pylint: disable=C0301
# Refer to:
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
rope_theta = 10000
max_position_embeddings = getattr(config, "n_positions", 2048)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(self.num_heads, self.head_size, scaling)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.dense(attn_output)
return output
class PhiMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
n_inner = getattr(config, "n_inner", None)
n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
self.fc1 = ColumnParallelLinear(
config.hidden_size,
n_inner,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
n_inner,
config.hidden_size,
quant_config=quant_config,
)
self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class PhiLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, quant_config)
self.mlp = PhiMLP(config, quant_config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_outputs + feed_forward_hidden_states + residual
return hidden_states
class PhiModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
PhiLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(self.config.num_hidden_layers):
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class PhiForCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = PhiModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v")
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# pylint: disable=E1136
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,285 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class QWenMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str = "silu",
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.c_proj(x)
return x
class QWenAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.c_attn = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.scaling = self.head_dim**-0.5
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.c_proj(attn_output)
return output
class QWenBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
self.attn = QWenAttention(config.hidden_size,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
quant_config=quant_config)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.hidden_size,
config.intermediate_size // 2,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
else:
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.ln_2(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class QWenModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.h = nn.ModuleList([
QWenBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
residual = None
for i in range(len(self.h)):
layer = self.h[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
residual,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states
class QWenLMHeadModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = QWenModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,367 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import Qwen2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class Qwen2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Qwen2Attention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
use_sliding_window: bool = False,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window if use_sliding_window else None
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class Qwen2DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen2Config,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
use_sliding_window = (config.use_sliding_window
and layer_idx < config.max_window_layers)
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
use_sliding_window=use_sliding_window,
quant_config=quant_config,
sliding_window=config.sliding_window)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen2Model(nn.Module):
def __init__(
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2DecoderLayer(config, layer_idx, quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Qwen2ForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config)
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,447 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class Qwen2MoeMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.n_routed_experts = config.num_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.n_routed_experts}.")
self.experts = nn.ModuleList([
Qwen2MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
self.pack_params()
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
bias=False,
quant_config=None)
if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen2MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
else:
self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
1,
bias=False)
def pack_params(self):
w1 = []
w2 = []
for expert in self.experts:
w1.append(expert.gate_up_proj.weight)
w2.append(expert.down_proj.weight)
self.w1 = torch._utils._flatten_dense_tensors(w1)
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
for data, param in zip(w1s, w1):
param.data = data
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
self.w2 = torch._utils._flatten_dense_tensors(w2)
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
for data, param in zip(w2s, w2):
param.data = data
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
shared_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_output
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
class Qwen2MoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class Qwen2MoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Qwen2MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
)
if (config.num_experts is not None
and (layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
quant_config=quant_config)
else:
self.mlp = Qwen2MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen2MoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Qwen2MoeForCausalLM(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Qwen2MoeModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_expert." in name)
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_expert." in name)
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,301 @@
# coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class StablelmMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size, [config.intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=False)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class StablelmAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
self.num_heads = self.total_num_heads // tp_size
self.total_num_key_value_heads = config.num_key_value_heads
if self.total_num_key_value_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_key_value_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_key_value_heads == 0
self.num_key_value_heads = max(
1, self.total_num_key_value_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
rope_pct = getattr(config, "rope_pct",
getattr(config, "partial_rotary_factor", 1))
self.rotary_ndims = int(self.head_dim * rope_pct)
self.scaling = self.head_dim**-0.5
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
self.qkv_bias = getattr(config, "use_qkv_bias", False)
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
self.qkv_proj = QKVParallelLinear(self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_key_value_heads,
self.qkv_bias,
quant_config=quant_config)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_ndims,
max_position=self.config.max_position_embeddings,
base=self.config.rope_theta,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_key_value_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class StablelmDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.self_attn = StablelmAttention(config)
self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, residual
class StableLMEpochModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
StablelmDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class StablelmForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = StableLMEpochModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,302 @@
# coding=utf-8
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class Starcoder2Attention(nn.Module):
def __init__(self,
config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.max_position_embeddings = config.max_position_embeddings
self.use_bias = config.use_bias
self.sliding_window = config.sliding_window
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=self.use_bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=self.use_bias,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class Starcoder2MLP(nn.Module):
def __init__(self,
config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.c_fc = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=config.use_bias,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=config.use_bias,
quant_config=quant_config,
)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class Starcoder2DecoderLayer(nn.Module):
def __init__(self,
config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Starcoder2Model(nn.Module):
def __init__(self,
config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# TODO: consider padding_idx (currently removed)
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
Starcoder2DecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states, kv_caches[i],
attn_metadata)
hidden_states = self.norm(hidden_states)
return hidden_states
class Starcoder2ForCausalLM(nn.Module):
def __init__(self,
config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.model = Starcoder2Model(config, quant_config=quant_config)
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,366 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/xverse/XVERSE-7B/blob/main/modeling_xverse.py
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Xverse model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
class XverseMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate, _ = self.gate_up_proj(x)
x = self.act_fn(gate)
x, _ = self.down_proj(x)
return x
class XverseAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
# partition the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=bias,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class XverseDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
self.self_attn = XverseAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=getattr(config, "bias", False),
sliding_window=sliding_window,
)
self.mlp = XverseMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class XverseModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
XverseDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class XverseForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config=None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = XverseModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if ("rotary_emb.inv_freq" in name
or "rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)