Support DeepSeekV3.2 with MLAPO operator (#4753)
### What this PR does / why we need it?
This PR adds support for the optimized MLAPO operator in DSV3.2 and this
operator provides an optimized implementation that avoids redundant
q_down recomputation.
The operator implementation and optimizations were introduced in PR
[#4707](https://github.com/vllm-project/vllm-ascend/pull/4707).
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -7,15 +7,19 @@ from torch import nn
|
||||
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, ReplicatedLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.ops.shared_weight_layer import (
|
||||
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
|
||||
@@ -23,6 +27,7 @@ from vllm_ascend.ops.shared_weight_layer import (
|
||||
register_layer_to_shared_weight_series)
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
_round_up, dispose_layer, enable_sp,
|
||||
is_enable_nz, replace_layer)
|
||||
@@ -341,12 +346,13 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
assert self.indexer is not None, "Indexer is required for DSA."
|
||||
|
||||
self.enable_sfa_cp = enable_sp()
|
||||
self.local_num_heads = self.num_heads
|
||||
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
if self.enable_sfa_cp:
|
||||
self.local_num_heads = self.num_heads * self.tp_size
|
||||
|
||||
@@ -454,6 +460,29 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
post_process_after_loading_for_shared_weight_series(
|
||||
self.o_proj)
|
||||
|
||||
if self.enable_mlapo:
|
||||
quant_method = getattr(
|
||||
getattr(self.fused_qkv_a_proj, "quant_method", None),
|
||||
"quant_method",
|
||||
None,
|
||||
)
|
||||
reasons = []
|
||||
if self.fused_qkv_a_proj is None or not isinstance(
|
||||
quant_method, AscendW8A8LinearMethod):
|
||||
reasons.append(
|
||||
"Currently mlapo only supports W8A8 quantization in MLA scenario."
|
||||
"Some layers in your model are not quantized with W8A8,"
|
||||
"thus mlapo is disabled for these layers.")
|
||||
if self.enable_sfa_cp:
|
||||
reasons.append("Currently mlapo does not support SFA with CP,"
|
||||
"thus mlapo is disabled for these layers.")
|
||||
if reasons:
|
||||
self.enable_mlapo = False
|
||||
for msg in reasons:
|
||||
logger.warning_once(msg)
|
||||
else:
|
||||
self._process_weights_for_fused_mlapo(act_dtype)
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
|
||||
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
|
||||
@@ -555,6 +584,161 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
x = torch_npu.npu_interleave_rope(x, cos, sin)
|
||||
return x.view(B, N, D)
|
||||
|
||||
# Processing the input parameters for MLAPO by reordering and transposing
|
||||
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
|
||||
# and handling quantization parameters.
|
||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||
assert self.kv_a_proj_with_mqa is None
|
||||
assert self.fused_qkv_a_proj is not None
|
||||
|
||||
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
..., self.q_lora_rank:].contiguous()
|
||||
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
..., :self.q_lora_rank].contiguous()
|
||||
|
||||
self.fused_qkv_a_proj.weight = None
|
||||
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
|
||||
wd_qkv = wd_qkv.t().contiguous()
|
||||
wd_qkv = transdata(wd_qkv,
|
||||
block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
||||
|
||||
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[
|
||||
self.q_lora_rank:].contiguous()
|
||||
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self.
|
||||
q_lora_rank].contiguous(
|
||||
)
|
||||
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
||||
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
|
||||
self.qk_rope_head_dim)
|
||||
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
||||
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl),
|
||||
dim=-1).contiguous()
|
||||
|
||||
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[
|
||||
self.q_lora_rank:].contiguous()
|
||||
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self.
|
||||
q_lora_rank].contiguous(
|
||||
)
|
||||
|
||||
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
||||
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
|
||||
self.qk_rope_head_dim)
|
||||
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
||||
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias),
|
||||
dim=-1).contiguous()
|
||||
|
||||
wu_q = self.q_proj.weight.data
|
||||
wu_q = wu_q.t().reshape(self.num_heads,
|
||||
self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
-1)
|
||||
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
|
||||
wu_q = wu_q.reshape(
|
||||
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
|
||||
-1)
|
||||
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
|
||||
|
||||
qb_deq_scl = self.q_proj.deq_scale.data
|
||||
qb_deq_scl = qb_deq_scl.reshape(
|
||||
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
||||
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
|
||||
self.qb_deq_scl = qb_deq_scl.reshape(
|
||||
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||
|
||||
qb_qt_bias = self.q_proj.quant_bias.data
|
||||
qb_qt_bias = qb_qt_bias.reshape(
|
||||
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
||||
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
|
||||
self.qb_qt_bias = qb_qt_bias.reshape(
|
||||
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||
|
||||
device = self.q_proj.weight.device
|
||||
self.gamma1 = self.q_a_layernorm.weight.data
|
||||
self.beta1 = self.q_a_layernorm.bias.data
|
||||
self.gamma2 = self.kv_a_layernorm.weight.data
|
||||
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
|
||||
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
|
||||
self.quant_scale1 = self.q_proj.input_scale.data
|
||||
self.quant_offset1 = self.q_proj.input_offset.data
|
||||
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||
|
||||
if self.vllm_config.kv_transfer_config is not None:
|
||||
self.fused_qkv_a_proj.deq_scale = None
|
||||
self.fused_qkv_a_proj.quant_bias = None
|
||||
self.q_proj.deq_scale = None
|
||||
self.q_proj.quant_bias = None
|
||||
torch.npu.empty_cache()
|
||||
|
||||
def _sfa_preprocessc_decode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
need_gather_q_kv: bool,
|
||||
num_actual_tokens: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), need_gather_q_kv)
|
||||
k_nope, k_pe = kv_cache[0], kv_cache[1]
|
||||
ql_nope = torch.empty(
|
||||
(num_actual_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_pe = torch.empty(
|
||||
(num_actual_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_c = torch.empty(
|
||||
(num_actual_tokens, self.q_lora_rank),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
torch.ops._C_ascend.mla_preprocess(
|
||||
hidden_states,
|
||||
self.wd_qkv,
|
||||
self.deq_scale_qkv,
|
||||
self.gamma1,
|
||||
self.beta1,
|
||||
self.wu_q,
|
||||
self.qb_deq_scl,
|
||||
self.gamma2,
|
||||
attn_metadata.cos,
|
||||
attn_metadata.sin,
|
||||
self.W_UK_T,
|
||||
k_nope,
|
||||
k_pe,
|
||||
attn_metadata.slot_mapping[:num_actual_tokens].flatten(),
|
||||
quant_scale0=self.quant_scale0,
|
||||
quant_offset0=self.quant_offset0,
|
||||
bias0=self.quant_bias_qkv,
|
||||
quant_scale1=self.quant_scale1,
|
||||
quant_offset1=self.quant_offset1,
|
||||
bias1=self.qb_qt_bias,
|
||||
ctkv_scale=self.ctkv_scale,
|
||||
q_nope_scale=self.q_nope_scale,
|
||||
cache_mode="krope_ctkv",
|
||||
quant_mode="per_tensor_quant_asymm",
|
||||
enable_inner_out=True,
|
||||
q_out0=ql_nope,
|
||||
kv_cache_out0=k_nope,
|
||||
q_out1=q_pe,
|
||||
kv_cache_out1=k_pe,
|
||||
inner_out=q_c,
|
||||
)
|
||||
return hidden_states, ql_nope, q_pe, q_c
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer_name,
|
||||
@@ -565,69 +749,76 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
forward_context = get_forward_context()
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
if self.enable_sfa_cp:
|
||||
from vllm.forward_context import get_forward_context
|
||||
if not get_forward_context().in_profile_run:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
reach_layer_for_shared_weight_series(self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
|
||||
if self.enable_sfa_cp and not forward_context.in_profile_run:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
reach_layer_for_shared_weight_series(self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
return output.fill_(0)
|
||||
has_prefill = attn_metadata.has_prefill
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
cos = attn_metadata.cos
|
||||
sin = attn_metadata.sin
|
||||
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
||||
actual_seq_lengths_key = attn_metadata.seq_lens
|
||||
hidden_states = hidden_states[:num_actual_tokens]
|
||||
if self.enable_sfa_cp:
|
||||
need_gather_q_kv = False
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_tokens]
|
||||
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
|
||||
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
|
||||
dependency=hidden_states,
|
||||
enabled=self.enable_prefetch)
|
||||
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
||||
q_c, kv_no_split = qkv_lora.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
|
||||
# Process for Flash Comm V1
|
||||
if need_gather_q_kv:
|
||||
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
q_c.contiguous(), need_gather_q_kv)
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split.contiguous(), need_gather_q_kv)
|
||||
if self.enable_mlapo and not forward_context.with_prefill:
|
||||
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocessc_decode(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
need_gather_q_kv=need_gather_q_kv,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
)
|
||||
else:
|
||||
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
|
||||
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
|
||||
dependency=hidden_states,
|
||||
enabled=self.enable_prefetch)
|
||||
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
||||
q_c, kv_no_split = qkv_lora.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
# Process for Flash Comm V1
|
||||
if need_gather_q_kv:
|
||||
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
q_c.contiguous(), need_gather_q_kv)
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split.contiguous(), need_gather_q_kv)
|
||||
|
||||
if has_prefill:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
if has_prefill:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
cos = attn_metadata.cos
|
||||
sin = attn_metadata.sin
|
||||
slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
slot_mapping_cp = None
|
||||
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
||||
actual_seq_lengths_key = attn_metadata.seq_lens
|
||||
if self.enable_sfa_cp:
|
||||
assert attn_metadata.sfa_cp_context is not None
|
||||
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
|
||||
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
|
||||
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
|
||||
slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
slot_mapping_cp = None
|
||||
if self.enable_sfa_cp:
|
||||
assert attn_metadata.sfa_cp_context is not None
|
||||
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
|
||||
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
|
||||
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
|
||||
|
||||
self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
|
||||
slot_mapping_cp)
|
||||
self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
|
||||
slot_mapping_cp)
|
||||
|
||||
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
reach_layer_for_shared_weight_series(self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
reach_layer_for_shared_weight_series(self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
|
||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
|
||||
topk_indices = self.indexer_select(
|
||||
x=hidden_states,
|
||||
|
||||
Reference in New Issue
Block a user