diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md index b769a31..90002db 100644 --- a/docs/source/user_guide/additional_config.md +++ b/docs/source/user_guide/additional_config.md @@ -44,6 +44,7 @@ The details of each config option are as follows: | `use_cached_graph` | bool | `False` | Whether to use cached graph | | `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache | | `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty | +| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout | **ascend_scheduler_config** @@ -64,7 +65,8 @@ A full example of additional configuration is as follows: "use_cached_graph": true, "graph_batch_sizes": [1, 2, 4, 8], "graph_batch_sizes_init": false, - "enable_multistream_moe": false + "enable_multistream_moe": false, + "enable_kv_nz": false }, "ascend_scheduler_config": { "enabled": true, diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index abb6039..f7441f5 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -58,6 +58,7 @@ class TorchairGraphConfig: "enable_multistream_moe", False) self.enable_view_optimize = torchair_graph_config.get( "enable_view_optimize", True) + self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False) if not isinstance(self.graph_batch_sizes, list): raise TypeError("graph_batch_sizes must be list[int]") diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 7f60342..9e10815 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -480,6 +480,7 @@ class AscendMLAImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz # Adapt torch air graph mode with spec decoding. speculative_config = get_current_vllm_config().speculative_config if speculative_config is not None: @@ -662,6 +663,7 @@ class AscendMLAImpl(MLAAttentionImpl): kv = self.kv_a_proj_with_mqa(hidden_states)[0] # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( kv, self.kv_a_layernorm.weight, @@ -671,7 +673,37 @@ class AscendMLAImpl(MLAAttentionImpl): kv_cache[1], kv_cache[0], epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode="PA", + cache_mode=cache_mode, + ) + return k_pe, k_nope + + def exec_kv_prefill( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + + B = hidden_states.shape[0] + N = self.num_kv_heads + S = 1 + kv = self.kv_a_proj_with_mqa(hidden_states)[0] + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA" + _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( + kv, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + is_output_kv=True, ) return k_pe, k_nope @@ -709,34 +741,42 @@ class AscendMLAImpl(MLAAttentionImpl): # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: assert num_tokens % self.spec_token_num == 0 - q_nope = (q_nope.view( - num_tokens // (self.spec_token_num + 1), - self.spec_token_num + 1, - self.num_heads, - -1, - ).transpose(1, 2).contiguous()) - q_pe = (q_pe.view( - num_tokens // (self.spec_token_num + 1), - self.spec_token_num + 1, - self.num_heads, - -1, - ).transpose(1, 2).contiguous()) + q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1), + self.spec_token_num + 1, self.num_heads, + -1) + q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1), + self.spec_token_num + 1, self.num_heads, -1) + if not self.enable_kv_nz: + q_nope = q_nope.transpose(1, 2).contiguous() + q_pe = q_pe.transpose(1, 2).contiguous() sparse_mode = 3 spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore else: - q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + if self.enable_kv_nz: + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + else: + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) sparse_mode = 0 spec_attn_mask = None # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] block_size = kv_c_and_k_pe_cache[0].shape[1] - k_nope = k_nope.view(-1, self.num_kv_heads, block_size, - self.kv_lora_rank) - k_pe = k_pe.view(-1, self.num_kv_heads, block_size, - self.qk_rope_head_dim) + if self.enable_kv_nz: + k_nope = k_nope.view(-1, self.num_kv_heads, + self.kv_lora_rank // 16, block_size, 16) + k_pe = k_pe.view(-1, self.num_kv_heads, + self.qk_rope_head_dim // 16, block_size, 16) + input_layout = "BSND" + else: + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + input_layout = "BNSD" - attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( + attn_output, _ = torch_npu.npu_fused_infer_attention_score( q_nope, k_nope, k_nope, @@ -744,7 +784,7 @@ class AscendMLAImpl(MLAAttentionImpl): key_rope=k_pe, num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, - input_layout="BNSD", + input_layout=input_layout, atten_mask=spec_attn_mask, sparse_mode=sparse_mode, scale=self.scale, @@ -793,10 +833,11 @@ class AscendMLAImpl(MLAAttentionImpl): ] num_actual_toks = attn_metadata.num_actual_tokens if k_pe is None and not self.running_in_graph: - kv_c, k_pe = self.kv_a_proj_with_mqa( - hidden_states_or_kv_c_normed)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + if not self.torchair_graph_enabled: + kv_c, k_pe = self.kv_a_proj_with_mqa( + hidden_states_or_kv_c_normed)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) else: kv_c_normed = hidden_states_or_kv_c_normed assert attn_metadata.num_decodes is not None and \ @@ -809,16 +850,18 @@ class AscendMLAImpl(MLAAttentionImpl): # Inputs and outputs may be padded for CUDA graphs output_padded = output output = output[:num_actual_toks, ...] - kv_c_normed = kv_c_normed[:num_actual_toks, ...] - prefill_k_c_normed = kv_c_normed[num_decode_tokens:] + if not self.torchair_graph_enabled: + kv_c_normed = kv_c_normed[:num_actual_toks, ...] + prefill_k_c_normed = kv_c_normed[num_decode_tokens:] if not self.running_in_graph: hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - k_pe = k_pe[:num_actual_toks, ...] - k_pe = k_pe.unsqueeze(1) - decode_k_pe = k_pe[:num_decode_tokens] - prefill_k_pe = k_pe[num_decode_tokens:] + if not self.torchair_graph_enabled: + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] else: decode_hs_or_q_c = hidden_states_or_q_c if has_decode: @@ -855,22 +898,25 @@ class AscendMLAImpl(MLAAttentionImpl): prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] if self.torchair_graph_enabled: num_tokens = prefill_hs_or_q_c.shape[0] + seq_len = self.rotary_emb.max_position_embeddings + cos = self.rotary_emb.cos_cached[:seq_len].to( + dtype=prefill_q_pe.dtype) + sin = self.rotary_emb.sin_cached[:seq_len].to( + dtype=prefill_q_pe.dtype) + cos = cos[attn_metadata.prefill.input_positions] + sin = sin[attn_metadata.prefill.input_positions] + cos = cos[:, None, None, :] + sin = sin[:, None, None, :] + + prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( + hidden_states_or_kv_c_normed, cos, sin, kv_cache, + attn_metadata.slot_mapping) + + kv_c_normed = prefill_k_nope[:num_actual_toks, ...] + prefill_k_c_normed = prefill_k_nope[num_decode_tokens:] prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, -1) - if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding': - # NOTE: When scaling not specified - ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape - prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1) - prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1) - prefill_q_pe, prefill_k_pe = self.rotary_emb( - attn_metadata.prefill.input_positions, prefill_q_pe, - prefill_k_pe) - prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape) - prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape) - else: - prefill_q_pe, prefill_k_pe = self.rotary_emb( - attn_metadata.prefill.input_positions, prefill_q_pe, - prefill_k_pe) prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1) else: prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(