diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index ac27231..a662265 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -16,11 +16,13 @@ from vllm.model_executor.layers.linear import (LinearBase, from vllm.utils import cdiv, round_down from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, maybe_save_kv_layer_to_connector, split_decodes_and_prefills, + trans_rope_weight, transdata, wait_for_kv_layer_from_connector) from vllm_ascend.compilation.acl_graph import get_graph_params from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig @@ -639,6 +641,87 @@ class AscendMLAImpl(MLAAttentionImpl): # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + if envs.VLLM_ASCEND_ENABLE_MLAPO: + self._process_weights_for_fused_mlapo(act_dtype) + + def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): + kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data), dim=-1) + wd_qkv = wd_qkv.t().contiguous() + wd_qkv = transdata(wd_qkv, + block_size=(16, 32)).unsqueeze(0).contiguous() + self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) + + kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale + kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, + self.qk_rope_head_dim) + kv_a_proj_deq_scl = kv_a_proj_deq_scl.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.deq_scale_qkv = torch.cat( + (kv_a_proj_deq_scl, self.q_a_proj.deq_scale), dim=-1).contiguous() + + kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias + kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, + self.qk_rope_head_dim) + kv_a_proj_qt_bias = kv_a_proj_qt_bias.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.quant_bias_qkv = torch.cat( + (kv_a_proj_qt_bias, self.q_a_proj.quant_bias), + dim=-1).contiguous() + + wu_q = self.q_proj.weight.data + wu_q = wu_q.t().reshape(self.num_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + -1) + wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) + wu_q = wu_q.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), + -1) + wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() + self.wu_q = torch_npu.npu_format_cast(wu_q, 29) + + qb_deq_scl = self.q_proj.deq_scale.data + qb_deq_scl = qb_deq_scl.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) + self.qb_deq_scl = qb_deq_scl.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + qb_qt_bias = self.q_proj.quant_bias.data + qb_qt_bias = qb_qt_bias.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) + self.qb_qt_bias = qb_qt_bias.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + device = self.q_a_proj.weight.device + self.gamma0 = torch.ones( + [self.q_a_proj.weight.shape[-1]], + dtype=act_dtype, + device=device, + ) + self.beta0 = torch.zeros( + [self.q_a_proj.weight.shape[-1]], + dtype=act_dtype, + device=device, + ) + self.gamma1 = self.q_a_layernorm.weight.data + self.beta1 = self.q_a_layernorm.bias.data + self.gamma2 = self.kv_a_layernorm.weight.data + self.quant_scale0 = self.q_a_proj.input_scale.data + self.quant_offset0 = self.q_a_proj.input_offset.data + self.quant_scale1 = self.q_proj.input_scale.data + self.quant_offset1 = self.q_proj.input_offset.data + self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) + self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) + def _compute_prefill_context( self, q_nope: torch.Tensor, @@ -961,6 +1044,68 @@ class AscendMLAImpl(MLAAttentionImpl): current_ms_metadata.before_comm_event.wait() return self._v_up_proj(attn_output) + def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata): + bsz = attn_metadata.num_decode_tokens + hidden_states = hidden_states[:bsz] + + cos_shape = attn_metadata.decode.cos.shape + cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1]) + sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) + + decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1] + decode_q_nope = torch.empty( + (hidden_states.shape[0], self.W_UK_T.shape[0], + decode_k_nope.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + decode_q_pe = torch.empty( + (hidden_states.shape[0], self.W_UK_T.shape[0], + decode_k_pe.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + torch.ops._C_ascend.mla_preprocess( + hidden_states, + self.gamma0, + self.beta0, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + cos, + sin, + self.W_UK_T, + decode_k_nope, + decode_k_pe, + attn_metadata.slot_mapping[:bsz].flatten(), + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=self.ctkv_scale, + q_nope_scale=self.q_nope_scale, + cache_mode="krope_ctkv", + quant_mode="per_tensor_quant_asymm", + q_out0=decode_q_nope, + kv_cache_out0=decode_k_nope, + q_out1=decode_q_pe, + kv_cache_out1=decode_k_pe, + ) + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, + self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + + decode_preprocess_res = DecodeMLAPreprocessResult( + decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) + return decode_preprocess_res, None + def _mla_preprocess(self, layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv): # MLA Preprocess: @@ -1065,9 +1210,15 @@ class AscendMLAImpl(MLAAttentionImpl): device=hidden_states.device) # MLA Preprocess - decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( - layer_name, hidden_states, kv_cache, attn_metadata, - need_gather_q_kv) + forward_context = get_forward_context() + if (envs.VLLM_ASCEND_ENABLE_MLAPO and + (attn_metadata is None or not forward_context.with_prefill)): + decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( + hidden_states, kv_cache, attn_metadata) + else: + decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( + layer_name, hidden_states, kv_cache, attn_metadata, + need_gather_q_kv) if decode_preprocess_res is not None: # MLA Preprocess for decoding diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 8dc7efc..007b055 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Any, List import torch +import torch.nn.functional as F import torch_npu from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, @@ -153,3 +154,39 @@ def version_check(): if full_date >= "20250919": return True return False + + +def round_up(val: int, align: int) -> int: + if align == 0: + return 0 + return -(val // -align) * align + + +def trans_rope_weight(weight, rope_dim): + if rope_dim == 0: + return weight.contiguous() + nope_part = weight[..., :-rope_dim, :] + rope_part = weight[..., -rope_dim:, :] + reordered_rope_part = torch.cat( + (rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2) + return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous() + + +def transdata(nd_mat, block_size: tuple = (16, 16)): + r = round_up(nd_mat.shape[0], block_size[0]) + c = round_up(nd_mat.shape[1], block_size[1]) + r_pad = r - nd_mat.shape[0] + c_pad = c - nd_mat.shape[1] + nd_mat = F.pad(nd_mat, (0, r_pad, 0, c_pad)) + nz_mat = torch.permute( + torch.reshape( + nd_mat, + (r // block_size[0], block_size[0], c // block_size[1], + block_size[1]), + ), + [2, 0, 1, 3], + ) + nz_mat = torch.reshape( + nz_mat, + (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])) + return nz_mat