mooncake connector support pipeline parallel & fix pp with flashcomm1 (#4054)

### What this PR does / why we need it?
To support pipeline parallel with PD disaggregation, this PR support PP
in mooncake connector and fix other bugs when enable pp with other
optimization params, including following changes:
- mooncake connector support pp in prefill, we do not support decode pp
currently
- fix bugs when enable both pp and flashcomm1
- optimize ascend-scheduler to support full batch in multiple pipeline
stages, original implementation would cause all pipeline stages
batch_size total summed to max_num_seq, which makes pipeline is not
full, this optimization can make all stages running with full batch_size
= max_num_seq, the same changes will contribute to vllm scheduler too.

### Does this PR introduce _any_ user-facing change?
add `pp_size` in mooncake connector kv_connector_extra_config
```
"kv_connector_extra_config": {
            "use_ascend_direct": true,
            "prefill": {
                    "dp_size": 1,
                    "tp_size": 4,
                    "pp_size": 4
             },
             "decode": {
                    "dp_size": 16,
                    "tp_size": 1
             }
        }
```

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Signed-off-by: 秋刀鱼 <jaychou1620@Gmail.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: zss <zss@qq.com>
Co-authored-by: zss <3265779424@qq.com>
This commit is contained in:
lidenghui1110
2025-12-10 16:01:43 +08:00
committed by GitHub
parent ce5872705e
commit a82b0fa70e
5 changed files with 394 additions and 141 deletions

View File

@@ -46,7 +46,8 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
@@ -1765,11 +1766,22 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
else:
assert intermediate_tensors is not None
assert self.intermediate_tensors is not None
# If both flashcomm1 and pp are used simultaneously,
# the shape of the received data and the shape of the space to be copied to will not match,
# requiring a recalculation of the incoming data's shape.
tp_size = get_tensor_model_parallel_world_size()
num_input_tokens_with_flashcomm1 = num_input_tokens
if enable_sp():
num_input_tokens_with_flashcomm1 = (num_input_tokens +
tp_size - 1) // tp_size
for k, v in intermediate_tensors.items():
self.intermediate_tensors[k][:num_input_tokens].copy_(
v[:num_input_tokens], non_blocking=True)
self.intermediate_tensors[
k][:num_input_tokens_with_flashcomm1].copy_(
v[:num_input_tokens_with_flashcomm1],
non_blocking=True)
intermediate_tensors = IntermediateTensors({
k: v[:num_input_tokens]
k:
v[:num_input_tokens_with_flashcomm1]
for k, v in self.intermediate_tensors.items()
})
@@ -2044,7 +2056,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
update_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens)
if get_forward_context().sp_enabled:
if get_forward_context().sp_enabled and not isinstance(
hidden_states, IntermediateTensors):
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
pad_size = get_forward_context().pad_size
if pad_size > 0:
@@ -2366,7 +2379,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType._910B}:
if (num_tokens <= self.mc2_tokens_capacity
and self.parallel_config.world_size_across_dp >= 16):
and self.parallel_config.world_size_across_dp /
self.parallel_config.pipeline_parallel_size >= 16):
moe_comm_type = MoECommType.MC2
else:
# Currently, w4a8_dynamic does not support allgatherep
@@ -3131,10 +3145,16 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
# When PP and flashcomm1 are enabled, during dummy_run the estimated space should divide num_tokens by tp_size;
# otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading to incorrect memory estimation and potentially causing OOM.
actual_tokens = num_tokens
if enable_sp():
tp_size = get_tensor_model_parallel_world_size()
actual_tokens = num_tokens // tp_size
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
batch_size=actual_tokens,
dtype=self.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({