[Feat][SP] Suport SP for VL MoE models (#7044)
### What this PR does / why we need it?
2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.
### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.
### How was this patch tested?
- Model: Qwen3-VL-30B-A3B,
- TP4 DP2
- 100 reqs
- max concurrency 1
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k | 429.40 | 323.3 |
| 16k | 1297.01 | 911.74 |
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
40
vllm_ascend/compilation/passes/allgather_chunk_noop_pass.py
Normal file
40
vllm_ascend/compilation/passes/allgather_chunk_noop_pass.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
class AllGatherChunkNoOpCleanupPass(VllmInductorPass):
|
||||
"""Fold all_gather + sequence_parallel_chunk_impl into identity."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_allgather_chunk_noop_cleanup_pass")
|
||||
self._register_patterns()
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
|
||||
|
||||
def _empty(self, *args, **kwargs):
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
|
||||
|
||||
def _register_patterns(self) -> None:
|
||||
def pattern(input: torch.Tensor) -> torch.Tensor:
|
||||
gathered = self._all_gather(input)
|
||||
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
|
||||
|
||||
def replacement(input: torch.Tensor) -> torch.Tensor:
|
||||
return input
|
||||
|
||||
pm.register_replacement(pattern, replacement, [self._empty(8, 16)], pm.fwd_only, self.patterns)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.begin()
|
||||
matched_count = self.patterns.apply(graph)
|
||||
logger.debug("AllGatherChunkNoOpCleanupPass replaced %s patterns", matched_count)
|
||||
self.end_and_log()
|
||||
Reference in New Issue
Block a user