[feature] support pcp + mtp (in pd co-locate scenario) (#4098)

1. support pcp + mtp in pd co-locate scenario
2. llmdatadist connector pcp related bugfix and cleancode

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-11-12 17:22:21 +08:00
committed by GitHub
parent 1b4ce63ec9
commit a123f355e9
6 changed files with 246 additions and 97 deletions

View File

@@ -371,7 +371,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
runtime_shape, speculative_config):
runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
@@ -388,16 +388,14 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
if speculative_config and speculative_config.method == "deepseek_mtp":
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
len(seq_len))
else:
pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length,
dtype=seq_len.dtype,
device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
# For pcp + spec decode, we flatten seq_lens
# to avoid irregular spec_attn_mask shape,
# so there's no need to divide runtime_shape by spec_multiple
pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length,
dtype=seq_len.dtype,
device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
torch.npu.graph_task_update_begin(update_stream, handle)