Files
xc-llm-ascend/vllm_ascend/compilation/passes/allgather_chunk_noop_pass.py
realliujiaxu 5d12446573 [Feat][SP] Suport SP for VL MoE models (#7044)
### What this PR does / why we need it?

2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.


### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.


### How was this patch tested?
- Model: Qwen3-VL-30B-A3B, 
- TP4 DP2
- 100 reqs
- max concurrency 1

| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k         | 429.40               | 323.3                  |
| 16k        | 1297.01              | 911.74                |

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
2026-03-24 17:16:00 +08:00

41 lines
1.7 KiB
Python

import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.logger import logger
class AllGatherChunkNoOpCleanupPass(VllmInductorPass):
"""Fold all_gather + sequence_parallel_chunk_impl into identity."""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_allgather_chunk_noop_cleanup_pass")
self._register_patterns()
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
def _empty(self, *args, **kwargs):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
def _register_patterns(self) -> None:
def pattern(input: torch.Tensor) -> torch.Tensor:
gathered = self._all_gather(input)
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
def replacement(input: torch.Tensor) -> torch.Tensor:
return input
pm.register_replacement(pattern, replacement, [self._empty(8, 16)], pm.fwd_only, self.patterns)
def __call__(self, graph: torch.fx.Graph) -> None:
self.begin()
matched_count = self.patterns.apply(graph)
logger.debug("AllGatherChunkNoOpCleanupPass replaced %s patterns", matched_count)
self.end_and_log()