Files
xc-llm-ascend/vllm_ascend/compilation/passes/sequence_parallelism.py
realliujiaxu 5d12446573 [Feat][SP] Suport SP for VL MoE models (#7044)
### What this PR does / why we need it?

2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.


### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.


### How was this patch tested?
- Model: Qwen3-VL-30B-A3B, 
- TP4 DP2
- 100 reqs
- max concurrency 1

| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k         | 429.40               | 323.3                  |
| 16k        | 1297.01              | 911.74                |

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
2026-03-24 17:16:00 +08:00

235 lines
8.8 KiB
Python

import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce
from vllm.logger import logger
from vllm_ascend.compilation.passes.noop_elimination import NoOpEliminationPass
from vllm_ascend.utils import is_moe_model
SP_MIN_TOKEN_NUM_DEFAULT = 1000
def get_sp_min_token_num(config: VllmConfig) -> int:
if is_moe_model(config):
return 1
return SP_MIN_TOKEN_NUM_DEFAULT
class _SequenceParallelPatternHelper:
"""Helper for sequence parallelism patterns.
Provides TP communication helper methods: _all_reduce, _reduce_scatter,
_all_gather, and tensor creation utilities.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
):
self.eps = epsilon
self.dtype = dtype
self.device = device
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return tensor_model_parallel_all_reduce(x)
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.reduce_scatter(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
"""Replaces all_reduce + AddRMSNormBias with reduce_scatter + AddRMSNormBias
+ all_gather for middle-layer sequence parallelism."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
def get_inputs(self):
"""
Generate example inputs.
"""
input = self.empty(8, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
return [input, weight, residual]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
x = self._all_reduce(input)
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(x, residual, weight, None, self.eps)
return result, residual
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
reduce_scatter, residual, weight, None, self.eps
)
all_gather = self._all_gather(result)
return all_gather, residual
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
"""Same as MiddleAllReduceRMSNormPattern but for the last layer
(no residual backprop)."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
input = self.empty(8, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
return [input, weight, residual]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
x = self._all_reduce(input)
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(x, residual, weight, None, self.eps)
return result
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
reduce_scatter = self._reduce_scatter(input)
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(reduce_scatter, residual, weight, None, self.eps)
all_gather = self._all_gather(result)
return all_gather
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class Qwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
"""For Qwen3-VL middle layers with hidden_states + deepstack_input_embeds add.
Replaces all_reduce + add + AddRMSNormBias with reduce_scatter +
chunk(deepstack_input_embeds) + add + AddRMSNormBias + all_gather.
"""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
input = self.empty(8, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
deepstack_input_embeds = self.empty(8, 16)
return [input, weight, residual, deepstack_input_embeds]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
deepstack_input_embeds: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
x = self._all_reduce(input)
add_ = x + deepstack_input_embeds
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
return result, residual
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
deepstack_input_embeds: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
chunk = deepstack_input_embeds.chunk(self.tp_size)[self.tp_rank]
add_ = reduce_scatter + chunk
residual = torch.ops.vllm.maybe_chunk_residual(reduce_scatter, residual)
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
all_gather = self._all_gather(result)
return all_gather, residual
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class SequenceParallelismPass(VllmInductorPass):
"""Sequence parallelism compilation pass.
Registers and applies the above patterns. Runs noop cleanup first, then
uses token range to determine whether to enable SP.
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_pass")
self.noop_cleanup = NoOpEliminationPass(config)
for epsilon in [1e-5, 1e-6]:
MiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
LastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
Qwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
self.min_tokens = get_sp_min_token_num(config)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.noop_cleanup(graph) # Eliminate redundant view-like operations
logger.debug(f"after noop_cleanup {graph.graph}")
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
logger.debug(f"after apply replacement {graph.graph}")
from torch._inductor.pattern_matcher import PatternPrettyPrinter
pattern_idx = 0
for pattern_entry in self.patterns.patterns.values():
for p in pattern_entry:
p_str = PatternPrettyPrinter.run(p.pattern)
logger.debug("Pattern %d: %s", pattern_idx, p_str)
pattern_idx += 1
self.end_and_log()
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
applicable = compile_range.start >= self.min_tokens
logger.debug(f"SequenceParallelismPass {compile_range=} {applicable=}")
return applicable