adapt the mla_v1 with the mla_preprocess kernel (#3397)
### What this PR does / why we need it? This pull request integrates a new `mla_preprocess` kernel to create an optimized path for MLA (Multi-Layer Attention) decode operations on Ascend hardware, controlled by an environment flag. The changes include new utility functions for weight transformation, a method to prepare weights for the fused kernel, and logic to route decode-only batches to this new path. My review identified a critical bug in the `transdata` utility function where padding dimensions are swapped, which will lead to incorrect tensor shapes and kernel failures. Additionally, I've pointed out a high-severity maintainability issue in the trans_rope_weight function, which modifies its input in-place, and I have provided a pure-function alternative. ### Does this PR introduce _any_ user-facing change? No user-facing changes by default. User can enable the `mla_preprocess` kernel in model by enable the env-var `VLLM_ASCEND_ENABLE_MLAPO`. ### How was this patch tested? Dedicated Ascend kernels are not covered by our CI yet, so no extra automated tests were added. Future MLA-focused regression runs will cover this path. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@@ -16,11 +16,13 @@ from vllm.model_executor.layers.linear import (LinearBase,
|
|||||||
from vllm.utils import cdiv, round_down
|
from vllm.utils import cdiv, round_down
|
||||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
|
|
||||||
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
maybe_save_kv_layer_to_connector,
|
maybe_save_kv_layer_to_connector,
|
||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
|
trans_rope_weight, transdata,
|
||||||
wait_for_kv_layer_from_connector)
|
wait_for_kv_layer_from_connector)
|
||||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||||
@@ -639,6 +641,87 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||||
|
|
||||||
|
if envs.VLLM_ASCEND_ENABLE_MLAPO:
|
||||||
|
self._process_weights_for_fused_mlapo(act_dtype)
|
||||||
|
|
||||||
|
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||||
|
kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data
|
||||||
|
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||||
|
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
|
||||||
|
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||||
|
wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data), dim=-1)
|
||||||
|
wd_qkv = wd_qkv.t().contiguous()
|
||||||
|
wd_qkv = transdata(wd_qkv,
|
||||||
|
block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||||
|
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
||||||
|
|
||||||
|
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale
|
||||||
|
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
||||||
|
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
|
||||||
|
self.qk_rope_head_dim)
|
||||||
|
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
||||||
|
self.deq_scale_qkv = torch.cat(
|
||||||
|
(kv_a_proj_deq_scl, self.q_a_proj.deq_scale), dim=-1).contiguous()
|
||||||
|
|
||||||
|
kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias
|
||||||
|
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
||||||
|
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
|
||||||
|
self.qk_rope_head_dim)
|
||||||
|
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
||||||
|
self.quant_bias_qkv = torch.cat(
|
||||||
|
(kv_a_proj_qt_bias, self.q_a_proj.quant_bias),
|
||||||
|
dim=-1).contiguous()
|
||||||
|
|
||||||
|
wu_q = self.q_proj.weight.data
|
||||||
|
wu_q = wu_q.t().reshape(self.num_heads,
|
||||||
|
self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||||
|
-1)
|
||||||
|
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
|
||||||
|
wu_q = wu_q.reshape(
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
|
||||||
|
-1)
|
||||||
|
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||||
|
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
|
||||||
|
|
||||||
|
qb_deq_scl = self.q_proj.deq_scale.data
|
||||||
|
qb_deq_scl = qb_deq_scl.reshape(
|
||||||
|
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
||||||
|
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
|
||||||
|
self.qb_deq_scl = qb_deq_scl.reshape(
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||||
|
|
||||||
|
qb_qt_bias = self.q_proj.quant_bias.data
|
||||||
|
qb_qt_bias = qb_qt_bias.reshape(
|
||||||
|
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
||||||
|
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
|
||||||
|
self.qb_qt_bias = qb_qt_bias.reshape(
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||||
|
|
||||||
|
device = self.q_a_proj.weight.device
|
||||||
|
self.gamma0 = torch.ones(
|
||||||
|
[self.q_a_proj.weight.shape[-1]],
|
||||||
|
dtype=act_dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.beta0 = torch.zeros(
|
||||||
|
[self.q_a_proj.weight.shape[-1]],
|
||||||
|
dtype=act_dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.gamma1 = self.q_a_layernorm.weight.data
|
||||||
|
self.beta1 = self.q_a_layernorm.bias.data
|
||||||
|
self.gamma2 = self.kv_a_layernorm.weight.data
|
||||||
|
self.quant_scale0 = self.q_a_proj.input_scale.data
|
||||||
|
self.quant_offset0 = self.q_a_proj.input_offset.data
|
||||||
|
self.quant_scale1 = self.q_proj.input_scale.data
|
||||||
|
self.quant_offset1 = self.q_proj.input_offset.data
|
||||||
|
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||||
|
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||||
|
|
||||||
def _compute_prefill_context(
|
def _compute_prefill_context(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
@@ -961,6 +1044,68 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
current_ms_metadata.before_comm_event.wait()
|
current_ms_metadata.before_comm_event.wait()
|
||||||
return self._v_up_proj(attn_output)
|
return self._v_up_proj(attn_output)
|
||||||
|
|
||||||
|
def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
|
||||||
|
bsz = attn_metadata.num_decode_tokens
|
||||||
|
hidden_states = hidden_states[:bsz]
|
||||||
|
|
||||||
|
cos_shape = attn_metadata.decode.cos.shape
|
||||||
|
cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1])
|
||||||
|
sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1])
|
||||||
|
|
||||||
|
decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1]
|
||||||
|
decode_q_nope = torch.empty(
|
||||||
|
(hidden_states.shape[0], self.W_UK_T.shape[0],
|
||||||
|
decode_k_nope.shape[-1]),
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
decode_q_pe = torch.empty(
|
||||||
|
(hidden_states.shape[0], self.W_UK_T.shape[0],
|
||||||
|
decode_k_pe.shape[-1]),
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.ops._C_ascend.mla_preprocess(
|
||||||
|
hidden_states,
|
||||||
|
self.gamma0,
|
||||||
|
self.beta0,
|
||||||
|
self.wd_qkv,
|
||||||
|
self.deq_scale_qkv,
|
||||||
|
self.gamma1,
|
||||||
|
self.beta1,
|
||||||
|
self.wu_q,
|
||||||
|
self.qb_deq_scl,
|
||||||
|
self.gamma2,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
self.W_UK_T,
|
||||||
|
decode_k_nope,
|
||||||
|
decode_k_pe,
|
||||||
|
attn_metadata.slot_mapping[:bsz].flatten(),
|
||||||
|
quant_scale0=self.quant_scale0,
|
||||||
|
quant_offset0=self.quant_offset0,
|
||||||
|
bias0=self.quant_bias_qkv,
|
||||||
|
quant_scale1=self.quant_scale1,
|
||||||
|
quant_offset1=self.quant_offset1,
|
||||||
|
bias1=self.qb_qt_bias,
|
||||||
|
ctkv_scale=self.ctkv_scale,
|
||||||
|
q_nope_scale=self.q_nope_scale,
|
||||||
|
cache_mode="krope_ctkv",
|
||||||
|
quant_mode="per_tensor_quant_asymm",
|
||||||
|
q_out0=decode_q_nope,
|
||||||
|
kv_cache_out0=decode_k_nope,
|
||||||
|
q_out1=decode_q_pe,
|
||||||
|
kv_cache_out1=decode_k_pe,
|
||||||
|
)
|
||||||
|
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
|
||||||
|
self.kv_lora_rank)
|
||||||
|
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
||||||
|
|
||||||
|
decode_preprocess_res = DecodeMLAPreprocessResult(
|
||||||
|
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
|
||||||
|
return decode_preprocess_res, None
|
||||||
|
|
||||||
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
|
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
|
||||||
attn_metadata, need_gather_q_kv):
|
attn_metadata, need_gather_q_kv):
|
||||||
# MLA Preprocess:
|
# MLA Preprocess:
|
||||||
@@ -1065,9 +1210,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
|
|
||||||
# MLA Preprocess
|
# MLA Preprocess
|
||||||
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
|
forward_context = get_forward_context()
|
||||||
layer_name, hidden_states, kv_cache, attn_metadata,
|
if (envs.VLLM_ASCEND_ENABLE_MLAPO and
|
||||||
need_gather_q_kv)
|
(attn_metadata is None or not forward_context.with_prefill)):
|
||||||
|
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
||||||
|
hidden_states, kv_cache, attn_metadata)
|
||||||
|
else:
|
||||||
|
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
|
||||||
|
layer_name, hidden_states, kv_cache, attn_metadata,
|
||||||
|
need_gather_q_kv)
|
||||||
|
|
||||||
if decode_preprocess_res is not None:
|
if decode_preprocess_res is not None:
|
||||||
# MLA Preprocess for decoding
|
# MLA Preprocess for decoding
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
has_kv_transfer_group,
|
has_kv_transfer_group,
|
||||||
@@ -153,3 +154,39 @@ def version_check():
|
|||||||
if full_date >= "20250919":
|
if full_date >= "20250919":
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def round_up(val: int, align: int) -> int:
|
||||||
|
if align == 0:
|
||||||
|
return 0
|
||||||
|
return -(val // -align) * align
|
||||||
|
|
||||||
|
|
||||||
|
def trans_rope_weight(weight, rope_dim):
|
||||||
|
if rope_dim == 0:
|
||||||
|
return weight.contiguous()
|
||||||
|
nope_part = weight[..., :-rope_dim, :]
|
||||||
|
rope_part = weight[..., -rope_dim:, :]
|
||||||
|
reordered_rope_part = torch.cat(
|
||||||
|
(rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
|
||||||
|
return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def transdata(nd_mat, block_size: tuple = (16, 16)):
|
||||||
|
r = round_up(nd_mat.shape[0], block_size[0])
|
||||||
|
c = round_up(nd_mat.shape[1], block_size[1])
|
||||||
|
r_pad = r - nd_mat.shape[0]
|
||||||
|
c_pad = c - nd_mat.shape[1]
|
||||||
|
nd_mat = F.pad(nd_mat, (0, r_pad, 0, c_pad))
|
||||||
|
nz_mat = torch.permute(
|
||||||
|
torch.reshape(
|
||||||
|
nd_mat,
|
||||||
|
(r // block_size[0], block_size[0], c // block_size[1],
|
||||||
|
block_size[1]),
|
||||||
|
),
|
||||||
|
[2, 0, 1, 3],
|
||||||
|
)
|
||||||
|
nz_mat = torch.reshape(
|
||||||
|
nz_mat,
|
||||||
|
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
|
||||||
|
return nz_mat
|
||||||
|
|||||||
Reference in New Issue
Block a user