[feat]decode convert bsnd to tnd and fix bug when pcp and dcp (#3980)
### What this PR does / why we need it?
1、in attention_v1 module, convert bsnd t0 tnd when pcp and dcp
2、fix tochair bug: service startup problem
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -7,7 +7,6 @@ from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
@@ -302,16 +301,25 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
|
||||
block_table, block_size, actual_seq_lengths_kv, attn_output,
|
||||
softmax_lse, pcp_rank, dcp_rank, dcp_size) = param
|
||||
actual_seq_lengths_kv = forward_context.attn_metadata[
|
||||
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank,
|
||||
dcp_rank]
|
||||
pad_length = runtime_shape - len(actual_seq_lengths_kv)
|
||||
pad_tensor = np.zeros(pad_length,
|
||||
dtype=actual_seq_lengths_kv.dtype)
|
||||
actual_seq_lengths_kv = np.concatenate(
|
||||
[actual_seq_lengths_kv, pad_tensor])
|
||||
block_table, block_size, actual_seq_lengths_kv,
|
||||
actual_seq_lengths_q, attn_output, softmax_lse, dcp_size,
|
||||
pcp_rank, dcp_rank) = param
|
||||
attn_metadata = forward_context.attn_metadata[key]
|
||||
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
|
||||
pcp_rank,
|
||||
dcp_rank]
|
||||
if (runtime_shape - len(actual_seq_lengths_kv)) > 0:
|
||||
actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (
|
||||
runtime_shape - len(actual_seq_lengths_kv))
|
||||
|
||||
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
|
||||
attn_metadata
|
||||
.
|
||||
num_decode_tokens]
|
||||
if (runtime_shape - len(actual_seq_lengths_q)):
|
||||
actual_seq_lengths_q = actual_seq_lengths_q + [
|
||||
actual_seq_lengths_q[-1]
|
||||
] * (runtime_shape - len(actual_seq_lengths_q))
|
||||
if dcp_size > 1:
|
||||
num_heads = num_heads * dcp_size
|
||||
|
||||
@@ -323,7 +331,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
value,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
input_layout="BSND",
|
||||
input_layout="TND",
|
||||
atten_mask=None,
|
||||
scale=scale,
|
||||
antiquant_mode=0,
|
||||
@@ -332,6 +340,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
block_table=block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
actual_seq_lengths=actual_seq_lengths_q,
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
out=[attn_output, softmax_lse])
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
Reference in New Issue
Block a user