[Feature] Add MLAProcess for DeepSeek MLA on NPU (#10130)

This commit is contained in:
Even Zhou
2025-09-23 08:17:48 +08:00
committed by GitHub
parent 0753ef831e
commit d27a6f7092
7 changed files with 369 additions and 23 deletions

View File

@@ -43,6 +43,10 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
NPUFusedMLAPreprocess,
is_mla_preprocess_enabled,
)
from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
@@ -1177,6 +1181,12 @@ class DeepseekV2AttentionMLA(nn.Module):
self.weight_block_size = (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
)
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
if self.is_mla_preprocess_enabled:
assert (
quant_config.get_name() == "w8a8_int8"
), "MLA Preprocess only works with W8A8Int8"
self.mla_preprocess = None
def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch
@@ -1263,9 +1273,28 @@ class DeepseekV2AttentionMLA(nn.Module):
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA:
inner_state = self.forward_absorb_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
if not self.is_mla_preprocess_enabled:
inner_state = self.forward_absorb_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
else:
# TODO(iforgetmyname): to be separated as a standalone func
if self.mla_preprocess is None:
self.mla_preprocess = NPUFusedMLAPreprocess(
self.fused_qkv_a_proj_with_mqa,
self.q_a_layernorm,
self.kv_a_layernorm,
self.q_b_proj,
self.w_kc,
self.rotary_emb,
self.layer_id,
self.num_local_heads,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
)
inner_state = self.mla_preprocess.forward(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
inner_state = self.forward_absorb_fused_mla_rope_prepare(
positions, hidden_states, forward_batch, zero_allocator