diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 6ea9058f..098e77c5 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -26,7 +26,7 @@ import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) -from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size) @@ -387,16 +387,6 @@ class AscendAttentionMetadataBuilder: num_computed_tokens_of_pcp_dcp) num_computed_tokens_array = num_computed_tokens_array[: num_decodes] - pad_length = common_attn_metadata.num_input_tokens - num_actual_tokens_pcp_padded // self.pcp_size - if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and pad_length > 0: - pad_tensor = np.zeros( - (pad_length, num_computed_tokens_array.shape[1], - num_computed_tokens_array.shape[2]), - dtype=num_computed_tokens_array.dtype) - - num_computed_tokens_array = np.concatenate( - [num_computed_tokens_array, pad_tensor], axis=0) - batch_seq_mask = ( num_computed_tokens_array[:, self.pcp_rank, self.dcp_rank] == 0) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index af6322ab..5c65936e 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Any, Callable, Optional from unittest.mock import patch +import numpy as np import torch import torch_npu import vllm.envs as envs @@ -326,9 +327,12 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank] - if (runtime_shape - len(actual_seq_lengths_kv)) > 0: - actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * ( - runtime_shape - len(actual_seq_lengths_kv)) + pad_length = runtime_shape - len(actual_seq_lengths_kv) + if pad_length > 0: + pad_tensor = np.zeros(pad_length, + dtype=actual_seq_lengths_kv.dtype) + actual_seq_lengths_kv = np.concatenate( + [actual_seq_lengths_kv, pad_tensor]) actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: attn_metadata