[Feat][SP] Suport SP for VL MoE models (#7044)

### What this PR does / why we need it?

2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.


### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.


### How was this patch tested?
- Model: Qwen3-VL-30B-A3B, 
- TP4 DP2
- 100 reqs
- max concurrency 1

| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k         | 429.40               | 323.3                  |
| 16k        | 1297.01              | 911.74                |

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2026-03-24 17:16:00 +08:00
committed by GitHub
parent 9615bc33fd
commit 5d12446573
21 changed files with 947 additions and 54 deletions

View File

@@ -33,7 +33,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEPrepareOutput
from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
from vllm_ascend.utils import enable_sp, enable_sp_by_pass, npu_stream_switch, prefill_context_parallel_enable
class PrepareAndFinalize(ABC):
@@ -324,7 +324,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
Returns:
MoEPrepareOutput with global tensors.
"""
if enable_sp():
if enable_sp() or enable_sp_by_pass():
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce)
@@ -433,7 +433,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
Returns:
Tensor with shape [local_num_tokens, hidden_size]
"""
if enable_sp():
if enable_sp() or enable_sp_by_pass():
return self._finalize_with_ep_group(hidden_states)
return self._finalize_with_dp_group(hidden_states, reduce_results)

View File

@@ -17,7 +17,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
from vllm_ascend.ops.triton.muls_add import muls_add_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
from vllm_ascend.utils import enable_sp_by_pass, npu_stream_switch, prefetch_stream
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
@@ -43,7 +43,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
except AssertionError:
return x
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled or (enable_sp_by_pass() and is_ep_comm)
if flash_comm_v1_enabled and label:
dp_metadata = forward_context.dp_metadata
if dp_metadata is None or not is_ep_comm:
@@ -53,6 +53,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
x = x[:-pad_size]
else:
x = get_ep_group().all_gather(x, 0)
if enable_sp_by_pass(): # TODO: do unpad
return x
# unpad
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
@@ -74,7 +76,11 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
except AssertionError:
return tensor_model_parallel_all_reduce(x)
if not getattr(forward_context, "flash_comm_v1_enabled", False):
flash_comm_v1_enabled = getattr(forward_context, "flash_comm_v1_enabled", False) or (
enable_sp_by_pass() and is_ep_comm
)
if not flash_comm_v1_enabled:
return tensor_model_parallel_all_reduce(x)
dp_metadata = forward_context.dp_metadata
@@ -84,6 +90,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
x = F.pad(x, (0, 0, 0, pad_size))
return tensor_model_parallel_reduce_scatter(x, 0)
else:
if enable_sp_by_pass():
return get_ep_group().reduce_scatter(x.view(-1, *x.shape[1:]), 0)
# padding
dp_size = get_dp_group().world_size
num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu
@@ -107,7 +115,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
if _EXTRA_CTX.flash_comm_v1_enabled:
if _EXTRA_CTX.flash_comm_v1_enabled or enable_sp_by_pass():
return torch.empty(
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
)