[Bugfix]fix pcp dcp attn aclgraph (#4066)
### What this PR does / why we need it?
In the DCP-PCP graph mode scenario, there is a shape issue with multiple
batches. This PR fixes this problem.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -26,7 +26,7 @@ import torch.nn as nn
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionLayer, AttentionType)
|
AttentionLayer, AttentionType)
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (get_dcp_group,
|
from vllm.distributed import (get_dcp_group,
|
||||||
get_decode_context_model_parallel_rank,
|
get_decode_context_model_parallel_rank,
|
||||||
get_decode_context_model_parallel_world_size)
|
get_decode_context_model_parallel_world_size)
|
||||||
@@ -387,16 +387,6 @@ class AscendAttentionMetadataBuilder:
|
|||||||
num_computed_tokens_of_pcp_dcp)
|
num_computed_tokens_of_pcp_dcp)
|
||||||
num_computed_tokens_array = num_computed_tokens_array[:
|
num_computed_tokens_array = num_computed_tokens_array[:
|
||||||
num_decodes]
|
num_decodes]
|
||||||
pad_length = common_attn_metadata.num_input_tokens - num_actual_tokens_pcp_padded // self.pcp_size
|
|
||||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and pad_length > 0:
|
|
||||||
pad_tensor = np.zeros(
|
|
||||||
(pad_length, num_computed_tokens_array.shape[1],
|
|
||||||
num_computed_tokens_array.shape[2]),
|
|
||||||
dtype=num_computed_tokens_array.dtype)
|
|
||||||
|
|
||||||
num_computed_tokens_array = np.concatenate(
|
|
||||||
[num_computed_tokens_array, pad_tensor], axis=0)
|
|
||||||
|
|
||||||
batch_seq_mask = (
|
batch_seq_mask = (
|
||||||
num_computed_tokens_array[:, self.pcp_rank,
|
num_computed_tokens_array[:, self.pcp_rank,
|
||||||
self.dcp_rank] == 0)
|
self.dcp_rank] == 0)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@@ -326,9 +327,12 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
|||||||
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
|
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
|
||||||
pcp_rank,
|
pcp_rank,
|
||||||
dcp_rank]
|
dcp_rank]
|
||||||
if (runtime_shape - len(actual_seq_lengths_kv)) > 0:
|
pad_length = runtime_shape - len(actual_seq_lengths_kv)
|
||||||
actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (
|
if pad_length > 0:
|
||||||
runtime_shape - len(actual_seq_lengths_kv))
|
pad_tensor = np.zeros(pad_length,
|
||||||
|
dtype=actual_seq_lengths_kv.dtype)
|
||||||
|
actual_seq_lengths_kv = np.concatenate(
|
||||||
|
[actual_seq_lengths_kv, pad_tensor])
|
||||||
|
|
||||||
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
|
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
|
||||||
attn_metadata
|
attn_metadata
|
||||||
|
|||||||
Reference in New Issue
Block a user