### What this PR does / why we need it?
2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.
### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.
### How was this patch tested?
- Model: Qwen3-VL-30B-A3B,
- TP4 DP2
- 100 reqs
- max concurrency 1
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k | 429.40 | 323.3 |
| 16k | 1297.01 | 911.74 |
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
41 lines
1.7 KiB
Python
41 lines
1.7 KiB
Python
import torch
|
|
import torch._inductor.pattern_matcher as pm
|
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
|
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
|
from vllm.logger import logger
|
|
|
|
|
|
class AllGatherChunkNoOpCleanupPass(VllmInductorPass):
|
|
"""Fold all_gather + sequence_parallel_chunk_impl into identity."""
|
|
|
|
def __init__(self, config: VllmConfig):
|
|
super().__init__(config)
|
|
self.tp_group = get_tp_group()
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_allgather_chunk_noop_cleanup_pass")
|
|
self._register_patterns()
|
|
|
|
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
|
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
|
|
|
|
def _empty(self, *args, **kwargs):
|
|
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
|
|
|
|
def _register_patterns(self) -> None:
|
|
def pattern(input: torch.Tensor) -> torch.Tensor:
|
|
gathered = self._all_gather(input)
|
|
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
|
|
|
|
def replacement(input: torch.Tensor) -> torch.Tensor:
|
|
return input
|
|
|
|
pm.register_replacement(pattern, replacement, [self._empty(8, 16)], pm.fwd_only, self.patterns)
|
|
|
|
def __call__(self, graph: torch.fx.Graph) -> None:
|
|
self.begin()
|
|
matched_count = self.patterns.apply(graph)
|
|
logger.debug("AllGatherChunkNoOpCleanupPass replaced %s patterns", matched_count)
|
|
self.end_and_log()
|