[Bugfix]fix pcp dcp attn aclgraph (#4066)

### What this PR does / why we need it?
In the DCP-PCP graph mode scenario, there is a shape issue with multiple
batches. This PR fixes this problem.

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-11-08 18:47:12 +08:00
committed by GitHub
parent 48094148f8
commit 1d7cb5880a
2 changed files with 8 additions and 14 deletions

View File

@@ -26,7 +26,7 @@ import torch.nn as nn
import torch_npu import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType) AttentionLayer, AttentionType)
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group, from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank, get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size) get_decode_context_model_parallel_world_size)
@@ -387,16 +387,6 @@ class AscendAttentionMetadataBuilder:
num_computed_tokens_of_pcp_dcp) num_computed_tokens_of_pcp_dcp)
num_computed_tokens_array = num_computed_tokens_array[: num_computed_tokens_array = num_computed_tokens_array[:
num_decodes] num_decodes]
pad_length = common_attn_metadata.num_input_tokens - num_actual_tokens_pcp_padded // self.pcp_size
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and pad_length > 0:
pad_tensor = np.zeros(
(pad_length, num_computed_tokens_array.shape[1],
num_computed_tokens_array.shape[2]),
dtype=num_computed_tokens_array.dtype)
num_computed_tokens_array = np.concatenate(
[num_computed_tokens_array, pad_tensor], axis=0)
batch_seq_mask = ( batch_seq_mask = (
num_computed_tokens_array[:, self.pcp_rank, num_computed_tokens_array[:, self.pcp_rank,
self.dcp_rank] == 0) self.dcp_rank] == 0)

View File

@@ -7,6 +7,7 @@ from dataclasses import dataclass
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from unittest.mock import patch from unittest.mock import patch
import numpy as np
import torch import torch
import torch_npu import torch_npu
import vllm.envs as envs import vllm.envs as envs
@@ -326,9 +327,12 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
pcp_rank, pcp_rank,
dcp_rank] dcp_rank]
if (runtime_shape - len(actual_seq_lengths_kv)) > 0: pad_length = runtime_shape - len(actual_seq_lengths_kv)
actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * ( if pad_length > 0:
runtime_shape - len(actual_seq_lengths_kv)) pad_tensor = np.zeros(pad_length,
dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate(
[actual_seq_lengths_kv, pad_tensor])
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
attn_metadata attn_metadata