diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 59d6fcb2..258d5e3a 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -869,6 +869,7 @@ class AscendAttentionBackendImpl(AttentionImpl): else: num_heads = self.num_heads + q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2]) k_nope = self.key_cache.view(self.key_cache.shape[0], self.key_cache.shape[1], -1) value = self.value_cache.view(self.key_cache.shape[0], @@ -879,7 +880,7 @@ class AscendAttentionBackendImpl(AttentionImpl): 'num_key_value_heads': self.num_kv_heads, 'input_layout': - "TND", + "BSND", 'atten_mask': None, 'scale': @@ -895,14 +896,12 @@ class AscendAttentionBackendImpl(AttentionImpl): 'block_size': self.key_cache.shape[1], 'actual_seq_lengths_kv': - attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens], - 'actual_seq_lengths': - attn_metadata.actual_seq_lengths_q[:attn_metadata. - num_decode_tokens] + attn_metadata.decode_meta. + num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank], } graph_params = get_graph_params() forward_context: ForwardContext = get_forward_context() - num_tokens = query.shape[0] + num_tokens = q_nope.shape[0] if forward_context.capturing: stream = torch_npu.npu.current_stream() @@ -914,16 +913,16 @@ class AscendAttentionBackendImpl(AttentionImpl): workspace = graph_params.workspaces.get(num_tokens) if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( - query, k_nope, value, **common_kwargs) + q_nope, k_nope, value, **common_kwargs) update_graph_params_workspaces(num_tokens, weak_ref_tensors(workspace)) - attn_out = torch.empty_like(query) + attn_out = torch.empty_like(q_nope) attn_lse = torch.empty((num_tokens, num_heads, 1, 1), dtype=torch.float, - device=query.device) + device=q_nope.device) graph_params.attn_params[num_tokens].append( - (weak_ref_tensors(query), weak_ref_tensors(k_nope), + (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), weak_ref_tensors(value), self.num_heads, self.num_kv_heads, self.scale, attn_metadata.block_tables, self.key_cache.shape[1], attn_metadata.decode_meta. @@ -933,7 +932,7 @@ class AscendAttentionBackendImpl(AttentionImpl): self.pcp_rank, self.dcp_rank, self.dcp_size)) torch.npu.graph_task_group_begin(stream) torch_npu.npu_fused_infer_attention_score.out( - query, + q_nope, k_nope, value, **common_kwargs, @@ -943,7 +942,11 @@ class AscendAttentionBackendImpl(AttentionImpl): graph_params.handles[num_tokens].append(handle) else: attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( - query, k_nope, value, **common_kwargs) + q_nope, k_nope, value, **common_kwargs) + + attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2], + attn_out.shape[3]) + attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1) attn_out_lse_list = [] # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] @@ -1017,7 +1020,8 @@ class AscendAttentionBackendImpl(AttentionImpl): prefill_query, key, value, attn_metadata, output[num_decode_tokens:], prefill_query.shape[0]) attn_metadata.seq_lens = seq_lens_back - output[num_decode_tokens:] = output_prefill + output[num_decode_tokens:output_prefill.shape[0] + + num_decode_tokens] = output_prefill return output def forward( @@ -1089,7 +1093,9 @@ class AscendAttentionBackendImpl(AttentionImpl): if has_prefill: if self.pcp_size > 1: kv = torch.cat([key, value], dim=-1) - all_kv = get_pcp_group().all_gather(kv, dim=0) + num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size + all_kv = get_pcp_group().all_gather( + kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0) pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None all_kv = torch.index_select(all_kv, 0, pcp_allgather_restore_idx) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 41476ccc..d9e08c84 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -301,9 +301,9 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): ): (q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table, block_size, actual_seq_lengths_kv, attn_output, - softmax_lse, cp_rank, dcp_rank, dcp_size) = param + softmax_lse, pcp_rank, dcp_rank, dcp_size) = param actual_seq_lengths_kv = forward_context.attn_metadata[ - key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank, + key].decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank] pad_length = runtime_shape - len(actual_seq_lengths_kv) pad_tensor = np.zeros(pad_length, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0abce926..4e88b405 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -476,6 +476,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) + self.num_actual_tokens_pcp_padded = 0 if self.speculative_config and self.pcp_size > 1: self.input_ids_pcp_full = torch.zeros(self.max_num_tokens, dtype=torch.int32, @@ -1915,7 +1916,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): hidden_states = hidden_states[:-pad_size, :] if self.pcp_size > 1: - hidden_states = get_pcp_group().all_gather(hidden_states, 0) + hidden_states = get_pcp_group().all_gather( + hidden_states[:self.num_actual_tokens_pcp_padded // + self.pcp_size], 0) hidden_states = torch.index_select( hidden_states, 0, self.pcp_allgather_restore_idx[:hidden_states.shape[0]]) @@ -4304,6 +4307,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs]) num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size + self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded long_seq_metadata = None if self.pcp_size * self.dcp_size > 1: long_seq_metadata = AscendPrefillContextParallelMetadata(