[Feat][SP] Suport SP for VL MoE models (#7044)
### What this PR does / why we need it?
2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.
### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.
### How was this patch tested?
- Model: Qwen3-VL-30B-A3B,
- TP4 DP2
- 100 reqs
- max concurrency 1
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k | 429.40 | 323.3 |
| 16k | 1297.01 | 911.74 |
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -159,6 +159,12 @@ class AscendConfig:
|
||||
and get_ascend_device_type() != AscendDeviceType.A5
|
||||
)
|
||||
|
||||
self.enable_sp_by_pass = (
|
||||
vllm_config.model_config is not None
|
||||
and not vllm_config.model_config.enforce_eager
|
||||
and vllm_config.compilation_config.pass_config.enable_sp
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_compile_ranges(compilation_config):
|
||||
return compilation_config.compile_ranges_endpoints or []
|
||||
@@ -195,14 +201,6 @@ class AscendConfig:
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
if vllm_config.compilation_config.pass_config.enable_sp and not is_moe_model(vllm_config):
|
||||
from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_threshold
|
||||
|
||||
sp_threshold = get_sp_threshold(vllm_config)
|
||||
new_compile_ranges_split_points.append(sp_threshold)
|
||||
logger.debug(f"add {sp_threshold} to compile_ranges_split_points for sequence parallelism")
|
||||
if len(new_compile_ranges_split_points) > len(self._get_compile_ranges(vllm_config.compilation_config)):
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
self._set_compile_ranges(vllm_config.compilation_config, new_compile_ranges_split_points)
|
||||
|
||||
@@ -70,6 +70,8 @@ class GraphFusionPassManager:
|
||||
self.passes.append(MulsAddFusionPass(config))
|
||||
|
||||
if config.compilation_config.pass_config.enable_sp:
|
||||
from .passes.sequence_parallelism import AscendSequenceParallelismPass
|
||||
from .passes.sequence_parallelism import SequenceParallelismPass
|
||||
from .passes.sequence_parallelism_moe import SequenceParallelismMoePass
|
||||
|
||||
self.passes.append(AscendSequenceParallelismPass(config))
|
||||
self.passes.append(SequenceParallelismPass(config))
|
||||
self.passes.append(SequenceParallelismMoePass(config))
|
||||
|
||||
40
vllm_ascend/compilation/passes/allgather_chunk_noop_pass.py
Normal file
40
vllm_ascend/compilation/passes/allgather_chunk_noop_pass.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
class AllGatherChunkNoOpCleanupPass(VllmInductorPass):
|
||||
"""Fold all_gather + sequence_parallel_chunk_impl into identity."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_allgather_chunk_noop_cleanup_pass")
|
||||
self._register_patterns()
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
|
||||
|
||||
def _empty(self, *args, **kwargs):
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
|
||||
|
||||
def _register_patterns(self) -> None:
|
||||
def pattern(input: torch.Tensor) -> torch.Tensor:
|
||||
gathered = self._all_gather(input)
|
||||
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
|
||||
|
||||
def replacement(input: torch.Tensor) -> torch.Tensor:
|
||||
return input
|
||||
|
||||
pm.register_replacement(pattern, replacement, [self._empty(8, 16)], pm.fwd_only, self.patterns)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.begin()
|
||||
matched_count = self.patterns.apply(graph)
|
||||
logger.debug("AllGatherChunkNoOpCleanupPass replaced %s patterns", matched_count)
|
||||
self.end_and_log()
|
||||
62
vllm_ascend/compilation/passes/noop_elimination.py
Normal file
62
vllm_ascend/compilation/passes/noop_elimination.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
class NoOpEliminationPass(VllmInductorPass):
|
||||
"""Remove no-op view/reshape nodes after pattern rewrites."""
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
fx_graph = graph.graph if hasattr(graph, "graph") else graph
|
||||
removed = 0
|
||||
for node in list(fx_graph.nodes):
|
||||
if not self._is_view_like(node):
|
||||
continue
|
||||
|
||||
input_node = node.args[0]
|
||||
if not isinstance(input_node, torch.fx.Node):
|
||||
continue
|
||||
|
||||
input_meta = input_node.meta.get("val")
|
||||
output_meta = node.meta.get("val")
|
||||
if input_meta is None or output_meta is None:
|
||||
continue
|
||||
|
||||
input_shape = getattr(input_meta, "shape", None)
|
||||
output_shape = getattr(output_meta, "shape", None)
|
||||
if input_shape is None or output_shape is None:
|
||||
continue
|
||||
|
||||
if self._all_dims_equivalent(input_shape, output_shape):
|
||||
node.replace_all_uses_with(input_node)
|
||||
fx_graph.erase_node(node)
|
||||
removed += 1
|
||||
|
||||
logger.debug("NoOpEliminationPass removed %s no-op views", removed)
|
||||
|
||||
@staticmethod
|
||||
def _is_view_like(node: torch.fx.Node) -> bool:
|
||||
return (node.op == "call_method" and node.target in {"view", "reshape"}) or (
|
||||
node.op == "call_function"
|
||||
and node.target
|
||||
in {
|
||||
torch.ops.aten.view.default,
|
||||
torch.ops.aten.reshape.default,
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _dims_equivalent(dim: int | SymInt, i_dim: int | SymInt) -> bool:
|
||||
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
|
||||
|
||||
def _all_dims_equivalent(self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]) -> bool:
|
||||
dims_ = list(dims)
|
||||
i_dims_ = list(i_dims)
|
||||
if len(dims_) != len(i_dims_):
|
||||
return False
|
||||
return all(self._dims_equivalent(s, i_s) for s, i_s in zip(dims_, i_dims_))
|
||||
@@ -7,21 +7,25 @@ from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.compilation.passes.noop_elimination import NoOpEliminationPass
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
SP_THRESHOLD = 1000
|
||||
SP_MIN_TOKEN_NUM_DEFAULT = 1000
|
||||
|
||||
|
||||
def get_sp_threshold(config: VllmConfig):
|
||||
def get_sp_min_token_num(config: VllmConfig) -> int:
|
||||
if is_moe_model(config):
|
||||
return 1
|
||||
|
||||
additional_config = config.additional_config if config.additional_config is not None else {}
|
||||
return additional_config.get("sp_threshold", SP_THRESHOLD)
|
||||
return SP_MIN_TOKEN_NUM_DEFAULT
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
"""Helper for sequence parallelism patterns.
|
||||
|
||||
Provides TP communication helper methods: _all_reduce, _reduce_scatter,
|
||||
_all_gather, and tensor creation utilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -49,7 +53,10 @@ class _SequenceParallelPatternHelper:
|
||||
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
|
||||
|
||||
|
||||
class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Replaces all_reduce + AddRMSNormBias with reduce_scatter + AddRMSNormBias
|
||||
+ all_gather for middle-layer sequence parallelism."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -92,7 +99,10 @@ class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Same as MiddleAllReduceRMSNormPattern but for the last layer
|
||||
(no residual backprop)."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -127,7 +137,13 @@ class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class Qwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""For Qwen3-VL middle layers with hidden_states + deepstack_input_embeds add.
|
||||
|
||||
Replaces all_reduce + add + AddRMSNormBias with reduce_scatter +
|
||||
chunk(deepstack_input_embeds) + add + AddRMSNormBias + all_gather.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -168,25 +184,45 @@ class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendSequenceParallelismPass(VllmInductorPass):
|
||||
class SequenceParallelismPass(VllmInductorPass):
|
||||
"""Sequence parallelism compilation pass.
|
||||
|
||||
Registers and applies the above patterns. Runs noop cleanup first, then
|
||||
uses token range to determine whether to enable SP.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_pass")
|
||||
self.noop_cleanup = NoOpEliminationPass(config)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
AscendMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
MiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendLastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
LastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendQwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
Qwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
self.min_tokens = get_sp_threshold(config)
|
||||
self.min_tokens = get_sp_min_token_num(config)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.noop_cleanup(graph) # Eliminate redundant view-like operations
|
||||
logger.debug(f"after noop_cleanup {graph.graph}")
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
logger.debug(f"after apply replacement {graph.graph}")
|
||||
|
||||
from torch._inductor.pattern_matcher import PatternPrettyPrinter
|
||||
|
||||
pattern_idx = 0
|
||||
for pattern_entry in self.patterns.patterns.values():
|
||||
for p in pattern_entry:
|
||||
p_str = PatternPrettyPrinter.run(p.pattern)
|
||||
logger.debug("Pattern %d: %s", pattern_idx, p_str)
|
||||
pattern_idx += 1
|
||||
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
|
||||
204
vllm_ascend/compilation/passes/sequence_parallelism_moe.py
Normal file
204
vllm_ascend/compilation/passes/sequence_parallelism_moe.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from vllm.compilation.passes.vllm_inductor_pass import PatternPrettyPrinter, VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.compilation.passes.sequence_parallelism import (
|
||||
_SequenceParallelPatternHelper,
|
||||
get_sp_min_token_num,
|
||||
)
|
||||
|
||||
|
||||
class MiddleLayerAllgatherAddRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Replaces all_gather + slice + AddRMSNormBias with AddRMSNormBias +
|
||||
all_gather to avoid middle-layer shape mismatch."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
# num_tokens = 8
|
||||
return [input, weight, residual]
|
||||
|
||||
def get_scalar_inputs(self):
|
||||
return {"num_tokens": 8}
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather = self._all_gather(input)
|
||||
x_sliced = all_gather[:num_tokens]
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(x_sliced, residual, weight, None, self.eps)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(input, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather, residual
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
|
||||
)
|
||||
|
||||
|
||||
class LastLayerAllgatherRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Same as MiddleLayerAllgatherAddRMSNormPattern but for the last layer (no residual)
|
||||
all_gather + RMSNorm fusion."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def get_scalar_inputs(self):
|
||||
return {"num_tokens": 8}
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather = self._all_gather(input)
|
||||
x_sliced = all_gather[:num_tokens]
|
||||
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(x_sliced, residual, weight, None, self.eps)
|
||||
|
||||
return result
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
|
||||
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(input, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
|
||||
)
|
||||
|
||||
|
||||
class Qwen3VLMiddleLayerAllgatherAddRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Replaces all_gather + slice + add + AddRMSNormBias with add(chunk) +
|
||||
AddRMSNormBias + all_gather for Qwen3-VL-style all_gather path."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
deepstack_input_embeds = self.empty(8, 16)
|
||||
return [input, weight, residual, deepstack_input_embeds]
|
||||
|
||||
def get_scalar_inputs(self):
|
||||
return {"num_tokens": 8}
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
num_tokens,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather = self._all_gather(input)
|
||||
x_sliced = all_gather[:num_tokens]
|
||||
add_ = x_sliced + deepstack_input_embeds
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
num_tokens,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
chunk = deepstack_input_embeds.chunk(self.tp_size)[self.tp_rank]
|
||||
add_ = input + chunk
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather, residual
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
|
||||
)
|
||||
|
||||
|
||||
class AllGatherChunkNoOpPattern(_SequenceParallelPatternHelper):
|
||||
"""Folds all_gather + sequence_parallel_chunk_impl into identity (no-op)."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
return [self.empty(8, 16)]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor) -> torch.Tensor:
|
||||
gathered = self._all_gather(input)
|
||||
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
|
||||
|
||||
def replacement(input: torch.Tensor) -> torch.Tensor:
|
||||
return input
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SequenceParallelismMoePass(VllmInductorPass):
|
||||
"""Sequence parallelism AllGather epilogue pass.
|
||||
|
||||
Applies AllGather-based patterns: MiddleLayerAllgatherAddRMSNormPattern,
|
||||
LastLayerAllgatherRMSNormPattern, Qwen3VLMiddleLayerAllgatherAddRMSNormPattern,
|
||||
and AllGatherChunkNoOpPattern (all_gather + sequence_parallel_chunk_impl -> identity).
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_allgather_ep_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
MiddleLayerAllgatherAddRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
LastLayerAllgatherRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
Qwen3VLMiddleLayerAllgatherAddRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AllGatherChunkNoOpPattern(config).register(self.patterns)
|
||||
|
||||
self.min_tokens = get_sp_min_token_num(config)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
logger.debug(f"before apply replacement {graph}")
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug(f"after apply replacement {graph}")
|
||||
logger.debug("SequenceParallelismMoePass replaced %s patterns", self.matched_count)
|
||||
pattern_idx = 0
|
||||
for pattern_entry in self.patterns.patterns.values():
|
||||
for p in pattern_entry:
|
||||
p_str = PatternPrettyPrinter.run(p.pattern)
|
||||
logger.debug("Pattern %d: %s", pattern_idx, p_str)
|
||||
pattern_idx += 1
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
applicable = compile_range.start >= self.min_tokens
|
||||
logger.debug(f"SequenceParallelismMoePass {compile_range=} {applicable=}")
|
||||
return applicable
|
||||
@@ -33,7 +33,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEPrepareOutput
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
||||
from vllm_ascend.utils import enable_sp, enable_sp_by_pass, npu_stream_switch, prefill_context_parallel_enable
|
||||
|
||||
|
||||
class PrepareAndFinalize(ABC):
|
||||
@@ -324,7 +324,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
Returns:
|
||||
MoEPrepareOutput with global tensors.
|
||||
"""
|
||||
if enable_sp():
|
||||
if enable_sp() or enable_sp_by_pass():
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce)
|
||||
@@ -433,7 +433,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
Returns:
|
||||
Tensor with shape [local_num_tokens, hidden_size]
|
||||
"""
|
||||
if enable_sp():
|
||||
if enable_sp() or enable_sp_by_pass():
|
||||
return self._finalize_with_ep_group(hidden_states)
|
||||
|
||||
return self._finalize_with_dp_group(hidden_states, reduce_results)
|
||||
|
||||
@@ -17,7 +17,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
||||
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
from vllm_ascend.utils import enable_sp_by_pass, npu_stream_switch, prefetch_stream
|
||||
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
@@ -43,7 +43,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
except AssertionError:
|
||||
return x
|
||||
|
||||
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
|
||||
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled or (enable_sp_by_pass() and is_ep_comm)
|
||||
if flash_comm_v1_enabled and label:
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
@@ -53,6 +53,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
x = x[:-pad_size]
|
||||
else:
|
||||
x = get_ep_group().all_gather(x, 0)
|
||||
if enable_sp_by_pass(): # TODO: do unpad
|
||||
return x
|
||||
# unpad
|
||||
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
|
||||
@@ -74,7 +76,11 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
if not getattr(forward_context, "flash_comm_v1_enabled", False):
|
||||
flash_comm_v1_enabled = getattr(forward_context, "flash_comm_v1_enabled", False) or (
|
||||
enable_sp_by_pass() and is_ep_comm
|
||||
)
|
||||
|
||||
if not flash_comm_v1_enabled:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
@@ -84,6 +90,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
return tensor_model_parallel_reduce_scatter(x, 0)
|
||||
else:
|
||||
if enable_sp_by_pass():
|
||||
return get_ep_group().reduce_scatter(x.view(-1, *x.shape[1:]), 0)
|
||||
# padding
|
||||
dp_size = get_dp_group().world_size
|
||||
num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
||||
@@ -107,7 +115,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if _EXTRA_CTX.flash_comm_v1_enabled:
|
||||
if _EXTRA_CTX.flash_comm_v1_enabled or enable_sp_by_pass():
|
||||
return torch.empty(
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
|
||||
@@ -167,15 +167,12 @@
|
||||
# 1. `vllm.distributed.parallel_state.GroupCoordinator`
|
||||
# Why:
|
||||
# vllm doesn't support all_to_all for GroupCoordinator.
|
||||
# all_reduce in vLLM not is a customop, which will make MatmulAllReduceAddRMSNorm fusion failure.
|
||||
# How:
|
||||
# Add all_to_all implementation for GroupCoordinator.
|
||||
# make all_reduce as a customop.
|
||||
# Related PR (if no, explain why):
|
||||
# No, we should use vlLM all2all manager to support all_to_all for npu.
|
||||
# Future Plan:
|
||||
# Remove this patch when the refactor of all2all manager is done.
|
||||
# Remove this patch when vLLM support all_reduce as customop.
|
||||
#
|
||||
# ** 2. File: worker/patch_multimodal_merge.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -84,7 +84,7 @@ class GroupCoordinatorPatch(GroupCoordinator):
|
||||
if use_message_queue_broadcaster and self.world_size > 1:
|
||||
self.mq_broadcaster = MessageQueue.create_from_process_group(self.cpu_group, 1 << 22, 6)
|
||||
|
||||
self.use_custom_op_call = False
|
||||
self.use_custom_op_call = True
|
||||
self.use_cpu_custom_send_recv = False
|
||||
|
||||
def all_to_all(
|
||||
@@ -106,10 +106,5 @@ class GroupCoordinatorPatch(GroupCoordinator):
|
||||
assert self.device_communicator is not None, "device_communicator should be initialized when world_size > 1"
|
||||
return self.device_communicator.all_to_all(input_, scatter_dim, gather_dim, scatter_sizes, gather_sizes)
|
||||
|
||||
def all_reduce(self, input_):
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch
|
||||
|
||||
@@ -156,6 +156,16 @@ class NPUPlatform(Platform):
|
||||
def get_device_capability(cls, device_id: int = 0):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def apply_config_platform_defaults(cls, vllm_config: VllmConfig) -> None:
|
||||
"""Apply Ascend-specific defaults. Set sp_min_token_num=1 when enable_sp and not set."""
|
||||
pass_config = vllm_config.compilation_config.pass_config
|
||||
if pass_config.enable_sp and pass_config.sp_min_token_num is None:
|
||||
from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_min_token_num
|
||||
|
||||
pass_config.sp_min_token_num = get_sp_min_token_num(vllm_config)
|
||||
logger.info(f"set sp_min_token_num to {pass_config.sp_min_token_num}")
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.npu.get_device_name(device_id)
|
||||
@@ -198,6 +208,7 @@ class NPUPlatform(Platform):
|
||||
|
||||
# initialize ascend config from vllm additional_config
|
||||
cls._fix_incompatible_config(vllm_config)
|
||||
|
||||
ascend_config = init_ascend_config(vllm_config)
|
||||
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
@@ -218,6 +229,7 @@ class NPUPlatform(Platform):
|
||||
if not isinstance(ascend_compilation_config, dict)
|
||||
else ascend_compilation_config
|
||||
)
|
||||
|
||||
ascend_config.update_compile_ranges_split_points()
|
||||
|
||||
if model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
@@ -363,7 +375,8 @@ class NPUPlatform(Platform):
|
||||
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
|
||||
parallel_config.all2all_backend = "flashinfer_all2allv"
|
||||
if not vllm_config.compilation_config.pass_config.enable_sp:
|
||||
parallel_config.all2all_backend = "flashinfer_all2allv"
|
||||
if is_310p():
|
||||
parallel_config.worker_cls = "vllm_ascend._310p.worker_310p.NPUWorker310"
|
||||
elif ascend_config.xlite_graph_config.enabled:
|
||||
@@ -805,3 +818,7 @@ class NPUPlatform(Platform):
|
||||
"ignored on Ascend. Resetting to default (32)."
|
||||
)
|
||||
att_config.flash_attn_max_num_splits_for_cuda_graph = 32
|
||||
|
||||
@classmethod
|
||||
def use_custom_op_collectives(cls) -> bool:
|
||||
return True
|
||||
|
||||
@@ -764,8 +764,8 @@ def matmul_allreduce_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||
|
||||
|
||||
def enable_sp_by_pass(vllm_config: VllmConfig):
|
||||
return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp
|
||||
def enable_sp_by_pass():
|
||||
return get_ascend_config().enable_sp_by_pass
|
||||
|
||||
|
||||
def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
@@ -791,7 +791,7 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
|
||||
# TODO remove it after vllm has this func
|
||||
def shared_expert_dp_enabled() -> bool:
|
||||
return get_ascend_config().enable_shared_expert_dp or enable_sp()
|
||||
return get_ascend_config().enable_shared_expert_dp or enable_sp() or enable_sp_by_pass()
|
||||
|
||||
|
||||
def prefill_context_parallel_enable() -> bool:
|
||||
|
||||
@@ -1846,7 +1846,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if enable_sp(self.vllm_config) or enable_sp_by_pass(self.vllm_config):
|
||||
if enable_sp(self.vllm_config) or enable_sp_by_pass():
|
||||
return round_up(num_scheduled_tokens, tp_size)
|
||||
return num_scheduled_tokens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user