[Feature] Add MLAProcess for DeepSeek MLA on NPU (#10130)
This commit is contained in:
@@ -43,6 +43,10 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.amx_utils import PackWeightMethod
|
||||
from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
|
||||
NPUFusedMLAPreprocess,
|
||||
is_mla_preprocess_enabled,
|
||||
)
|
||||
from sglang.srt.layers.communicator import (
|
||||
LayerCommunicator,
|
||||
LayerScatterModes,
|
||||
@@ -1177,6 +1181,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.weight_block_size = (
|
||||
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
||||
)
|
||||
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
|
||||
if self.is_mla_preprocess_enabled:
|
||||
assert (
|
||||
quant_config.get_name() == "w8a8_int8"
|
||||
), "MLA Preprocess only works with W8A8Int8"
|
||||
self.mla_preprocess = None
|
||||
|
||||
def dispatch_attn_forward_method(
|
||||
self, forward_batch: ForwardBatch
|
||||
@@ -1263,9 +1273,28 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA:
|
||||
inner_state = self.forward_absorb_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
if not self.is_mla_preprocess_enabled:
|
||||
inner_state = self.forward_absorb_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
else:
|
||||
# TODO(iforgetmyname): to be separated as a standalone func
|
||||
if self.mla_preprocess is None:
|
||||
self.mla_preprocess = NPUFusedMLAPreprocess(
|
||||
self.fused_qkv_a_proj_with_mqa,
|
||||
self.q_a_layernorm,
|
||||
self.kv_a_layernorm,
|
||||
self.q_b_proj,
|
||||
self.w_kc,
|
||||
self.rotary_emb,
|
||||
self.layer_id,
|
||||
self.num_local_heads,
|
||||
self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim,
|
||||
)
|
||||
inner_state = self.mla_preprocess.forward(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
|
||||
Reference in New Issue
Block a user