Support DeepSeek V3.2 Exp (#11061)

Co-authored-by: Stefan He <11166516+hebiao064@users.noreply.github.com>
Co-authored-by: Liangsheng Yin <95566987+hnyls2002@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <56809903+fridge003@users.noreply.github.com>
Co-authored-by: DarkSharpness <76582120+darksharpness@users.noreply.github.com>
Co-authored-by: ZhengdQin <46387172+zhengdqin@users.noreply.github.com>
Co-authored-by: DarkSharpness <2040703891@qq.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
Co-authored-by: Zhengda Qin <zhengdqin@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: HAI <hixiao@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
fzyzcjy
2025-10-06 15:24:15 +08:00
committed by GitHub
parent 292a867ad9
commit efbc687c28
29 changed files with 4540 additions and 139 deletions

View File

@@ -15,6 +15,7 @@
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
from __future__ import annotations
import concurrent.futures
import logging
@@ -25,10 +26,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig
from sglang.srt import single_batch_overlap
from sglang.srt.configs.model_config import (
get_nsa_index_head_dim,
get_nsa_index_n_heads,
get_nsa_index_topk,
is_deepseek_nsa,
)
from sglang.srt.debug_utils.dumper import dumper
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
@@ -48,6 +55,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
NPUFusedMLAPreprocess,
is_mla_preprocess_enabled,
)
from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
@@ -172,10 +180,13 @@ elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize,
)
elif _is_npu:
import custom_ops
import sgl_kernel_npu
import torch_npu
else:
pass
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
@@ -184,6 +195,7 @@ logger = logging.getLogger(__name__)
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
"fa3",
"nsa",
"flashinfer",
"cutlass_mla",
"trtllm_mla",
@@ -204,6 +216,9 @@ class AttnForwardMethod(IntEnum):
# Use absorbed multi-latent attention
MLA = auto()
# Use Deepseek V3.2 sparse multi-latent attention
NPU_MLA_SPARSE = auto()
# Use multi-head attention, but with KV cache chunked.
# This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV = auto()
@@ -246,9 +261,15 @@ def handle_attention_ascend(attn, forward_batch):
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA
if hasattr(attn, "indexer"):
return AttnForwardMethod.NPU_MLA_SPARSE
else:
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
if hasattr(attn, "indexer"):
return AttnForwardMethod.NPU_MLA_SPARSE
else:
return AttnForwardMethod.MLA
def _get_sum_extend_prefix_lens(forward_batch):
@@ -267,7 +288,9 @@ def _is_extend_without_speculative(forward_batch):
)
def _handle_attention_backend(attn, forward_batch, backend_name):
def _handle_attention_backend(
attn: DeepseekV2AttentionMLA, forward_batch, backend_name
):
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
disable_ragged = (
backend_name in ["flashinfer", "flashmla"]
@@ -333,6 +356,10 @@ def handle_attention_aiter(attn, forward_batch):
return AttnForwardMethod.MLA
def handle_attention_nsa(attn, forward_batch):
return AttnForwardMethod.MLA
def handle_attention_triton(attn, forward_batch):
if (
_is_extend_without_speculative(forward_batch)
@@ -1005,6 +1032,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
# For tensor parallel attention
if self.q_lora_rank is not None:
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
@@ -1042,6 +1073,26 @@ class DeepseekV2AttentionMLA(nn.Module):
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
)
self.use_nsa = is_deepseek_nsa(config)
if self.use_nsa:
self.indexer = Indexer(
hidden_size=hidden_size,
index_n_heads=get_nsa_index_n_heads(config),
index_head_dim=get_nsa_index_head_dim(config),
rope_head_dim=qk_rope_head_dim,
index_topk=get_nsa_index_topk(config),
q_lora_rank=q_lora_rank,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
scale_fmt="ue8m0",
block_size=128,
rope_scaling=rope_scaling,
prefix=add_prefix("indexer", prefix),
quant_config=quant_config,
layer_id=layer_id,
alt_stream=alt_stream,
)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -1064,9 +1115,6 @@ class DeepseekV2AttentionMLA(nn.Module):
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
@@ -1193,8 +1241,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
if self.is_mla_preprocess_enabled:
assert (
quant_config.get_name() == "w8a8_int8"
), "MLA Preprocess only works with W8A8Int8"
quant_config is None or quant_config.get_name() == "w8a8_int8"
), "MLA Preprocess only works with Unquant or W8A8Int8"
self.mla_preprocess = None
def dispatch_attn_forward_method(
@@ -1272,7 +1320,6 @@ class DeepseekV2AttentionMLA(nn.Module):
return hidden_states, None, forward_batch, None
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
if attn_forward_method == AttnForwardMethod.MHA:
inner_state = self.forward_normal_prepare(
positions, hidden_states, forward_batch, zero_allocator
@@ -1304,6 +1351,10 @@ class DeepseekV2AttentionMLA(nn.Module):
inner_state = self.mla_preprocess.forward(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
inner_state = self.forward_npu_sparse_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
inner_state = self.forward_absorb_fused_mla_rope_prepare(
positions, hidden_states, forward_batch, zero_allocator
@@ -1329,6 +1380,8 @@ class DeepseekV2AttentionMLA(nn.Module):
return self.forward_normal_chunked_kv_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
return self.forward_npu_sparse_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
@@ -1424,6 +1477,7 @@ class DeepseekV2AttentionMLA(nn.Module):
):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
q_lora = None
if self.q_lora_rank is not None:
if (
(not isinstance(hidden_states, tuple))
@@ -1462,6 +1516,10 @@ class DeepseekV2AttentionMLA(nn.Module):
q = self.q_a_layernorm(q)
k_nope = self.kv_a_layernorm(k_nope)
# q_lora needed by indexer
if self.use_nsa:
q_lora = q
k_nope = k_nope.unsqueeze(1)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
@@ -1527,14 +1585,41 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1)
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
not _use_aiter or not _is_gfx95_supported
not _use_aiter or not _is_gfx95_supported or self.use_nsa
):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
topk_indices = None
if q_lora is not None:
topk_indices = self.indexer(
x=hidden_states,
q_lora=q_lora,
positions=positions,
forward_batch=forward_batch,
layer_id=self.layer_id,
)
return (
q_pe,
k_pe,
q_nope_out,
k_nope,
forward_batch,
zero_allocator,
positions,
topk_indices,
)
def forward_absorb_core(
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
self,
q_pe,
k_pe,
q_nope_out,
k_nope,
forward_batch,
zero_allocator,
positions,
topk_indices,
):
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
extra_args = {}
@@ -1543,6 +1628,7 @@ class DeepseekV2AttentionMLA(nn.Module):
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
"is_neox": self.rotary_emb.is_neox_style,
}
attn_output = self.attn_mqa(
q_nope_out,
k_nope,
@@ -1551,6 +1637,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_rope=q_pe,
k_rope=k_pe,
**extra_args,
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
)
else:
if _use_aiter_gfx95:
@@ -1570,7 +1657,13 @@ class DeepseekV2AttentionMLA(nn.Module):
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
attn_output = self.attn_mqa(
q,
k,
k_nope,
forward_batch,
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.use_deep_gemm_bmm:
@@ -1652,6 +1745,221 @@ class DeepseekV2AttentionMLA(nn.Module):
return output
def forward_npu_sparse_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
):
"""
Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
"""
if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
if self.mla_preprocess is None:
self.mla_preprocess = NPUFusedMLAPreprocess(
self.fused_qkv_a_proj_with_mqa,
self.q_a_layernorm,
self.kv_a_layernorm,
self.q_b_proj,
self.w_kc,
self.rotary_emb,
self.layer_id,
self.num_local_heads,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
)
(
q_pe,
k_pe,
q_nope_out,
k_nope,
forward_batch,
zero_allocator,
positions,
) = self.mla_preprocess.forward(
positions, hidden_states, forward_batch, zero_allocator
)
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
q, _ = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q_lora = self.q_a_layernorm(q)
else:
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if (
(not isinstance(hidden_states, tuple))
and hidden_states.shape[0] <= 16
and self.use_min_latency_fused_a_gemm
):
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
)
else:
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
q, latent_cache = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
k_nope = latent_cache[..., : self.kv_lora_rank]
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q = self.q_a_layernorm(q)
with torch.cuda.stream(self.alt_stream):
k_nope = self.kv_a_layernorm(k_nope)
current_stream.wait_stream(self.alt_stream)
else:
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
q, k_nope = fused_rms_mxfp4_quant(
q,
self.q_a_layernorm.weight,
self.q_a_layernorm.variance_epsilon,
k_nope,
self.kv_a_layernorm.weight,
self.kv_a_layernorm.variance_epsilon,
)
else:
q = self.q_a_layernorm(q)
k_nope = self.kv_a_layernorm(k_nope)
q_lora = q.clone() # required for topk_indices
k_nope = k_nope.unsqueeze(1)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
q_nope, q_pe = q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
per_token_group_quant_mla_deep_gemm_masked_fp8(
q_nope.transpose(0, 1)
)
)
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(q_nope_val, q_nope_scale),
(self.w_kc, self.w_scale_k),
q_nope_out,
masked_m,
expected_m,
)
q_nope_out = q_nope_out[:, :expected_m, :]
elif _is_hip:
# TODO(haishaw): add bmm_fp8 to ROCm
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
x = q_nope.transpose(0, 1)
q_nope_out = torch.empty(
x.shape[0],
x.shape[1],
self.w_kc.shape[2],
device=x.device,
dtype=torch.bfloat16,
)
batched_gemm_afp4wfp4_pre_quant(
x,
self.w_kc.transpose(-2, -1),
self.w_scale_k.transpose(-2, -1),
torch.bfloat16,
q_nope_out,
)
else:
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1),
zero_allocator.allocate(1),
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
else:
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_nope_out = q_nope_out.transpose(0, 1)
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
not _use_aiter or not _is_gfx95_supported
):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
# TODO: multi-stream indexer
topk_indices = self.indexer(
hidden_states, q_lora, positions, forward_batch, self.layer_id
)
return (
q_pe,
k_pe,
q_nope_out,
k_nope,
topk_indices,
forward_batch,
zero_allocator,
positions,
)
def forward_npu_sparse_core(
self,
q_pe,
k_pe,
q_nope_out,
k_nope,
topk_indices,
forward_batch,
zero_allocator,
positions,
):
attn_output = self.attn_mqa(
q_nope_out.contiguous(),
k_nope.contiguous(),
k_nope.contiguous(),
forward_batch,
save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
q_rope=q_pe.contiguous(),
k_rope=k_pe.contiguous(),
topk_indices=topk_indices,
)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
attn_bmm_output = torch.empty(
(attn_output.shape[0], self.num_local_heads, self.v_head_dim),
dtype=attn_output.dtype,
device=attn_output.device,
)
if not forward_batch.forward_mode.is_decode():
attn_output = attn_output.transpose(0, 1)
torch.bmm(
attn_output,
self.w_vc,
out=attn_bmm_output.view(
-1, self.num_local_heads, self.v_head_dim
).transpose(0, 1),
)
else:
attn_output = attn_output.contiguous()
torch.ops.npu.batch_matmul_transpose(
attn_output, self.w_vc, attn_bmm_output
)
attn_bmm_output = attn_bmm_output.reshape(
-1, self.num_local_heads * self.v_head_dim
)
output, _ = self.o_proj(attn_bmm_output)
return output
def forward_absorb_fused_mla_rope_prepare(
self,
positions: torch.Tensor,
@@ -2134,7 +2442,6 @@ class DeepseekV2DecoderLayer(nn.Module):
zero_allocator: BumpAllocator,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor:
quant_format = (
"mxfp4"
if _is_gfx95_supported
@@ -3099,6 +3406,7 @@ AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
AttentionBackendRegistry.register("aiter", handle_attention_aiter)
AttentionBackendRegistry.register("nsa", handle_attention_nsa)
AttentionBackendRegistry.register("triton", handle_attention_triton)
@@ -3106,4 +3414,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
pass
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]