[feature] support pcp + mtp (in pd co-locate scenario) (#4098)
1. support pcp + mtp in pd co-locate scenario
2. llmdatadist connector pcp related bugfix and cleancode
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -371,7 +371,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
|
||||
|
||||
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
runtime_shape, speculative_config):
|
||||
runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
@@ -388,16 +388,14 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
decode_meta = forward_context.attn_metadata[key].decode
|
||||
seq_len = decode_meta.cp_seq_len
|
||||
|
||||
if speculative_config and speculative_config.method == "deepseek_mtp":
|
||||
spec_multiple = speculative_config.num_speculative_tokens + 1
|
||||
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
|
||||
len(seq_len))
|
||||
else:
|
||||
pad_length = runtime_shape - len(seq_len)
|
||||
pad_tensor = torch.zeros(pad_length,
|
||||
dtype=seq_len.dtype,
|
||||
device=seq_len.device)
|
||||
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
|
||||
# For pcp + spec decode, we flatten seq_lens
|
||||
# to avoid irregular spec_attn_mask shape,
|
||||
# so there's no need to divide runtime_shape by spec_multiple
|
||||
pad_length = runtime_shape - len(seq_len)
|
||||
pad_tensor = torch.zeros(pad_length,
|
||||
dtype=seq_len.dtype,
|
||||
device=seq_len.device)
|
||||
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
|
||||
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user