[Feature] Add MLAProcess for DeepSeek MLA on NPU (#10130)
This commit is contained in:
@@ -9,6 +9,7 @@ from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
@@ -401,7 +402,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
antiquant_scale=None,
|
||||
sparse_mode=0,
|
||||
)
|
||||
output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
|
||||
output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
|
||||
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
||||
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
@@ -437,6 +438,10 @@ class AscendAttnBackend(AttentionBackend):
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if is_mla_preprocess_enabled():
|
||||
# MLAPO does saving kv_cache
|
||||
save_kv_cache = False
|
||||
|
||||
if self.graph_mode:
|
||||
return self.forward_decode_graph(
|
||||
q,
|
||||
|
||||
300
python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py
Normal file
300
python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py
Normal file
@@ -0,0 +1,300 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, is_npu
|
||||
|
||||
_is_npu = is_npu()
|
||||
_ENABLE_MLA_PREPROCESS_FLAG = get_bool_env_var("SGLANG_NPU_USE_MLAPO")
|
||||
_NPU_FORMAT_NZ = 29
|
||||
|
||||
|
||||
def is_mla_preprocess_enabled() -> bool:
|
||||
return _is_npu and _ENABLE_MLA_PREPROCESS_FLAG
|
||||
|
||||
|
||||
if is_mla_preprocess_enabled():
|
||||
import sgl_kernel_npu
|
||||
import torch_npu
|
||||
|
||||
torch.npu.config.allow_internal_format = True
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
|
||||
|
||||
def round_up(val: int, align: int) -> int:
|
||||
if align == 0:
|
||||
return 0
|
||||
return -(val // -align) * align
|
||||
|
||||
|
||||
def transdata(nd_mat, block_size: tuple = (16, 16)):
|
||||
r = round_up(nd_mat.shape[0], block_size[0])
|
||||
c = round_up(nd_mat.shape[1], block_size[1])
|
||||
r_pad = r - nd_mat.shape[0]
|
||||
c_pad = c - nd_mat.shape[1]
|
||||
nd_mat = F.pad(nd_mat, ((0, r_pad, 0, c_pad)))
|
||||
nz_mat = torch.permute(
|
||||
torch.reshape(
|
||||
nd_mat,
|
||||
(r // block_size[0], block_size[0], c // block_size[1], block_size[1]),
|
||||
),
|
||||
[2, 0, 1, 3],
|
||||
)
|
||||
nz_mat = torch.reshape(
|
||||
nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])
|
||||
)
|
||||
return nz_mat
|
||||
|
||||
|
||||
def trans_rope_weight(weight, rope_dim):
|
||||
weight_1 = weight[..., -rope_dim::2, :].contiguous()
|
||||
weight_2 = weight[..., -rope_dim + 1 :: 2, :].contiguous()
|
||||
weight[..., -rope_dim:, :] = torch.cat([weight_1, weight_2], dim=-2)
|
||||
|
||||
return weight.contiguous()
|
||||
|
||||
|
||||
class NPUFusedMLAPreprocess(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
fused_qkv_a_proj_with_mqa,
|
||||
q_a_layernorm,
|
||||
kv_a_layernorm,
|
||||
q_b_proj,
|
||||
w_kc,
|
||||
rotary_emb,
|
||||
layer_id,
|
||||
num_local_heads,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.qkv_a_proj = fused_qkv_a_proj_with_mqa
|
||||
self.q_a_layernorm = q_a_layernorm
|
||||
self.kv_a_layernorm = kv_a_layernorm
|
||||
self.q_b_proj = q_b_proj
|
||||
self.w_kc = w_kc.contiguous()
|
||||
self.rotary_emb = rotary_emb
|
||||
self.layer_id = layer_id
|
||||
self.has_preprocess_weights = False
|
||||
|
||||
self.q_lora_rank = self.q_b_proj.input_size # 1536
|
||||
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
|
||||
self.num_local_heads = num_local_heads # tp
|
||||
self.qk_nope_head_dim = qk_nope_head_dim # 128
|
||||
self.qk_rope_head_dim = qk_rope_head_dim # 64
|
||||
|
||||
def preprocess_weights(self, hidden_states):
|
||||
self.dummy = torch.empty(
|
||||
(hidden_states.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
self.qkv_a_proj_input_offset = self.qkv_a_proj.input_offset.to(dtype=torch.int8)
|
||||
self.q_b_proj_input_offset = self.q_b_proj.input_offset.to(dtype=torch.int8)
|
||||
|
||||
# matmul_0 weight [7168, 2112]
|
||||
fused_qkv_a_proj_with_mqa_weight_q = self.qkv_a_proj.weight.data[
|
||||
:, : self.q_lora_rank
|
||||
].clone() # [7168, 1536]
|
||||
fused_qkv_a_proj_with_mqa_weight_kv = self.qkv_a_proj.weight.data[
|
||||
:, self.q_lora_rank :
|
||||
].clone() # [7168, 576]
|
||||
# rope fit
|
||||
fused_qkv_a_proj_with_mqa_weight_kv_t = (
|
||||
fused_qkv_a_proj_with_mqa_weight_kv.t().contiguous()
|
||||
)
|
||||
fused_qkv_a_proj_with_mqa_weight_kv_t = trans_rope_weight(
|
||||
fused_qkv_a_proj_with_mqa_weight_kv_t, self.qk_rope_head_dim
|
||||
)
|
||||
fused_qkv_a_proj_with_mqa_weight_kv = (
|
||||
fused_qkv_a_proj_with_mqa_weight_kv_t.t().contiguous()
|
||||
)
|
||||
# cat nz
|
||||
fused_qkv_a_proj_with_mqa_weight_new = torch.cat(
|
||||
(fused_qkv_a_proj_with_mqa_weight_kv, fused_qkv_a_proj_with_mqa_weight_q),
|
||||
dim=-1,
|
||||
)
|
||||
fused_qkv_a_proj_with_mqa_weight = (
|
||||
fused_qkv_a_proj_with_mqa_weight_new.t().contiguous()
|
||||
)
|
||||
fused_qkv_a_proj_with_mqa_weight_nz = (
|
||||
transdata(fused_qkv_a_proj_with_mqa_weight, block_size=(16, 32))
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
)
|
||||
self.qkv_a_proj_weight_nz = torch_npu.npu_format_cast(
|
||||
fused_qkv_a_proj_with_mqa_weight_nz, _NPU_FORMAT_NZ
|
||||
)
|
||||
|
||||
# matmul_0 deq_scale [2112]
|
||||
fused_qkv_a_proj_with_mqa_deq_scale_q = self.qkv_a_proj.deq_scale.data[
|
||||
: self.q_lora_rank
|
||||
].clone() # [7168, 1536]
|
||||
fused_qkv_a_proj_with_mqa_deq_scale_kv = self.qkv_a_proj.deq_scale.data[
|
||||
self.q_lora_rank :
|
||||
].clone() # [7168, 576]
|
||||
# rope fit
|
||||
fused_qkv_a_proj_with_mqa_deq_scale_kv = (
|
||||
fused_qkv_a_proj_with_mqa_deq_scale_kv.reshape(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim, -1
|
||||
).contiguous()
|
||||
)
|
||||
fused_qkv_a_proj_with_mqa_deq_scale_kv = trans_rope_weight(
|
||||
fused_qkv_a_proj_with_mqa_deq_scale_kv, self.qk_rope_head_dim
|
||||
)
|
||||
fused_qkv_a_proj_with_mqa_deq_scale_kv = (
|
||||
fused_qkv_a_proj_with_mqa_deq_scale_kv.view(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim
|
||||
).contiguous()
|
||||
)
|
||||
self.qkv_a_proj_deq_scale_kvq = torch.cat(
|
||||
(
|
||||
fused_qkv_a_proj_with_mqa_deq_scale_kv,
|
||||
fused_qkv_a_proj_with_mqa_deq_scale_q,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# matmul_0 quant_bias [2112]
|
||||
fused_qkv_a_proj_with_mqa_quant_bias_q = self.qkv_a_proj.quant_bias.data[
|
||||
: self.q_lora_rank
|
||||
].clone() # [7168, 1536]
|
||||
fused_qkv_a_proj_with_mqa_quant_bias_kv = self.qkv_a_proj.quant_bias.data[
|
||||
self.q_lora_rank :
|
||||
].clone() # [7168, 576]
|
||||
# rope fit
|
||||
fused_qkv_a_proj_with_mqa_quant_bias_kv = (
|
||||
fused_qkv_a_proj_with_mqa_quant_bias_kv.reshape(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim, -1
|
||||
).contiguous()
|
||||
)
|
||||
fused_qkv_a_proj_with_mqa_quant_bias_kv = trans_rope_weight(
|
||||
fused_qkv_a_proj_with_mqa_quant_bias_kv, self.qk_rope_head_dim
|
||||
)
|
||||
fused_qkv_a_proj_with_mqa_quant_bias_kv = (
|
||||
fused_qkv_a_proj_with_mqa_quant_bias_kv.view(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim
|
||||
).contiguous()
|
||||
)
|
||||
self.qkv_a_proj_quant_bias_kvq = torch.cat(
|
||||
(
|
||||
fused_qkv_a_proj_with_mqa_quant_bias_kv,
|
||||
fused_qkv_a_proj_with_mqa_quant_bias_q,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# matmul_1 weight [1536, num_head * 192]
|
||||
q_b_proj_weight = self.q_b_proj.weight.data.clone()
|
||||
q_b_proj_weight = q_b_proj_weight.t().reshape(
|
||||
self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1
|
||||
)
|
||||
q_b_proj_weight = trans_rope_weight(q_b_proj_weight, self.qk_rope_head_dim)
|
||||
q_b_proj_weight = q_b_proj_weight.reshape(
|
||||
self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1
|
||||
)
|
||||
q_b_proj_weight_nz = (
|
||||
transdata(q_b_proj_weight, block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
)
|
||||
self.q_b_proj_weight_nz = torch_npu.npu_format_cast(
|
||||
q_b_proj_weight_nz, _NPU_FORMAT_NZ
|
||||
)
|
||||
|
||||
# matmul_1 deq_scale [num_head * 192]
|
||||
q_b_proj_deq_scale = self.q_b_proj.deq_scale.data.clone()
|
||||
q_b_proj_deq_scale = q_b_proj_deq_scale.reshape(
|
||||
self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1
|
||||
)
|
||||
q_b_proj_deq_scale = trans_rope_weight(
|
||||
q_b_proj_deq_scale, self.qk_rope_head_dim
|
||||
)
|
||||
self.q_b_proj_deq_scale = q_b_proj_deq_scale.reshape(
|
||||
self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)
|
||||
)
|
||||
|
||||
# matmul_1 quant_bias [num_head * 192]
|
||||
q_b_proj_quant_bias = self.q_b_proj.quant_bias.data.clone()
|
||||
q_b_proj_quant_bias = q_b_proj_quant_bias.reshape(
|
||||
self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1
|
||||
)
|
||||
q_b_proj_quant_bias = trans_rope_weight(
|
||||
q_b_proj_quant_bias, self.qk_rope_head_dim
|
||||
)
|
||||
self.q_b_proj_quant_bias = q_b_proj_quant_bias.reshape(
|
||||
self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)
|
||||
)
|
||||
|
||||
def get_sin_cos(self, positions):
|
||||
cos_sin = self.rotary_emb.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = cos.repeat(1, 2)
|
||||
sin = sin.repeat(1, 2)
|
||||
return cos, sin
|
||||
|
||||
def get_kv_cache_and_cache_idx(self, forward_batch):
|
||||
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(self.layer_id)
|
||||
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
|
||||
return k_cache, v_cache, slot_mapping
|
||||
|
||||
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
|
||||
input_dtype = hidden_states.dtype
|
||||
if not self.has_preprocess_weights:
|
||||
self.preprocess_weights(hidden_states)
|
||||
self.has_preprocess_weights = True
|
||||
self.dtype = hidden_states.dtype
|
||||
|
||||
cos, sin = self.get_sin_cos(positions)
|
||||
k_cache, v_cache, slot_mapping = self.get_kv_cache_and_cache_idx(forward_batch)
|
||||
|
||||
q_nope_out = torch.empty(
|
||||
(hidden_states.shape[0], self.w_kc.shape[0], k_cache.shape[-1]),
|
||||
dtype=input_dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_rope_out = torch.empty(
|
||||
(hidden_states.shape[0], self.w_kc.shape[0], v_cache.shape[-1]),
|
||||
dtype=input_dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
# TODO: dummy inputs to be removed
|
||||
# https://github.com/sgl-project/sgl-kernel-npu/issues/78
|
||||
torch.ops.npu.mla_preprocess(
|
||||
hidden_states,
|
||||
self.dummy,
|
||||
self.dummy,
|
||||
self.qkv_a_proj_weight_nz,
|
||||
self.qkv_a_proj_deq_scale_kvq,
|
||||
self.q_a_layernorm.weight,
|
||||
self.q_a_layernorm.bias,
|
||||
self.q_b_proj_weight_nz,
|
||||
self.q_b_proj_deq_scale,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
self.w_kc,
|
||||
k_cache,
|
||||
v_cache,
|
||||
slot_mapping,
|
||||
quant_scale0=self.qkv_a_proj.input_scale,
|
||||
quant_offset0=self.qkv_a_proj_input_offset,
|
||||
bias0=self.qkv_a_proj_quant_bias_kvq,
|
||||
quant_scale1=self.q_b_proj.input_scale,
|
||||
quant_offset1=self.q_b_proj_input_offset,
|
||||
bias1=self.q_b_proj_quant_bias,
|
||||
cache_mode="krope_ctkv",
|
||||
quant_mode="per_tensor_quant_asymm",
|
||||
q_out0=q_nope_out,
|
||||
kv_cache_out0=k_cache,
|
||||
q_out1=q_rope_out,
|
||||
kv_cache_out1=v_cache,
|
||||
)
|
||||
return (
|
||||
q_rope_out,
|
||||
v_cache,
|
||||
q_nope_out,
|
||||
k_cache,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
positions,
|
||||
)
|
||||
@@ -782,27 +782,33 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
|
||||
# and generalization to more scenarios will be supported in the future.
|
||||
if query.shape[1] * query.shape[2] > 4096:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
num_tokens = query.shape[0]
|
||||
rotary_mode = "half" if self.is_neox_style else "interleave"
|
||||
num_tokens, num_q_heads, _ = query.shape
|
||||
num_k_heads = key.shape[1]
|
||||
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
|
||||
cos_sin = self.cos_sin_cache[
|
||||
torch.add(positions, offsets) if offsets is not None else positions
|
||||
]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
# Reshape to [batchsize, head_dim, seq, rotary_dim]
|
||||
cos = cos.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
|
||||
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
|
||||
query_rot, key_rot = torch_npu.npu_mrope(
|
||||
torch.add(positions, offsets) if offsets is not None else positions,
|
||||
query_rot.reshape(num_tokens, -1),
|
||||
key_rot.reshape(num_tokens, -1),
|
||||
self.cos_sin_cache,
|
||||
self.rotary_dim,
|
||||
mrope_section=[0, 0, 0],
|
||||
rotary_mode=rotary_mode,
|
||||
query_rot = torch_npu.npu_interleave_rope(
|
||||
query_rot.reshape(num_tokens, num_q_heads, 1, self.rotary_dim),
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key_rot = torch_npu.npu_interleave_rope(
|
||||
key_rot.reshape(num_tokens, num_k_heads, 1, self.rotary_dim),
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
|
||||
key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
|
||||
|
||||
@@ -43,6 +43,10 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.amx_utils import PackWeightMethod
|
||||
from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
|
||||
NPUFusedMLAPreprocess,
|
||||
is_mla_preprocess_enabled,
|
||||
)
|
||||
from sglang.srt.layers.communicator import (
|
||||
LayerCommunicator,
|
||||
LayerScatterModes,
|
||||
@@ -1177,6 +1181,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.weight_block_size = (
|
||||
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
||||
)
|
||||
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
|
||||
if self.is_mla_preprocess_enabled:
|
||||
assert (
|
||||
quant_config.get_name() == "w8a8_int8"
|
||||
), "MLA Preprocess only works with W8A8Int8"
|
||||
self.mla_preprocess = None
|
||||
|
||||
def dispatch_attn_forward_method(
|
||||
self, forward_batch: ForwardBatch
|
||||
@@ -1263,9 +1273,28 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA:
|
||||
inner_state = self.forward_absorb_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
if not self.is_mla_preprocess_enabled:
|
||||
inner_state = self.forward_absorb_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
else:
|
||||
# TODO(iforgetmyname): to be separated as a standalone func
|
||||
if self.mla_preprocess is None:
|
||||
self.mla_preprocess = NPUFusedMLAPreprocess(
|
||||
self.fused_qkv_a_proj_with_mqa,
|
||||
self.q_a_layernorm,
|
||||
self.kv_a_layernorm,
|
||||
self.q_b_proj,
|
||||
self.w_kc,
|
||||
self.rotary_emb,
|
||||
self.layer_id,
|
||||
self.num_local_heads,
|
||||
self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim,
|
||||
)
|
||||
inner_state = self.mla_preprocess.forward(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
|
||||
@@ -174,6 +174,8 @@ def is_blackwell():
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_sm100_supported(device=None) -> bool:
|
||||
if not is_cuda_alike():
|
||||
return False
|
||||
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
||||
torch.version.cuda >= "12.8"
|
||||
)
|
||||
@@ -181,6 +183,8 @@ def is_sm100_supported(device=None) -> bool:
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_sm90_supported(device=None) -> bool:
|
||||
if not is_cuda_alike():
|
||||
return False
|
||||
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
||||
torch.version.cuda >= "12.3"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user