[Feat][SP] Suport SP for VL MoE models (#7044)

### What this PR does / why we need it?

2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.


### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.


### How was this patch tested?
- Model: Qwen3-VL-30B-A3B, 
- TP4 DP2
- 100 reqs
- max concurrency 1

| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k         | 429.40               | 323.3                  |
| 16k        | 1297.01              | 911.74                |

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2026-03-24 17:16:00 +08:00
committed by GitHub
parent 9615bc33fd
commit 5d12446573
21 changed files with 947 additions and 54 deletions

View File

@@ -70,6 +70,8 @@ class GraphFusionPassManager:
self.passes.append(MulsAddFusionPass(config))
if config.compilation_config.pass_config.enable_sp:
from .passes.sequence_parallelism import AscendSequenceParallelismPass
from .passes.sequence_parallelism import SequenceParallelismPass
from .passes.sequence_parallelism_moe import SequenceParallelismMoePass
self.passes.append(AscendSequenceParallelismPass(config))
self.passes.append(SequenceParallelismPass(config))
self.passes.append(SequenceParallelismMoePass(config))

View File

@@ -0,0 +1,40 @@
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.logger import logger
class AllGatherChunkNoOpCleanupPass(VllmInductorPass):
"""Fold all_gather + sequence_parallel_chunk_impl into identity."""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_allgather_chunk_noop_cleanup_pass")
self._register_patterns()
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
def _empty(self, *args, **kwargs):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
def _register_patterns(self) -> None:
def pattern(input: torch.Tensor) -> torch.Tensor:
gathered = self._all_gather(input)
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
def replacement(input: torch.Tensor) -> torch.Tensor:
return input
pm.register_replacement(pattern, replacement, [self._empty(8, 16)], pm.fwd_only, self.patterns)
def __call__(self, graph: torch.fx.Graph) -> None:
self.begin()
matched_count = self.patterns.apply(graph)
logger.debug("AllGatherChunkNoOpCleanupPass replaced %s patterns", matched_count)
self.end_and_log()

View File

@@ -0,0 +1,62 @@
from collections.abc import Iterable
import torch
import torch.fx
from torch import SymInt
from torch.fx.experimental.symbolic_shapes import statically_known_true
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
from vllm.logger import logger
class NoOpEliminationPass(VllmInductorPass):
"""Remove no-op view/reshape nodes after pattern rewrites."""
def __call__(self, graph: torch.fx.Graph) -> None:
fx_graph = graph.graph if hasattr(graph, "graph") else graph
removed = 0
for node in list(fx_graph.nodes):
if not self._is_view_like(node):
continue
input_node = node.args[0]
if not isinstance(input_node, torch.fx.Node):
continue
input_meta = input_node.meta.get("val")
output_meta = node.meta.get("val")
if input_meta is None or output_meta is None:
continue
input_shape = getattr(input_meta, "shape", None)
output_shape = getattr(output_meta, "shape", None)
if input_shape is None or output_shape is None:
continue
if self._all_dims_equivalent(input_shape, output_shape):
node.replace_all_uses_with(input_node)
fx_graph.erase_node(node)
removed += 1
logger.debug("NoOpEliminationPass removed %s no-op views", removed)
@staticmethod
def _is_view_like(node: torch.fx.Node) -> bool:
return (node.op == "call_method" and node.target in {"view", "reshape"}) or (
node.op == "call_function"
and node.target
in {
torch.ops.aten.view.default,
torch.ops.aten.reshape.default,
}
)
@staticmethod
def _dims_equivalent(dim: int | SymInt, i_dim: int | SymInt) -> bool:
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
def _all_dims_equivalent(self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]) -> bool:
dims_ = list(dims)
i_dims_ = list(i_dims)
if len(dims_) != len(i_dims_):
return False
return all(self._dims_equivalent(s, i_s) for s, i_s in zip(dims_, i_dims_))

View File

@@ -7,21 +7,25 @@ from vllm.config.utils import Range
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce
from vllm.logger import logger
from vllm_ascend.compilation.passes.noop_elimination import NoOpEliminationPass
from vllm_ascend.utils import is_moe_model
SP_THRESHOLD = 1000
SP_MIN_TOKEN_NUM_DEFAULT = 1000
def get_sp_threshold(config: VllmConfig):
def get_sp_min_token_num(config: VllmConfig) -> int:
if is_moe_model(config):
return 1
additional_config = config.additional_config if config.additional_config is not None else {}
return additional_config.get("sp_threshold", SP_THRESHOLD)
return SP_MIN_TOKEN_NUM_DEFAULT
class _SequenceParallelPatternHelper:
"""Helper for sequence parallelism patterns."""
"""Helper for sequence parallelism patterns.
Provides TP communication helper methods: _all_reduce, _reduce_scatter,
_all_gather, and tensor creation utilities.
"""
def __init__(
self,
@@ -49,7 +53,10 @@ class _SequenceParallelPatternHelper:
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
"""Replaces all_reduce + AddRMSNormBias with reduce_scatter + AddRMSNormBias
+ all_gather for middle-layer sequence parallelism."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
@@ -92,7 +99,10 @@ class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
"""Same as MiddleAllReduceRMSNormPattern but for the last layer
(no residual backprop)."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
@@ -127,7 +137,13 @@ class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class Qwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
"""For Qwen3-VL middle layers with hidden_states + deepstack_input_embeds add.
Replaces all_reduce + add + AddRMSNormBias with reduce_scatter +
chunk(deepstack_input_embeds) + add + AddRMSNormBias + all_gather.
"""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
@@ -168,25 +184,45 @@ class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AscendSequenceParallelismPass(VllmInductorPass):
class SequenceParallelismPass(VllmInductorPass):
"""Sequence parallelism compilation pass.
Registers and applies the above patterns. Runs noop cleanup first, then
uses token range to determine whether to enable SP.
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_pass")
self.noop_cleanup = NoOpEliminationPass(config)
for epsilon in [1e-5, 1e-6]:
AscendMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
MiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
AscendLastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
LastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
AscendQwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
Qwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
self.min_tokens = get_sp_threshold(config)
self.min_tokens = get_sp_min_token_num(config)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.noop_cleanup(graph) # Eliminate redundant view-like operations
logger.debug(f"after noop_cleanup {graph.graph}")
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
logger.debug(f"after apply replacement {graph.graph}")
from torch._inductor.pattern_matcher import PatternPrettyPrinter
pattern_idx = 0
for pattern_entry in self.patterns.patterns.values():
for p in pattern_entry:
p_str = PatternPrettyPrinter.run(p.pattern)
logger.debug("Pattern %d: %s", pattern_idx, p_str)
pattern_idx += 1
self.end_and_log()
def is_applicable_for_range(self, compile_range: Range) -> bool:

View File

@@ -0,0 +1,204 @@
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.passes.vllm_inductor_pass import PatternPrettyPrinter, VllmInductorPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.logger import logger
from vllm_ascend.compilation.passes.sequence_parallelism import (
_SequenceParallelPatternHelper,
get_sp_min_token_num,
)
class MiddleLayerAllgatherAddRMSNormPattern(_SequenceParallelPatternHelper):
"""Replaces all_gather + slice + AddRMSNormBias with AddRMSNormBias +
all_gather to avoid middle-layer shape mismatch."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
input = self.empty(5, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
# num_tokens = 8
return [input, weight, residual]
def get_scalar_inputs(self):
return {"num_tokens": 8}
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
) -> tuple[torch.Tensor, torch.Tensor]:
all_gather = self._all_gather(input)
x_sliced = all_gather[:num_tokens]
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(x_sliced, residual, weight, None, self.eps)
return result, residual
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(input, residual, weight, None, self.eps)
all_gather = self._all_gather(result)
return all_gather, residual
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
)
class LastLayerAllgatherRMSNormPattern(_SequenceParallelPatternHelper):
"""Same as MiddleLayerAllgatherAddRMSNormPattern but for the last layer (no residual)
all_gather + RMSNorm fusion."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
input = self.empty(5, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
return [input, weight, residual]
def get_scalar_inputs(self):
return {"num_tokens": 8}
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
) -> tuple[torch.Tensor, torch.Tensor]:
all_gather = self._all_gather(input)
x_sliced = all_gather[:num_tokens]
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(x_sliced, residual, weight, None, self.eps)
return result
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(input, residual, weight, None, self.eps)
all_gather = self._all_gather(result)
return all_gather
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
)
class Qwen3VLMiddleLayerAllgatherAddRMSNormPattern(_SequenceParallelPatternHelper):
"""Replaces all_gather + slice + add + AddRMSNormBias with add(chunk) +
AddRMSNormBias + all_gather for Qwen3-VL-style all_gather path."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
input = self.empty(5, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
deepstack_input_embeds = self.empty(8, 16)
return [input, weight, residual, deepstack_input_embeds]
def get_scalar_inputs(self):
return {"num_tokens": 8}
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
deepstack_input_embeds: torch.Tensor,
num_tokens,
) -> tuple[torch.Tensor, torch.Tensor]:
all_gather = self._all_gather(input)
x_sliced = all_gather[:num_tokens]
add_ = x_sliced + deepstack_input_embeds
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
return result, residual
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
deepstack_input_embeds: torch.Tensor,
num_tokens,
) -> tuple[torch.Tensor, torch.Tensor]:
chunk = deepstack_input_embeds.chunk(self.tp_size)[self.tp_rank]
add_ = input + chunk
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
all_gather = self._all_gather(result)
return all_gather, residual
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
)
class AllGatherChunkNoOpPattern(_SequenceParallelPatternHelper):
"""Folds all_gather + sequence_parallel_chunk_impl into identity (no-op)."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
return [self.empty(8, 16)]
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor) -> torch.Tensor:
gathered = self._all_gather(input)
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
def replacement(input: torch.Tensor) -> torch.Tensor:
return input
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class SequenceParallelismMoePass(VllmInductorPass):
"""Sequence parallelism AllGather epilogue pass.
Applies AllGather-based patterns: MiddleLayerAllgatherAddRMSNormPattern,
LastLayerAllgatherRMSNormPattern, Qwen3VLMiddleLayerAllgatherAddRMSNormPattern,
and AllGatherChunkNoOpPattern (all_gather + sequence_parallel_chunk_impl -> identity).
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_allgather_ep_pass")
for epsilon in [1e-5, 1e-6]:
MiddleLayerAllgatherAddRMSNormPattern(config, epsilon).register(self.patterns)
LastLayerAllgatherRMSNormPattern(config, epsilon).register(self.patterns)
Qwen3VLMiddleLayerAllgatherAddRMSNormPattern(config, epsilon).register(self.patterns)
AllGatherChunkNoOpPattern(config).register(self.patterns)
self.min_tokens = get_sp_min_token_num(config)
def __call__(self, graph: torch.fx.Graph):
self.begin()
logger.debug(f"before apply replacement {graph}")
self.matched_count = self.patterns.apply(graph)
logger.debug(f"after apply replacement {graph}")
logger.debug("SequenceParallelismMoePass replaced %s patterns", self.matched_count)
pattern_idx = 0
for pattern_entry in self.patterns.patterns.values():
for p in pattern_entry:
p_str = PatternPrettyPrinter.run(p.pattern)
logger.debug("Pattern %d: %s", pattern_idx, p_str)
pattern_idx += 1
self.end_and_log()
def is_applicable_for_range(self, compile_range: Range) -> bool:
applicable = compile_range.start >= self.min_tokens
logger.debug(f"SequenceParallelismMoePass {compile_range=} {applicable=}")
return applicable