[Feat][SP] Suport SP for VL MoE models (#7044)
### What this PR does / why we need it?
2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.
### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.
### How was this patch tested?
- Model: Qwen3-VL-30B-A3B,
- TP4 DP2
- 100 reqs
- max concurrency 1
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k | 429.40 | 323.3 |
| 16k | 1297.01 | 911.74 |
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -33,7 +33,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEPrepareOutput
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
||||
from vllm_ascend.utils import enable_sp, enable_sp_by_pass, npu_stream_switch, prefill_context_parallel_enable
|
||||
|
||||
|
||||
class PrepareAndFinalize(ABC):
|
||||
@@ -324,7 +324,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
Returns:
|
||||
MoEPrepareOutput with global tensors.
|
||||
"""
|
||||
if enable_sp():
|
||||
if enable_sp() or enable_sp_by_pass():
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce)
|
||||
@@ -433,7 +433,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
Returns:
|
||||
Tensor with shape [local_num_tokens, hidden_size]
|
||||
"""
|
||||
if enable_sp():
|
||||
if enable_sp() or enable_sp_by_pass():
|
||||
return self._finalize_with_ep_group(hidden_states)
|
||||
|
||||
return self._finalize_with_dp_group(hidden_states, reduce_results)
|
||||
|
||||
@@ -17,7 +17,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
||||
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
from vllm_ascend.utils import enable_sp_by_pass, npu_stream_switch, prefetch_stream
|
||||
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
@@ -43,7 +43,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
except AssertionError:
|
||||
return x
|
||||
|
||||
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
|
||||
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled or (enable_sp_by_pass() and is_ep_comm)
|
||||
if flash_comm_v1_enabled and label:
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
@@ -53,6 +53,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
x = x[:-pad_size]
|
||||
else:
|
||||
x = get_ep_group().all_gather(x, 0)
|
||||
if enable_sp_by_pass(): # TODO: do unpad
|
||||
return x
|
||||
# unpad
|
||||
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
|
||||
@@ -74,7 +76,11 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
if not getattr(forward_context, "flash_comm_v1_enabled", False):
|
||||
flash_comm_v1_enabled = getattr(forward_context, "flash_comm_v1_enabled", False) or (
|
||||
enable_sp_by_pass() and is_ep_comm
|
||||
)
|
||||
|
||||
if not flash_comm_v1_enabled:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
@@ -84,6 +90,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
return tensor_model_parallel_reduce_scatter(x, 0)
|
||||
else:
|
||||
if enable_sp_by_pass():
|
||||
return get_ep_group().reduce_scatter(x.view(-1, *x.shape[1:]), 0)
|
||||
# padding
|
||||
dp_size = get_dp_group().world_size
|
||||
num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
||||
@@ -107,7 +115,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if _EXTRA_CTX.flash_comm_v1_enabled:
|
||||
if _EXTRA_CTX.flash_comm_v1_enabled or enable_sp_by_pass():
|
||||
return torch.empty(
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user