diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index b9e51a3..f50fe56 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -21,7 +21,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor +from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor from vllm_ascend.worker.npu_input_batch import InputBatch if TYPE_CHECKING: @@ -579,13 +579,18 @@ class AscendMLAImpl(MLAAttentionImpl): " please make sure after the tensor parallel split, num_heads / num_kv_heads in " "{32, 64, 128}.") - def _v_up_proj_and_o_proj(self, x): + def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB + npu_prefetch(self.o_proj.weight, + x, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) return self.o_proj(x, is_prefill=False)[0] # Return `ql_nope`, `q_pe` @@ -864,7 +869,6 @@ class AscendMLAImpl(MLAAttentionImpl): sin: torch.Tensor, kv_cache: Tuple, slots: torch.Tensor, - enable_multistream_mla: bool = False, ): B = hidden_states.shape[0] @@ -874,21 +878,18 @@ class AscendMLAImpl(MLAAttentionImpl): # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" - with npu_stream_switch("mla_secondary", - 0, - enabled=enable_multistream_mla): - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( - kv, - self.kv_a_layernorm.weight, - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode=cache_mode, - ) - return k_pe, k_nope + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + kv, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + ) + return k_pe, k_nope, kv def exec_kv_prefill( self, @@ -940,6 +941,7 @@ class AscendMLAImpl(MLAAttentionImpl): k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AscendMLAMetadata, + enable_multistream_mla: bool = False, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None @@ -1020,7 +1022,8 @@ class AscendMLAImpl(MLAAttentionImpl): out=attn_output) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: - return self._v_up_proj_and_o_proj(attn_output) + return self._v_up_proj_and_o_proj(attn_output, + enable_multistream_mla) else: current_ms_metadata.before_comm_event.record() with torch.npu.stream(current_ms_metadata.comm_stream): @@ -1037,6 +1040,7 @@ class AscendMLAImpl(MLAAttentionImpl): attn_metadata: M, output: Optional[torch.Tensor] = None, enable_multistream_mla: bool = False, + ckq: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." if attn_metadata is None: @@ -1091,6 +1095,15 @@ class AscendMLAImpl(MLAAttentionImpl): sin = sin[attn_metadata.decode.input_positions] cos = cos[:, None, None, :] sin = sin[:, None, None, :] + with npu_stream_switch("mla_secondary", + 0, + enabled=enable_multistream_mla): + npu_wait_tensor(hidden_states_or_kv_c_normed, + ckq, + enabled=enable_multistream_mla) + decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( + hidden_states_or_kv_c_normed, cos, sin, kv_cache, + attn_metadata.slot_mapping) # Without explicitly controlling the order, IndexByTensor operations # would be placed after `matmul W_KV_T` hindering the overlapping of # KvRmsNormRopeCache and SingleRope. @@ -1100,12 +1113,13 @@ class AscendMLAImpl(MLAAttentionImpl): npu_wait_tensor(decode_hs_or_q_c, sin, enabled=enable_multistream_mla) + npu_wait_tensor(decode_hs_or_q_c, + decode_kv, + enabled=enable_multistream_mla) + decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_hs_or_q_c) if self.running_in_graph: - decode_k_pe, decode_k_nope = self.exec_kv( - hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping, enable_multistream_mla) with npu_stream_switch("mla_secondary", 0, enabled=enable_multistream_mla): @@ -1194,7 +1208,8 @@ class AscendMLAImpl(MLAAttentionImpl): if self.running_in_graph: return self._forward_decode(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, - kv_cache, attn_metadata) + kv_cache, attn_metadata, + enable_multistream_mla) else: output_decode = self._forward_decode(decode_ql_nope, decode_q_pe, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 39ae170..52360ae 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -74,8 +74,7 @@ from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod -from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, - npu_wait_tensor) +from vllm_ascend.utils import dispose_tensor, npu_prefetch class CustomDeepseekV2SiluAndMul(SiluAndMul): @@ -567,12 +566,12 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): and attn_metadata.num_decodes > 0) forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} if self.q_lora_rank is not None: + npu_prefetch(self.q_a_proj.weight, + hidden_states, + enabled=enable_multistream_mla) ckq = self.q_a_proj(hidden_states)[0] - npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla) - with npu_stream_switch("mla_secondary", - 0, - enabled=enable_multistream_mla): - hidden_states_or_q_c = self.q_a_layernorm(ckq) + hidden_states_or_q_c = self.q_a_layernorm(ckq) + forward_kwargs['ckq'] = ckq else: hidden_states_or_q_c = hidden_states if self.torchair_graph_enabled: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 2fcc4f0..016cfd2 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -416,6 +416,20 @@ def npu_wait_tensor(self: torch.Tensor, return _npu_wait_tensor(self, dependency) if enabled else self +# TODO(wxy): Move to ops module +def npu_prefetch(input: torch.Tensor, + dependency: torch.Tensor, + max_size: int = 0, + *, + enabled: bool = True): + if not enabled: + return + input_size = input.element_size() * input.numel() + if max_size <= 0 or max_size > input_size: + max_size = input_size + torch_npu.npu_prefetch(input, dependency, max_size) + + # TODO(zzzzwwjj): move this into forward_context class FusedMoEState(Enum): AllGather = 0