Files
2026-04-24 09:58:03 +08:00

1097 lines
42 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import re
from typing import Iterable, Set, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import (
get_ep_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
get_data_parallel_group_world_size,
get_tp_group,
)
from vllm.distributed.communication_op import (
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.interfaces import SupportsEagle
from vllm.model_executor.models.utils import maybe_prefix
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.attention import AttentionMetadata
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.v1.attention.backends.utils import (
MLUCommonAttentionMetadata,
get_common_metadata,
)
from vllm_mlu.model_executor.layers.indexer import Indexer
from vllm_mlu.model_executor.layers.compressor import Compressor
from vllm_mlu import _mlu_ops as mlu_ops
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.attention.layer import MLAAttention
from vllm_mlu.model_executor.layers.sparse_moe_mlp import MoeGroupInfo, SparseMoeMlp
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
logger = init_logger(__name__)
class HCHead(torch.nn.Module):
def __init__(
self,
hc_mult,
dim,
hc_eps,
norm_eps,
prefix: str = "",
):
super().__init__()
self.hc_mult: int = hc_mult
self.dim: int = dim
self.hc_dim: int = hc_mult * dim
self.hc_eps = hc_eps
self.norm_eps = norm_eps
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, self.hc_dim, dtype=torch.float), requires_grad=False)
self.hc_head_base = nn.Parameter(torch.empty(hc_mult, dtype=torch.float), requires_grad=False)
self.hc_head_scale = nn.Parameter(torch.empty(1, dtype=torch.float), requires_grad=False)
def forward(self, x: torch.Tensor):
shape, dtype = x.size(), x.dtype
x = x.flatten(-2).float()
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
mixes = F.linear(x, self.hc_head_fn) * rsqrt
pre = torch.sigmoid(mixes * self.hc_head_scale + self.hc_head_base) + self.hc_eps
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=-2)
return y.to(dtype)
class HCPre(torch.nn.Module):
def __init__(
self,
hc_mult,
dim,
hc_sinkhorn_iters,
hc_eps,
norm_eps,
prefix: str = "",
):
super().__init__()
self.hc_mult: int = hc_mult
self.dim: int = dim
self.hc_dim: int = hc_mult * dim
self.hc_sinkhorn_iters = hc_sinkhorn_iters
self.hc_eps = hc_eps
mix_hc = (2 + hc_mult) * hc_mult
self.norm_eps = norm_eps
self.hc_fn = nn.Parameter(torch.empty(mix_hc, self.hc_dim, dtype=torch.float), requires_grad=False)
self.hc_base = nn.Parameter(torch.empty(mix_hc, dtype=torch.float), requires_grad=False)
self.hc_scale = nn.Parameter(torch.empty(3, dtype=torch.float), requires_grad=False)
def forward(
self,
x: torch.Tensor,
rsqrt: torch.Tensor | None = None,
):
shape, dtype = x.size(), x.dtype
x = x.flatten(-2).float()
if rsqrt is None:
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
mixes = F.linear(x, self.hc_fn)
pre, post, comb = mlu_ops.hc_split_sinkhorn(
mixes.unsqueeze(0),
self.hc_scale,
self.hc_base,
rsqrt.squeeze(-1).unsqueeze(0),
self.hc_mult,
self.hc_sinkhorn_iters,
self.hc_eps,
)
pre, post, comb = pre.squeeze(0), post.squeeze(0), comb.squeeze(0)
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=-2)
return y.to(dtype), post, comb
class HCPost(torch.nn.Module):
def __init__(
self,
norm_eps: float,
prefix: str = "",
):
self.norm_eps = norm_eps
super().__init__()
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor,
post: torch.Tensor,
comb: torch.Tensor,
compute_rms: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor| None]:
# x: [bs, dim], residual: [bs, hc, dim], post: [bs, hc], comb: [bs, hc, hc]
# return
# y: [bs, hc, dim]
# [bs, hc, 1] * [bs, 1, dim] + torch.sum([bs, hc, hc, 1] * [bs, hc, 1, dim], -2)
# rsqrt: Optional, [bs, 1]
use_tmo = True
if use_tmo:
y, rsqrt = mlu_ops.fused_mhc_post(x, residual, post, comb, compute_rms, self.norm_eps)
return y, (rsqrt.unsqueeze(-1) if rsqrt is not None else None)
y = post.unsqueeze(-1) * x.unsqueeze(-2) + \
torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=-3)
rsqrt = (
torch.rsqrt(y.type_as(x).flatten(-2).float().square().mean(-1, keepdim=True) + self.norm_eps)
if compute_rms
else None
)
return y.type_as(x), rsqrt
class MLUDeepseekV4MoE(SparseMoeMlp):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs,
):
layer_id = int(prefix.split(sep=".")[-2])
self.layer_id = layer_id
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
SparseMoeMlp.__init__(
self,
num_experts=config.n_routed_experts,
top_k=config.n_activated_experts,
hidden_size=config.dim,
intermediate_size=config.moe_inter_dim,
up_proj_name='w13',
is_gated=True,
down_proj_name='w2',
has_bias=False,
hidden_act='silu',
params_dtype=torch.float,
quant_config=quant_config,
is_use_fused_moe=True,
expert_group=1,
topk_group=1,
scoring_func=config.score_func,
topk_method='',
routed_scaling_factor=config.route_scale,
use_hash=(layer_id < config.n_hash_layers),
vocab_size=config.vocab_size,
prefix=prefix,
)
self.dim = config.dim
world_size = get_ep_group().world_size
self.world_size = world_size
assert config.n_routed_experts % world_size == 0, \
f"Number of experts must be divisible by world size (world_size={world_size})"
self.n_routed_experts = config.n_routed_experts
self.n_local_experts = self.n_routed_experts // world_size
self.n_activated_experts = config.n_activated_experts
self.experts_start_idx = get_ep_group().rank_in_group * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
assert config.n_shared_experts == 1
self.shared_experts = FeedForward(
hidden_size=config.dim,
intermediate_size=config.moe_inter_dim,
hidden_act='silu',
up_proj_name='w13',
is_gated=True,
down_proj_name='w2',
bias=False,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
def forward(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
) -> torch.Tensor:
shape = hidden_states.size()
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states.float())
hidden_states = self.forward_experts(
hidden_states,
router_logits,
shared_output=shared_output,
input_ids=input_ids,
)
hidden_states = self.reduce_results(hidden_states)
return hidden_states.view(shape)
class MLUDeepseekV4Attention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
layer_id = int(prefix.split(sep=".")[-2])
self.layer_id = layer_id
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.attn_data_parallel_size = get_data_parallel_group_world_size()
self.attn_tensor_parallel_size = get_tensor_model_parallel_world_size()
self.num_heads = vllm_config.model_config.hf_config.n_heads
assert self.num_heads % self.tp_size == 0
self.num_local_heads = self.num_heads // self.attn_tensor_parallel_size
self.model_type = config.model_type
self.use_indexer = hasattr(config, 'index_n_heads')
self.hidden_size = config.dim
self.head_dim = config.head_dim
self.q_lora_rank = config.q_lora_rank
self.rope_head_dim = config.rope_head_dim
self.eps = config.norm_eps
self.o_groups = config.o_groups
self.o_local_groups = self.o_groups // self.attn_tensor_parallel_size
self.softmax_scale = self.head_dim ** -0.5
self.compress_ratio = config.compress_ratios[layer_id]
self.window_size = config.window_size
self.max_model_len = vllm_config.model_config.max_model_len
self.original_seq_len = config.original_seq_len
self.index_topk = config.index_topk
self.o_lora_rank = config.o_lora_rank
self.rope_theta = getattr(config, "rope_theta", 10000)
self.rope_scaling = getattr(config, "rope_scaling", None)
tp_group = get_tp_group()
# disable YaRN and use base rope_theta in pure sliding-window attention
if self.compress_ratio > 1:
max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", 65536)
self.rope_scaling["rope_type"] = 'deepseek_yarn'
else:
max_position_embeddings = 0
self.rope_scaling["rope_type"] = 'default'
if self.rope_scaling is not None:
self.rope_scaling["original_max_position_embeddings"] = 0
self.rotary_emb = get_rope(
self.rope_head_dim,
rotary_dim=self.rope_head_dim,
max_position=max_position_embeddings,
base=config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta,
rope_scaling=self.rope_scaling,
is_neox_style=False,
)
self.output_rotary_emb = get_rope(
self.rope_head_dim,
rotary_dim=self.rope_head_dim,
max_position=max_position_embeddings,
base=config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta,
rope_scaling=self.rope_scaling,
dtype=torch.float32,
is_neox_style=False,
inverse=True,
)
if self.q_lora_rank is not None:
self.wq_a = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=None,
prefix=f"{prefix}.wq_a",
)
self.q_norm = RMSNorm(
self.q_lora_rank,
eps=self.eps,
)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank,
self.num_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
tp_group=tp_group,
)
self.wkv = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=None,
prefix=f"{prefix}.wkv",
)
self.kv_norm = RMSNorm(
self.head_dim,
eps=self.eps,
)
if get_tensor_model_parallel_world_size() <= self.o_groups:
self.wo_a = ColumnParallelLinear(
self.num_heads * self.head_dim // self.o_groups,
self.o_groups * self.o_lora_rank,
bias=False,
quant_config=None,
prefix=f"{prefix}.wo_a",
)
self.wo_b = RowParallelLinear(
self.o_groups * self.o_lora_rank,
self.hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.wo_b",
tp_group=tp_group,
)
else:
self.wo_a = ReplicatedLinear(
self.num_heads * self.head_dim // self.o_groups,
self.o_groups * self.o_lora_rank,
bias=False,
quant_config=None,
prefix=f"{prefix}.wo_a",
)
self.wo_b = ReplicatedLinear(
self.o_groups * self.o_lora_rank,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wo_b",
)
self.attn = MLAAttention(
self.num_local_heads, # num_heads
self.softmax_scale, # scale
self.head_dim - self.rope_head_dim, # qk_nope_head_dim
self.rope_head_dim, # qk_rope_head_dim
self.head_dim, # v_head_dim
self.q_lora_rank, # q_lora_rank
self.head_dim, # kv_lora_rank
self.wkv, # kv_b_proj
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
# extra_impl_args
num_kv_heads=1,
prefix=f"{prefix}.attn",
use_fused_mla_qkv=False,
)
if self.compress_ratio:
self.compressor = Compressor(vllm_config, self.rotary_emb, self.compress_ratio, self.head_dim, False, f"{prefix}.compressor")
if self.compress_ratio == 4:
self.indexer = Indexer(vllm_config, self.rotary_emb, self.compress_ratio, f"{prefix}.indexer")
else:
self.indexer = None
self.attn_sink = nn.Parameter(torch.empty(self.num_local_heads, dtype=torch.float32))
def forward_sparse_attn(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
window_compress_params: dict | None,
window_slot_mapping: torch.Tensor,
compressor_slot_mapping: dict | None,
) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
if self.q_lora_rank is not None:
q = self.wq_a(hidden_states)[0]
q = self.q_norm(q)
qr = q
q = self.wq_b(q)[0].view(-1, self.num_local_heads, self.head_dim)
q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)
_, q_pe = q.split([self.head_dim - self.rope_head_dim, self.rope_head_dim], dim=-1)
kv = self.wkv(hidden_states)[0]
kv = self.kv_norm(kv)
kv = kv.unsqueeze(-2)
kv_pe = kv[..., -self.rope_head_dim :]
q_pe, kv_pe = self.rotary_emb(positions, q_pe, kv_pe, only_prefill=False)
common_metadata = get_common_metadata()
query_start_loc = common_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
key_cache = kv_cache[0][0]
mlu_ops.reshape_paged_cache(
kv,
None,
key_cache,
None,
window_slot_mapping,
)
if self.compress_ratio:
offsets = query_lens if common_metadata.is_prefill_only else torch.full_like(query_lens, self.window_size)
if self.indexer is not None:
indexer_kv_cache = kv_cache[2]
compress_block_tables, compress_context_lens = self.indexer(
hidden_states,
qr,
positions,
offsets,
attn_metadata,
batch_to_kv_state,
indexer_kv_cache,
compressor_slot_mapping[(0, self.compress_ratio)],
)
if self.compress_ratio:
compress_kv = self.compressor(
hidden_states,
positions,
attn_metadata,
batch_to_kv_state,
key_cache,
self.window_size,
compressor_slot_mapping[(self.window_size, self.compress_ratio)],
)
if common_metadata.is_prefill_only:
kv = torch.cat([kv, compress_kv], dim=0)
assert window_compress_params != None
if self.compress_ratio:
if self.indexer is not None:
window_block_tables = window_compress_params.get("window_block_tables", None)
window_context_lens = window_compress_params.get("window_context_lens", None)
new_block_tables = torch.empty([num_tokens, self.window_size + self.index_topk], dtype=torch.int32, device=hidden_states.device)
new_context_lens = torch.empty([num_tokens], dtype=torch.int32, device=hidden_states.device)
mlu_ops.concat_block_table(
window_block_tables,
window_context_lens,
compress_block_tables,
compress_context_lens,
new_block_tables,
new_context_lens,
)
max_contxt_len = self.window_size + self.index_topk
else:
new_block_tables = window_compress_params.get("compress_block_tables", None)
new_context_lens = window_compress_params.get("compress_context_lens", None)
max_contxt_len = self.window_size + (self.max_model_len // self.compress_ratio)
else:
new_block_tables = window_compress_params.get("window_block_tables", None)
new_context_lens = window_compress_params.get("window_context_lens", None)
max_contxt_len = self.window_size
attn_output = torch.zeros_like(q)
total_token = q.size(0)
assert total_token == new_block_tables.size(0)
q_ = q.view(total_token, -1, self.num_local_heads, self.head_dim)
attn_output = attn_output.view(total_token, -1, self.num_local_heads, self.head_dim)
if common_metadata.is_prefill_only:
kv_cache_ = kv.unsqueeze(1) # insert block_size, [total_token, 1, head_dim] -> [total_token, 1, 1, head_dim]
else:
kv_cache_ = kv_cache[0].view(-1, 1, 1, self.head_dim)
mlu_ops.single_query_cached_kv_attn(
q=q_,
k_cache=kv_cache_,
v_cache=None,
out=attn_output,
block_tables=new_block_tables,
context_lens=new_context_lens,
k_cache_quant_scale=None,
v_cache_quant_scale=None,
alibi_slopes=None,
max_contxt_len=max_contxt_len,
windows_size_left=-1,
windows_size_right=-1,
softmax_scale=self.softmax_scale,
compute_dtype=torch.float,
learnable_sink=self.attn_sink,
)
attn_output = attn_output.reshape(-1, self.num_local_heads, self.head_dim).to(torch.float)
attn_output_pe = attn_output[..., -self.rope_head_dim:]
attn_output_pe, _ = self.output_rotary_emb(positions, attn_output_pe, None, only_prefill=False)
attn_output = attn_output.to(dtype=torch.bfloat16)
if get_tensor_model_parallel_world_size() <= self.o_groups:
attn_output = attn_output.reshape(num_tokens, self.o_local_groups, -1)
wo_a = self.wo_a.weight.view(self.o_local_groups, self.o_lora_rank, -1)
o = torch.einsum("ngd,grd->ngr", attn_output, wo_a)
output = self.wo_b(o.flatten(-2))[0]
output = tensor_model_parallel_all_reduce(output)
else:
# (token, 64/tp, head_dim) -> (64/tp, token, head_dim)
attn_output = attn_output.flatten(-2).contiguous()
attn_output = tensor_model_parallel_all_gather(attn_output, dim=-1)
# (token, 64 * head_dim) -> (token, 64, head_dim)
attn_output = attn_output.reshape(-1, self.num_heads, self.head_dim).contiguous() # t, 64
wo_a = self.wo_a.weight.view(self.o_groups, self.o_lora_rank, -1)
attn_output = attn_output.reshape(num_tokens, self.o_groups, -1)
o = torch.einsum("ngd,grd->ngr", attn_output, wo_a)
output = self.wo_b(o.flatten(-2))[0]
return output
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
batch_to_kv_state: torch.Tensor,
window_compress_params: dict | None,
window_slot_mapping: torch.Tensor,
compressor_slot_mapping: dict | None,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return torch.empty_like(hidden_states)
# self.attn and self.attn_decoder always have the same attn_metadata
# and share the same kv cache for each layer
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.attn.layer_name]
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
output = self.forward_sparse_attn(
positions,
hidden_states,
kv_cache,
attn_metadata,
batch_to_kv_state,
window_compress_params,
window_slot_mapping,
compressor_slot_mapping,
)
return output
class MLUDeepseekV4DecoderLayer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str,
config: PretrainedConfig | None = None,
) -> None:
super().__init__()
if config is None:
config = vllm_config.model_config.hf_config
self.config = config
self.dim = config.dim
layer_idx = int(prefix.split(sep=".")[-1])
self.layer_idx = layer_idx
self.attn = MLUDeepseekV4Attention(
vllm_config=vllm_config,
prefix=f"{prefix}.attn",
)
self.hc_mult = config.hc_mult
self.mix_hc = (2 + self.hc_mult) * self.hc_mult
self.hc_dim = self.hc_mult * config.dim
self.norm_eps = config.norm_eps
self.hc_sinkhorn_iters = config.hc_sinkhorn_iters
self.hc_eps = config.hc_eps
self.hc_attn_pre = HCPre(
self.hc_mult,
config.dim,
self.hc_sinkhorn_iters,
self.hc_eps,
self.norm_eps,
prefix=f"{prefix}.hc_attn_pre"
)
self.hc_attn_post = HCPost(
self.norm_eps,
)
self.hc_ffn_pre = HCPre(
self.hc_mult,
config.dim,
self.hc_sinkhorn_iters,
self.hc_eps,
self.norm_eps,
prefix=f"{prefix}.hc_attn_pre"
)
self.hc_ffn_post = HCPost(
self.norm_eps,
)
self.attn_norm = RMSNorm(config.dim, config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.ffn = MLUDeepseekV4MoE(
vllm_config=vllm_config,
prefix=f"{prefix}.ffn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_ids: torch.Tensor | None,
residual: torch.Tensor | None,
batch_to_kv_state: torch.Tensor,
window_compress_params: dict | None = None,
hc_attn_pre_norm: torch.Tensor | None = None,
window_slot_mapping: torch.Tensor | None = None,
compressor_slot_mapping: dict | None = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states, post, comb = self.hc_attn_pre(hidden_states, rsqrt=hc_attn_pre_norm)
hidden_states = self.attn_norm(hidden_states)
hidden_states = self.attn(
positions,
hidden_states,
batch_to_kv_state,
window_compress_params,
window_slot_mapping,
compressor_slot_mapping,
)
hidden_states, hc_ffn_pre_norm = self.hc_attn_post(
hidden_states,
residual,
post,
comb,
compute_rms=True,
)
residual = hidden_states
is_last_layer = (self.layer_idx == self.config.n_layers - 1)
hidden_states, post, comb = self.hc_ffn_pre(hidden_states, rsqrt=hc_ffn_pre_norm)
hidden_states = self.ffn_norm(hidden_states)
hidden_states = self.ffn(hidden_states, input_ids)
hidden_states, hc_attn_pre_norm = self.hc_ffn_post(
hidden_states,
residual,
post,
comb,
compute_rms=(not is_last_layer),
)
return hidden_states, hc_attn_pre_norm
@support_torch_compile
class MLUDeepseekV4Model(nn.Module):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.device = current_platform.device_type
self.compress_ratio = 128 # only compressor layer 128
self.window_size = config.window_size
self.max_model_len = vllm_config.model_config.max_model_len
self.vocab_size = config.vocab_size
self.norm_eps = config.norm_eps
self.hc_eps = config.hc_eps
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.dim,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
self.layers = nn.ModuleList()
for layer_id in range(config.n_layers):
self.layers.append(MLUDeepseekV4DecoderLayer(
vllm_config=vllm_config,
prefix=f"{prefix}.layers.{layer_id}",
config=config,
))
self.hc_mult = config.hc_mult
self.dim = config.dim
self.hc_head = HCHead(
self.hc_mult,
self.dim,
self.hc_eps,
self.norm_eps,
prefix=f"{prefix}.hc_head",
)
self.norm = RMSNorm(config.dim, self.norm_eps)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
batch_to_kv_state: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
window_slot_mapping: torch.Tensor | None = None,
compressor_slot_mapping: dict | None = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
hidden_states = hidden_states.unsqueeze(1).repeat(1, self.config.hc_mult, 1)
common_metadata = get_common_metadata()
if common_metadata is not None:
total_token_num = hidden_states.size(0)
window_block_tables = torch.empty([total_token_num, self.window_size], dtype=torch.int32, device=hidden_states.device)
window_context_lens = torch.empty([total_token_num], dtype=torch.int32, device=hidden_states.device)
kv_cache_size = self.window_size + (self.max_model_len // self.compress_ratio if self.compress_ratio else 0)
compress_block_tables = torch.empty([total_token_num, kv_cache_size], dtype=torch.int32, device=hidden_states.device)
compress_context_lens = torch.empty([total_token_num], dtype=torch.int32, device=hidden_states.device)
mlu_ops.get_window_block_tables(
window_size = self.window_size,
block_size = 1,
seq_k_lens = common_metadata.seq_lens,
query_start_loc = common_metadata.query_start_loc,
block_table = common_metadata.block_table_tensor,
window_block_tables = window_block_tables,
window_context_lens = window_context_lens
)
# get_compress_block_tables
query_start_loc = common_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
compress_lens = query_lens // self.compress_ratio
cu_compress_lens = torch.cat([
torch.tensor([0], dtype=compress_lens.dtype, device=compress_lens.device),
torch.cumsum(compress_lens, dim=0)
])
offsets = cu_compress_lens[: -1] + total_token_num if common_metadata.is_prefill_only else torch.full_like(query_lens, self.window_size)
mlu_ops.get_compress_block_tables(
ratio = self.compress_ratio,
block_size = 1,
seq_k_lens = common_metadata.seq_lens,
query_start_loc = common_metadata.query_start_loc,
offset = offsets,
block_table = common_metadata.block_table_tensor,
compress_block_tables = compress_block_tables,
compress_context_lens = compress_context_lens,
)
win_comp_block_tables = torch.empty([total_token_num, kv_cache_size], dtype=torch.int32, device=hidden_states.device)
win_comp_context_lens = torch.empty([total_token_num], dtype=torch.int32, device=hidden_states.device)
mlu_ops.concat_block_table(
window_block_tables,
window_context_lens,
compress_block_tables,
compress_context_lens,
win_comp_block_tables,
win_comp_context_lens,
)
window_compress_params = {
"window_block_tables": window_block_tables,
"window_context_lens": window_context_lens,
"compress_block_tables": win_comp_block_tables,
"compress_context_lens": win_comp_context_lens,
}
else:
window_compress_params = None
hc_attn_pre_norm = None
for layer in self.layers:
hidden_states, hc_attn_pre_norm = layer(
positions,
hidden_states,
input_ids,
None,
batch_to_kv_state,
window_compress_params,
hc_attn_pre_norm=hc_attn_pre_norm,
window_slot_mapping=window_slot_mapping,
compressor_slot_mapping=compressor_slot_mapping,
)
hidden_states = self.hc_head(hidden_states)
hidden_states = self.norm(hidden_states).to(dtype=torch.float)
return hidden_states
class MLUDeepseekV4ForCausalLM(nn.Module, SupportsEagle):
packed_modules_mapping = {
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = MLUDeepseekV4Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.lm_head = ColumnParallelLinear(
config.dim,
config.vocab_size,
params_dtype=torch.float32,
quant_config=quant_config,
bias=False,
skip_bias_add=True,
return_bias=False,
)
def update_forward_args(self, args, kwargs):
window_size = self.config.window_size
# Part 1. window slot mapping.
common_metadata: MLUCommonAttentionMetadata = get_common_metadata()
if common_metadata is None or common_metadata.block_table_tensor is None:
window_slot_mapping = None
elif common_metadata.is_prefill_only:
block_table = common_metadata.block_table_tensor
query_start_loc = common_metadata.query_start_loc
window_slot_mapping = torch.empty([query_start_loc[-1]], dtype=torch.int32, device=block_table.device)
window_slot_mapping.fill_(-1)
for i, seq_len in enumerate(common_metadata.seq_lens):
if seq_len < window_size:
window_slot_mapping[query_start_loc[i]: query_start_loc[i+1]].copy_(block_table[i, :seq_len])
else:
# | <------- seqlen--------> |
# | other | window size |
# | other | tail | head |
# move head to the front of window, and move tail to the latter.
tail_pos = query_start_loc[i].item() + seq_len - window_size
head_size = seq_len % window_size
tail_size = window_size - head_size
window_slot_mapping[tail_pos: tail_pos + tail_size].copy_(
block_table[i, head_size:window_size],
)
window_slot_mapping[tail_pos + tail_size: tail_pos + window_size].copy_(
block_table[i, :head_size]
)
else:
block_table = common_metadata.block_table_tensor
window_pos = (common_metadata.seq_lens - 1) % window_size
window_slot_mapping = torch.gather(block_table, 1, window_pos.unsqueeze(1)).squeeze(1)
kwargs["window_slot_mapping"] = window_slot_mapping
# Part 2. compressor slot mapping
assert set(self.config.compress_ratios) == {0, 4, 128}
# The pairs <window_offset, compress_ratio> <128, 128> <128, 4> <0, 4> contain all cases.
# <128, 128> and <128, 4> indicate attn.compressor, and
# <0, 4> indicates attn.indexer.compressor.
window_offsets = [128, 128, 0]
compress_ratios = [128, 4, 4]
# dict key: (window_size, compress_ratio)
compressor_slot_mapping = dict()
if common_metadata is None or common_metadata.block_table_tensor is None:
pass
elif common_metadata.is_prefill_only:
block_tables = common_metadata.block_table_tensor
query_start_loc = common_metadata.query_start_loc
query_start_loc = common_metadata.query_start_loc
query_lens = (query_start_loc[1:] - query_start_loc[:-1]).tolist()
for compress_ratio, window_offset in zip(compress_ratios, window_offsets):
slot_lens = [q // compress_ratio for q in query_lens]
cu_slot_lens = torch.cat([
torch.tensor([0], dtype=torch.int32, device='cpu'),
torch.cumsum(torch.tensor(slot_lens, dtype=torch.int32, device='cpu'), dim=0)],
)
slot_mapping = torch.empty(sum(slot_lens), dtype=torch.int32, device=block_table.device)
for i in range(len(query_lens)):
slot_mapping[cu_slot_lens[i]: cu_slot_lens[i+1]] = \
block_tables[i, window_offset: window_offset + slot_lens[i]]
compressor_slot_mapping[(window_offset, compress_ratio)] = slot_mapping
else:
block_tables = common_metadata.block_table_tensor
seq_lens = common_metadata.seq_lens
query_start_loc = common_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
for compress_ratio, window_offset in zip(compress_ratios, window_offsets):
offset = window_offset + (seq_lens - query_lens) // compress_ratio
slot_mapping = torch.gather(block_tables, 1, offset.unsqueeze(1)).squeeze(1)
compressor_slot_mapping[(window_offset, compress_ratio)] = slot_mapping
kwargs["compressor_slot_mapping"] = compressor_slot_mapping
return args, kwargs
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
batch_to_kv_state: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
window_slot_mapping: torch.Tensor | None = None,
compressor_slot_mapping: dict | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids,
positions,
batch_to_kv_state,
inputs_embeds,
window_slot_mapping,
compressor_slot_mapping,
)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
('w13', 'w1', 0),
('w13', 'w3', 1),
]
for name, m in self.model.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
moe_group_info = MoeGroupInfo()
moe_ep_size = moe_group_info.moe_ep_size
moe_ep_rank = moe_group_info.moe_ep_rank
num_total_experts = self.config.n_routed_experts
start_expert_id = moe_ep_rank * ((num_total_experts + moe_ep_size - 1) // moe_ep_size)
expert_num_per_rank = (num_total_experts + moe_ep_size - 1) // moe_ep_size
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
pattern = r'layers\.([0-9]*)\.'
match = re.search(pattern, name)
if match:
layer_id = int(match.group(1))
if layer_id >= self.config.n_layers:
continue
# The following parameters are not included yet.
skiped_parameters = ['mtp']
if any(param in name for param in skiped_parameters):
continue
name = name.replace("embed.weight", "embed_tokens.weight")
name = "model." + name
name = name.replace("model.head.weight", "lm_head.weight")
if "ffn.experts." in name:
expert_id = int(name.split(".")[-3])
if expert_id < start_expert_id or expert_id >= start_expert_id + ((num_total_experts + moe_ep_size - 1) // moe_ep_size):
continue
new_expert_id = expert_id - start_expert_id
name = name.replace(f"experts.{expert_id}", f"experts.{new_expert_id}")
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if "w1.weight" not in name and \
"w3.weight" not in name:
continue
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# remap parameter name for hc pre
name = name.replace("hc_attn_base", "hc_attn_pre.hc_base")
name = name.replace("hc_attn_fn", "hc_attn_pre.hc_fn")
name = name.replace("hc_attn_scale", "hc_attn_pre.hc_scale")
name = name.replace("hc_ffn_base", "hc_ffn_pre.hc_base")
name = name.replace("hc_ffn_fn", "hc_ffn_pre.hc_fn")
name = name.replace("hc_ffn_scale", "hc_ffn_pre.hc_scale")
# remap parameter name for hc head
name = name.replace("hc_head_base", "hc_head.hc_head_base")
name = name.replace("hc_head_fn", "hc_head.hc_head_fn")
name = name.replace("hc_head_scale", "hc_head.hc_head_scale")
name = name.replace("gate.tid2eid", "deepseekv4_topk.tid2eid")
name = name.replace("ffn.gate.bias", "ffn.deepseekv4_topk.bias")
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if 'attn_sink' in name:
num_heads = self.config.n_heads
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
assert num_heads % tp_size == 0
num_local_heads = num_heads // tp_size
loaded_weight = loaded_weight[tp_rank * num_local_heads: (tp_rank + 1) * num_local_heads]
weight_loader(param, loaded_weight)
loaded_params.add(name)
if diff := set(params_dict.keys()) - loaded_params:
logger.error(f"The following params are not loaded: {diff}")
for name, m in self.model.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params_after_loading()
return set(loaded_params)