[Feat][SP] Suport SP for VL MoE models (#7044)
### What this PR does / why we need it?
2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.
### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.
### How was this patch tested?
- Model: Qwen3-VL-30B-A3B,
- TP4 DP2
- 100 reqs
- max concurrency 1
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k | 429.40 | 323.3 |
| 16k | 1297.01 | 911.74 |
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -7,21 +7,25 @@ from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.compilation.passes.noop_elimination import NoOpEliminationPass
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
SP_THRESHOLD = 1000
|
||||
SP_MIN_TOKEN_NUM_DEFAULT = 1000
|
||||
|
||||
|
||||
def get_sp_threshold(config: VllmConfig):
|
||||
def get_sp_min_token_num(config: VllmConfig) -> int:
|
||||
if is_moe_model(config):
|
||||
return 1
|
||||
|
||||
additional_config = config.additional_config if config.additional_config is not None else {}
|
||||
return additional_config.get("sp_threshold", SP_THRESHOLD)
|
||||
return SP_MIN_TOKEN_NUM_DEFAULT
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
"""Helper for sequence parallelism patterns.
|
||||
|
||||
Provides TP communication helper methods: _all_reduce, _reduce_scatter,
|
||||
_all_gather, and tensor creation utilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -49,7 +53,10 @@ class _SequenceParallelPatternHelper:
|
||||
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
|
||||
|
||||
|
||||
class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Replaces all_reduce + AddRMSNormBias with reduce_scatter + AddRMSNormBias
|
||||
+ all_gather for middle-layer sequence parallelism."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -92,7 +99,10 @@ class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Same as MiddleAllReduceRMSNormPattern but for the last layer
|
||||
(no residual backprop)."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -127,7 +137,13 @@ class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class Qwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""For Qwen3-VL middle layers with hidden_states + deepstack_input_embeds add.
|
||||
|
||||
Replaces all_reduce + add + AddRMSNormBias with reduce_scatter +
|
||||
chunk(deepstack_input_embeds) + add + AddRMSNormBias + all_gather.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -168,25 +184,45 @@ class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendSequenceParallelismPass(VllmInductorPass):
|
||||
class SequenceParallelismPass(VllmInductorPass):
|
||||
"""Sequence parallelism compilation pass.
|
||||
|
||||
Registers and applies the above patterns. Runs noop cleanup first, then
|
||||
uses token range to determine whether to enable SP.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_pass")
|
||||
self.noop_cleanup = NoOpEliminationPass(config)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
AscendMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
MiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendLastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
LastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendQwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
Qwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
self.min_tokens = get_sp_threshold(config)
|
||||
self.min_tokens = get_sp_min_token_num(config)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.noop_cleanup(graph) # Eliminate redundant view-like operations
|
||||
logger.debug(f"after noop_cleanup {graph.graph}")
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
logger.debug(f"after apply replacement {graph.graph}")
|
||||
|
||||
from torch._inductor.pattern_matcher import PatternPrettyPrinter
|
||||
|
||||
pattern_idx = 0
|
||||
for pattern_entry in self.patterns.patterns.values():
|
||||
for p in pattern_entry:
|
||||
p_str = PatternPrettyPrinter.run(p.pattern)
|
||||
logger.debug("Pattern %d: %s", pattern_idx, p_str)
|
||||
pattern_idx += 1
|
||||
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user