Sync from v0.13
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
# coding=utf-8
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
@@ -21,36 +23,52 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from itertools import islice
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.utils import print_warning_once
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
@@ -68,192 +86,84 @@ class MixtralMoE(nn.Module):
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
tp_size: int | None = None,
|
||||
dp_size: int | None = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size // self.tp_size
|
||||
self.quant_config = quant_config
|
||||
|
||||
# FIXME(pcmoritz): Make this more general to support different
|
||||
# quantization schemes
|
||||
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = get_ep_group().rank_in_group
|
||||
self.ep_size = self.ep_group.size()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
# Expert Parallelism Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.enable_eplb = enable_eplb
|
||||
|
||||
self.n_routed_experts = num_experts
|
||||
self.n_logical_experts = num_experts
|
||||
self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts
|
||||
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
|
||||
self.physical_expert_end = (
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.gate = ReplicatedLinear(self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
quant_config=None)
|
||||
|
||||
if self.use_fp8:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
self.gate = ReplicatedLinear(
|
||||
hidden_size,
|
||||
num_experts,
|
||||
bias=False,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.w13_weight = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
dtype=params_dtype))
|
||||
self.w2_weight = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
dtype=params_dtype))
|
||||
|
||||
set_weight_attrs(self.w13_weight, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2_weight, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
|
||||
# Used for fp8.
|
||||
self.w13_scale = None
|
||||
self.w2_scale = None
|
||||
self.a13_scale = None
|
||||
self.a2_scale = None
|
||||
|
||||
if self.use_fp8:
|
||||
# WEIGHT_SCALE (for fp8)
|
||||
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(self.w13_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
|
||||
# ACT_SCALE (for fp8)
|
||||
if quant_config.activation_scheme == "static":
|
||||
if not quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8.")
|
||||
self.a13_scale = nn.Parameter(torch.zeros(
|
||||
self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
self.a2_scale = nn.Parameter(torch.zeros(
|
||||
self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
set_weight_attrs(self.a13_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.a2_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||
weight_name: str, expert_id: int):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param.data
|
||||
shard_size = self.intermediate_size
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
if weight_name.endswith("w1.weight"):
|
||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("w3.weight"):
|
||||
param_data[expert_id,
|
||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("w2.weight"):
|
||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
# Fp8 is the only case where we need to process after loading.
|
||||
if not self.use_fp8:
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize here.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
w13_weight = torch.empty_like(self.w13_weight.data,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_weight = torch.empty_like(self.w2_weight.data,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
for expert in range(self.num_total_experts):
|
||||
w13_weight[expert, :, :], self.w13_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
self.w13_weight.data[expert, :, :])
|
||||
w2_weight[expert, :, :], self.w2_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
self.w2_weight.data[expert, :, :])
|
||||
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
||||
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
# If checkpoint is fp8 + static, cleanup act_scales.
|
||||
# Since state_dict has an act_scale per expert but our kernels
|
||||
# are passed one act_scale shared across all experts.
|
||||
elif self.quant_config.activation_scheme == "static":
|
||||
if self.a13_scale is None or self.a2_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
|
||||
if (not all_close_1d(self.a13_scale)
|
||||
or not all_close_1d(self.a2_scale)):
|
||||
print_warning_once(
|
||||
"Found act_scales that are not equal for fp8 MoE layer. "
|
||||
"Using the maximum across experts for each layer. ")
|
||||
|
||||
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
||||
requires_grad=False)
|
||||
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
||||
requires_grad=False)
|
||||
self.experts = FusedMoE(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=True,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
dp_size=dp_size,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.w13_weight,
|
||||
self.w2_weight,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8=self.use_fp8,
|
||||
w1_scale=self.w13_scale,
|
||||
w2_scale=self.w2_scale,
|
||||
a1_scale=self.a13_scale,
|
||||
a2_scale=self.a2_scale)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_size)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
sliding_window: Optional[int] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -270,21 +180,13 @@ class MixtralAttention(nn.Module):
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
# MixtralConfig has an optional head_dim argument
|
||||
self.head_dim = getattr(config, "head_dim", None)
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
if isinstance(
|
||||
quant_config,
|
||||
Fp8Config) and not quant_config.is_checkpoint_fp8_serialized:
|
||||
print_warning_once(
|
||||
"For Mixtral FP8 quantization, we currently do not quantize "
|
||||
"the attention layers until their FP8 performance is improved."
|
||||
)
|
||||
quant_config = None
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -293,18 +195,19 @@ class MixtralAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=int(self.rope_theta),
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=True,
|
||||
)
|
||||
self.attn = Attention(
|
||||
@@ -312,128 +215,271 @@ class MixtralAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MixtralDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = MixtralAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
sliding_window=config.sliding_window,
|
||||
quant_config=quant_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.block_sparse_moe = MixtralMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.block_sparse_moe",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
residual: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.block_sparse_moe(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MixtralModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
MixtralDecoderLayer(config, quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
self.num_redundant_experts = parallel_config.eplb_config.num_redundant_experts
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MixtralDecoderLayer(
|
||||
config,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
enable_eplb=self.enable_eplb,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i], attn_metadata,
|
||||
residual)
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
ckpt_up_proj_name="w3",
|
||||
num_experts=self.config.num_local_experts,
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
)
|
||||
|
||||
class MixtralForCausalLM(nn.Module):
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name.endswith("scale"):
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
is_expert_weight = True
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
if (
|
||||
name_mapped.endswith(".bias") or name_mapped.endswith("_bias")
|
||||
) and name_mapped not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
weight_loader = typing.cast(
|
||||
Callable[..., bool], param.weight_loader
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
packed_modules_mapping = {
|
||||
@@ -445,139 +491,108 @@ class MixtralForCausalLM(nn.Module):
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.model = MixtralModel(config,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.model = MixtralModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
self.expert_weights = []
|
||||
self.moe_layers = []
|
||||
example_moe = None
|
||||
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
assert isinstance(layer, MixtralDecoderLayer)
|
||||
if hasattr(layer, "block_sparse_moe") and isinstance(
|
||||
layer.block_sparse_moe, MixtralMoE
|
||||
):
|
||||
example_moe = layer.block_sparse_moe
|
||||
self.moe_layers.append(layer.block_sparse_moe.experts)
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
|
||||
if example_moe is None:
|
||||
raise RuntimeError("No MixtralMoE layer found in model.layers.")
|
||||
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
|
||||
for layer in self.model.layers:
|
||||
if hasattr(layer, "block_sparse_moe") and isinstance(
|
||||
layer.block_sparse_moe, MixtralMoE
|
||||
):
|
||||
moe = layer.block_sparse_moe
|
||||
moe.n_local_physical_experts = num_local_physical_experts
|
||||
moe.n_physical_experts = num_physical_experts
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
expert_params_mapping = [
|
||||
# These are the weight scales for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
] + [
|
||||
# These are the weights for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
] + [
|
||||
# These are the activation scales for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for param_name, weight_name, expert_id in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
def all_close_1d(x: torch.Tensor) -> bool:
|
||||
assert len(x.shape) == 1
|
||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
Reference in New Issue
Block a user