[Feat][SP] Suport SP for VL MoE models (#7044)
### What this PR does / why we need it?
2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.
### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.
### How was this patch tested?
- Model: Qwen3-VL-30B-A3B,
- TP4 DP2
- 100 reqs
- max concurrency 1
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k | 429.40 | 323.3 |
| 16k | 1297.01 | 911.74 |
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
40
vllm_ascend/compilation/passes/allgather_chunk_noop_pass.py
Normal file
40
vllm_ascend/compilation/passes/allgather_chunk_noop_pass.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
class AllGatherChunkNoOpCleanupPass(VllmInductorPass):
|
||||
"""Fold all_gather + sequence_parallel_chunk_impl into identity."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_allgather_chunk_noop_cleanup_pass")
|
||||
self._register_patterns()
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
|
||||
|
||||
def _empty(self, *args, **kwargs):
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
|
||||
|
||||
def _register_patterns(self) -> None:
|
||||
def pattern(input: torch.Tensor) -> torch.Tensor:
|
||||
gathered = self._all_gather(input)
|
||||
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
|
||||
|
||||
def replacement(input: torch.Tensor) -> torch.Tensor:
|
||||
return input
|
||||
|
||||
pm.register_replacement(pattern, replacement, [self._empty(8, 16)], pm.fwd_only, self.patterns)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.begin()
|
||||
matched_count = self.patterns.apply(graph)
|
||||
logger.debug("AllGatherChunkNoOpCleanupPass replaced %s patterns", matched_count)
|
||||
self.end_and_log()
|
||||
62
vllm_ascend/compilation/passes/noop_elimination.py
Normal file
62
vllm_ascend/compilation/passes/noop_elimination.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
class NoOpEliminationPass(VllmInductorPass):
|
||||
"""Remove no-op view/reshape nodes after pattern rewrites."""
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
fx_graph = graph.graph if hasattr(graph, "graph") else graph
|
||||
removed = 0
|
||||
for node in list(fx_graph.nodes):
|
||||
if not self._is_view_like(node):
|
||||
continue
|
||||
|
||||
input_node = node.args[0]
|
||||
if not isinstance(input_node, torch.fx.Node):
|
||||
continue
|
||||
|
||||
input_meta = input_node.meta.get("val")
|
||||
output_meta = node.meta.get("val")
|
||||
if input_meta is None or output_meta is None:
|
||||
continue
|
||||
|
||||
input_shape = getattr(input_meta, "shape", None)
|
||||
output_shape = getattr(output_meta, "shape", None)
|
||||
if input_shape is None or output_shape is None:
|
||||
continue
|
||||
|
||||
if self._all_dims_equivalent(input_shape, output_shape):
|
||||
node.replace_all_uses_with(input_node)
|
||||
fx_graph.erase_node(node)
|
||||
removed += 1
|
||||
|
||||
logger.debug("NoOpEliminationPass removed %s no-op views", removed)
|
||||
|
||||
@staticmethod
|
||||
def _is_view_like(node: torch.fx.Node) -> bool:
|
||||
return (node.op == "call_method" and node.target in {"view", "reshape"}) or (
|
||||
node.op == "call_function"
|
||||
and node.target
|
||||
in {
|
||||
torch.ops.aten.view.default,
|
||||
torch.ops.aten.reshape.default,
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _dims_equivalent(dim: int | SymInt, i_dim: int | SymInt) -> bool:
|
||||
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
|
||||
|
||||
def _all_dims_equivalent(self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]) -> bool:
|
||||
dims_ = list(dims)
|
||||
i_dims_ = list(i_dims)
|
||||
if len(dims_) != len(i_dims_):
|
||||
return False
|
||||
return all(self._dims_equivalent(s, i_s) for s, i_s in zip(dims_, i_dims_))
|
||||
@@ -7,21 +7,25 @@ from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.compilation.passes.noop_elimination import NoOpEliminationPass
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
SP_THRESHOLD = 1000
|
||||
SP_MIN_TOKEN_NUM_DEFAULT = 1000
|
||||
|
||||
|
||||
def get_sp_threshold(config: VllmConfig):
|
||||
def get_sp_min_token_num(config: VllmConfig) -> int:
|
||||
if is_moe_model(config):
|
||||
return 1
|
||||
|
||||
additional_config = config.additional_config if config.additional_config is not None else {}
|
||||
return additional_config.get("sp_threshold", SP_THRESHOLD)
|
||||
return SP_MIN_TOKEN_NUM_DEFAULT
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
"""Helper for sequence parallelism patterns.
|
||||
|
||||
Provides TP communication helper methods: _all_reduce, _reduce_scatter,
|
||||
_all_gather, and tensor creation utilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -49,7 +53,10 @@ class _SequenceParallelPatternHelper:
|
||||
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
|
||||
|
||||
|
||||
class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Replaces all_reduce + AddRMSNormBias with reduce_scatter + AddRMSNormBias
|
||||
+ all_gather for middle-layer sequence parallelism."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -92,7 +99,10 @@ class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Same as MiddleAllReduceRMSNormPattern but for the last layer
|
||||
(no residual backprop)."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -127,7 +137,13 @@ class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class Qwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""For Qwen3-VL middle layers with hidden_states + deepstack_input_embeds add.
|
||||
|
||||
Replaces all_reduce + add + AddRMSNormBias with reduce_scatter +
|
||||
chunk(deepstack_input_embeds) + add + AddRMSNormBias + all_gather.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -168,25 +184,45 @@ class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendSequenceParallelismPass(VllmInductorPass):
|
||||
class SequenceParallelismPass(VllmInductorPass):
|
||||
"""Sequence parallelism compilation pass.
|
||||
|
||||
Registers and applies the above patterns. Runs noop cleanup first, then
|
||||
uses token range to determine whether to enable SP.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_pass")
|
||||
self.noop_cleanup = NoOpEliminationPass(config)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
AscendMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
MiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendLastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
LastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendQwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
Qwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
self.min_tokens = get_sp_threshold(config)
|
||||
self.min_tokens = get_sp_min_token_num(config)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.noop_cleanup(graph) # Eliminate redundant view-like operations
|
||||
logger.debug(f"after noop_cleanup {graph.graph}")
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
logger.debug(f"after apply replacement {graph.graph}")
|
||||
|
||||
from torch._inductor.pattern_matcher import PatternPrettyPrinter
|
||||
|
||||
pattern_idx = 0
|
||||
for pattern_entry in self.patterns.patterns.values():
|
||||
for p in pattern_entry:
|
||||
p_str = PatternPrettyPrinter.run(p.pattern)
|
||||
logger.debug("Pattern %d: %s", pattern_idx, p_str)
|
||||
pattern_idx += 1
|
||||
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
|
||||
204
vllm_ascend/compilation/passes/sequence_parallelism_moe.py
Normal file
204
vllm_ascend/compilation/passes/sequence_parallelism_moe.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from vllm.compilation.passes.vllm_inductor_pass import PatternPrettyPrinter, VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.compilation.passes.sequence_parallelism import (
|
||||
_SequenceParallelPatternHelper,
|
||||
get_sp_min_token_num,
|
||||
)
|
||||
|
||||
|
||||
class MiddleLayerAllgatherAddRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Replaces all_gather + slice + AddRMSNormBias with AddRMSNormBias +
|
||||
all_gather to avoid middle-layer shape mismatch."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
# num_tokens = 8
|
||||
return [input, weight, residual]
|
||||
|
||||
def get_scalar_inputs(self):
|
||||
return {"num_tokens": 8}
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather = self._all_gather(input)
|
||||
x_sliced = all_gather[:num_tokens]
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(x_sliced, residual, weight, None, self.eps)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(input, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather, residual
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
|
||||
)
|
||||
|
||||
|
||||
class LastLayerAllgatherRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Same as MiddleLayerAllgatherAddRMSNormPattern but for the last layer (no residual)
|
||||
all_gather + RMSNorm fusion."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def get_scalar_inputs(self):
|
||||
return {"num_tokens": 8}
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather = self._all_gather(input)
|
||||
x_sliced = all_gather[:num_tokens]
|
||||
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(x_sliced, residual, weight, None, self.eps)
|
||||
|
||||
return result
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
|
||||
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(input, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
|
||||
)
|
||||
|
||||
|
||||
class Qwen3VLMiddleLayerAllgatherAddRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Replaces all_gather + slice + add + AddRMSNormBias with add(chunk) +
|
||||
AddRMSNormBias + all_gather for Qwen3-VL-style all_gather path."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
deepstack_input_embeds = self.empty(8, 16)
|
||||
return [input, weight, residual, deepstack_input_embeds]
|
||||
|
||||
def get_scalar_inputs(self):
|
||||
return {"num_tokens": 8}
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
num_tokens,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather = self._all_gather(input)
|
||||
x_sliced = all_gather[:num_tokens]
|
||||
add_ = x_sliced + deepstack_input_embeds
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
num_tokens,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
chunk = deepstack_input_embeds.chunk(self.tp_size)[self.tp_rank]
|
||||
add_ = input + chunk
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather, residual
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
|
||||
)
|
||||
|
||||
|
||||
class AllGatherChunkNoOpPattern(_SequenceParallelPatternHelper):
|
||||
"""Folds all_gather + sequence_parallel_chunk_impl into identity (no-op)."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
return [self.empty(8, 16)]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor) -> torch.Tensor:
|
||||
gathered = self._all_gather(input)
|
||||
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
|
||||
|
||||
def replacement(input: torch.Tensor) -> torch.Tensor:
|
||||
return input
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SequenceParallelismMoePass(VllmInductorPass):
|
||||
"""Sequence parallelism AllGather epilogue pass.
|
||||
|
||||
Applies AllGather-based patterns: MiddleLayerAllgatherAddRMSNormPattern,
|
||||
LastLayerAllgatherRMSNormPattern, Qwen3VLMiddleLayerAllgatherAddRMSNormPattern,
|
||||
and AllGatherChunkNoOpPattern (all_gather + sequence_parallel_chunk_impl -> identity).
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_allgather_ep_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
MiddleLayerAllgatherAddRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
LastLayerAllgatherRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
Qwen3VLMiddleLayerAllgatherAddRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AllGatherChunkNoOpPattern(config).register(self.patterns)
|
||||
|
||||
self.min_tokens = get_sp_min_token_num(config)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
logger.debug(f"before apply replacement {graph}")
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug(f"after apply replacement {graph}")
|
||||
logger.debug("SequenceParallelismMoePass replaced %s patterns", self.matched_count)
|
||||
pattern_idx = 0
|
||||
for pattern_entry in self.patterns.patterns.values():
|
||||
for p in pattern_entry:
|
||||
p_str = PatternPrettyPrinter.run(p.pattern)
|
||||
logger.debug("Pattern %d: %s", pattern_idx, p_str)
|
||||
pattern_idx += 1
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
applicable = compile_range.start >= self.min_tokens
|
||||
logger.debug(f"SequenceParallelismMoePass {compile_range=} {applicable=}")
|
||||
return applicable
|
||||
Reference in New Issue
Block a user