Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -1,29 +1,47 @@
# coding=utf-8
from typing import Iterable, List, Optional, Tuple
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from itertools import islice
import torch
import torch.nn as nn
from transformers import DbrxConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.dbrx import DbrxConfig
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
class DbrxRouter(nn.Module):
@@ -34,7 +52,7 @@ class DbrxRouter(nn.Module):
def __init__(
self,
config: DbrxConfig,
params_dtype: Optional[torch.dtype] = None,
params_dtype: torch.dtype | None = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
@@ -53,7 +71,80 @@ class DbrxRouter(nn.Module):
return router_logits
class DbrxExperts(nn.Module):
class DbrxExperts(FusedMoE):
def __init__(
self,
config: DbrxConfig,
quant_config: QuantizationConfig | None = None,
params_dtype: torch.dtype | None = None,
prefix: str = "",
):
super().__init__(
num_experts=config.ffn_config.moe_num_experts,
top_k=config.ffn_config.moe_top_k,
hidden_size=config.d_model,
intermediate_size=config.ffn_config.ffn_hidden_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=get_tensor_model_parallel_world_size(),
prefix=prefix,
)
self.config = config
self.d_model = config.d_model
self.intermediate_size = self.config.ffn_config.ffn_hidden_size // self.tp_size
# Define custom weight loader for dbrx model
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
param_name: str,
):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
# DBRX uses GLU for each experts.
# GLU has 3 linear layers: w1, v1 and w2.
if weight_name.endswith("w1"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
elif param_name.endswith("weight_scale"):
param_data[:, 0] = loaded_weight
else:
param_data = loaded_weight
if weight_name.endswith("v1"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, shard_size : 2 * shard_size, :] = loaded_weight[
:, shard, :
]
elif param_name.endswith("weight_scale"):
param_data[:, 1] = loaded_weight
else:
param_data[:] = loaded_weight
if weight_name.endswith("w2"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]
else:
param_data[:] = loaded_weight
class DbrxMoE(nn.Module):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
@@ -64,109 +155,41 @@ class DbrxExperts(nn.Module):
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
quant_config: QuantizationConfig | None = None,
params_dtype: torch.dtype | None = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.ffn_config.moe_num_experts
self.top_k = config.ffn_config.moe_top_k
self.d_model = config.d_model
self.intermediate_size = (config.ffn_config.ffn_hidden_size //
self.tp_size)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.router = DbrxRouter(config, self.params_dtype)
self.ws = nn.Parameter(
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.d_model,
device="cuda",
dtype=self.params_dtype,
))
self.w2s = nn.Parameter(
torch.empty(
self.num_total_experts,
self.d_model,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype,
))
set_weight_attrs(
self.ws,
{
"weight_loader": self.weight_loader,
},
self.experts = DbrxExperts(
config=config,
quant_config=quant_config,
params_dtype=self.params_dtype,
prefix=f"{prefix}.experts",
)
set_weight_attrs(
self.w2s,
{
"weight_loader": self.weight_loader,
},
)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
# DBRX uses GLU for each experts.
# GLU has 3 linear layers: w1, v1 and w2.
if weight_name.endswith("w1"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
if weight_name.endswith("v1"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:,
shard_size:2 * shard_size, :] = loaded_weight[:,
shard, :]
if weight_name.endswith("w2"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.d_model)
# router_logits: (num_tokens, n_experts)
router_logits = self.router(hidden_states)
final_hidden_states = fused_moe(
hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
class DbrxAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
@@ -174,7 +197,10 @@ class DbrxAttention(nn.Module):
self.head_dim = self.d_model // self.total_num_heads
self.total_num_kv_heads = config.attn_config.kv_n_heads
self.clip_qkv = config.attn_config.clip_qkv
self.rope_theta = config.attn_config.rope_theta
rope_parameters = {
"rope_type": "default",
"rope_theta": int(config.attn_config.rope_theta),
}
self.max_position = config.max_seq_len
# pylint: disable=invalid-name
@@ -185,18 +211,19 @@ class DbrxAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.Wqkv",
)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position,
base=int(self.rope_theta),
rope_parameters=rope_parameters,
is_neox_style=True,
)
@@ -221,35 +248,39 @@ class DbrxAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
hidden_states, _ = self.out_proj(attn_output)
return hidden_states
class DbrxFusedNormAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
self.attn = DbrxAttention(config, quant_config)
self.attn = DbrxAttention(
config, cache_config, quant_config, prefix=f"{prefix}.attn"
)
self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model)
@@ -257,16 +288,12 @@ class DbrxFusedNormAttention(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + x
residual = hidden_states
@@ -275,28 +302,27 @@ class DbrxFusedNormAttention(nn.Module):
class DbrxBlock(nn.Module):
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
self.ffn = DbrxExperts(config, quant_config)
self.norm_attn_norm = DbrxFusedNormAttention(
config, cache_config, quant_config, prefix=f"{prefix}.norm_attn_norm"
)
self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn")
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states, residual = self.norm_attn_norm(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = self.ffn(hidden_states)
hidden_states = hidden_states + residual
@@ -304,110 +330,155 @@ class DbrxBlock(nn.Module):
class DbrxModel(nn.Module):
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
)
self.blocks = nn.ModuleList(
[DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers,
lambda prefix: DbrxBlock(config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.blocks",
)
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias,
nn.Parameter):
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
# Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], config.d_model
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.blocks)):
block = self.blocks[i]
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i],
attn_metadata,
)
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"]
for block in islice(self.blocks, self.start_layer, self.end_layer):
hidden_states = block(position_ids, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
expert_params_mapping = [
(
"w13" if weight_name in ["w1", "v1"] else "w2",
f"mlp.{weight_name}",
)
for weight_name in ["w1", "v1", "w2"]
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
class DbrxForCausalLM(nn.Module):
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
):
if name.endswith(("w1", "w2", "v1")):
name = name + "_weight"
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, weight_name, name)
break
else:
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class DbrxForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
if config.tie_word_embeddings:
raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, quant_config)
self.transformer = DbrxModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.d_model,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.transformer(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [(
"ws" if weight_name in ["w1", "v1"] else "w2s",
f"experts.mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, weight_name)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)