[Bugfix]fix pcp dcp attn aclgraph (#4066)
### What this PR does / why we need it?
In the DCP-PCP graph mode scenario, there is a shape issue with multiple
batches. This PR fixes this problem.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
@@ -326,9 +327,12 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
|
||||
pcp_rank,
|
||||
dcp_rank]
|
||||
if (runtime_shape - len(actual_seq_lengths_kv)) > 0:
|
||||
actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (
|
||||
runtime_shape - len(actual_seq_lengths_kv))
|
||||
pad_length = runtime_shape - len(actual_seq_lengths_kv)
|
||||
if pad_length > 0:
|
||||
pad_tensor = np.zeros(pad_length,
|
||||
dtype=actual_seq_lengths_kv.dtype)
|
||||
actual_seq_lengths_kv = np.concatenate(
|
||||
[actual_seq_lengths_kv, pad_tensor])
|
||||
|
||||
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
|
||||
attn_metadata
|
||||
|
||||
Reference in New Issue
Block a user