[Feat][SP] Suport SP for VL MoE models (#7044)

### What this PR does / why we need it?

2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.


### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.


### How was this patch tested?
- Model: Qwen3-VL-30B-A3B, 
- TP4 DP2
- 100 reqs
- max concurrency 1

| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k         | 429.40               | 323.3                  |
| 16k        | 1297.01              | 911.74                |

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2026-03-24 17:16:00 +08:00
committed by GitHub
parent 9615bc33fd
commit 5d12446573
21 changed files with 947 additions and 54 deletions

View File

@@ -141,6 +141,8 @@ e2e-multicard-2-cards:
estimated_time: 164
- name: tests/e2e/multicard/2-cards/test_sp_pass.py
estimated_time: 198
- name: tests/e2e/multicard/2-cards/test_sequence_parallelism_moe.py
estimated_time: 120
e2e-multicard-4-cards:
- name: tests/e2e/multicard/4-cards/test_qwen3_next.py
estimated_time: 1868

Binary file not shown.

After

Width:  |  Height:  |  Size: 576 KiB

View File

@@ -44,7 +44,6 @@ The following table lists additional configuration options available in vLLM Asc
| `pa_shape_list` | list | `[]` | The custom shape list of page attention ops. |
| `enable_kv_nz` | bool | `False` | Whether to enable KV cache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear |
| `sp_threshold` | int | `1000` | For dense models, only num_tokens > threshold will enable sequence parallelism. |
The details of each configuration option are as follows:

View File

@@ -14,14 +14,13 @@ Currently, vllm-ascend has implemented Sequence Parallelism for VL-class models
```bash
vllm serve Qwen/Qwen3-VL-2B-Instruct \
--tensor-parallel-size 2 \
--compilation-config '{"pass_config": {"enable_sp": true}}' \
--additional_config={"sp_threshold": 1000}
--compilation-config '{"pass_config": {"enable_sp": true, , "sp_min_token_num": 1000}}'
```
- `"pass_config": {"enable_sp": true}`: This is the switch for SP. Since SP relies on graph mode, it must be enabled and is not supported in eager mode.
- `--additional_config={"sp_threshold": 1000}`: Based on our experiments, when the number of tokens is small (empirical value is less than 1000), SP can actually bring negative benefits. This is because when the communication volume is small, the fixed overhead of the communication operator becomes the dominant factor. Therefore, when one communication operator (Allreduce) is split into two communication operators (ReduceScatter+Allgather), the end-to-end latency often becomes longer. Thus, we have reserved the `sp_threshold`parameter; SP will only take effect when `num_tokens >= sp_threshold`. **The default value is 1000, which generally does not need to be modified.** `sp_threshold` will be appended into `compile_ranges_split_points`, which is a parameter provided by vllm that splits the graph compilation range `[1, max_num_batched_tokens]` into `{[1, split_points[0]], [split_points[0] + 1, split_points[1]], ..., [split_points[-1] + 1, max_num_batched_tokens]}`, and sequentially checks whether the `is_applicable_for_range` of the pass returns `True`.
- `"enable_sp"`: This is the switch for SP. Since SP relies on graph mode, it is not supported in eager mode.
- `sp_min_token_num` (from upstream vllm's `pass_config`): Based on our experiments, when the number of tokens is small (empirical value is less than 1000), SP can actually bring negative benefits. This is because when the communication volume is small, the fixed overhead of the communication operator becomes the dominant factor. SP will only take effect when `num_tokens >= sp_min_token_num`. **The default value is 1000 on Ascend, which generally does not need to be modified.** To customize, use `--compilation-config '{"pass_config": {"enable_sp": true, "sp_min_token_num": 512}}'`. The value will be appended into `compile_ranges_split_points`, which splits the graph compilation range and checks whether the pass is applicable per range.
Without modifying `sp_threshold`, the simplest way and recommended way to enable SP is:
Without modifying `sp_min_token_num`, the simplest way and recommended way to enable SP is:
```bash
vllm serve Qwen/Qwen3-VL-2B-Instruct \
@@ -44,7 +43,7 @@ FC1 is a unique optimization in vllm-ascend, currently implemented based on Cust
| | VL + Dense | VL + MoE | non-VL + Dense | non-VL + MoE |
| -------------------- | ---------- | -------- | -------------- | ------------ |
| Sequence Parallelism | graph | x | x | x |
| Sequence Parallelism | graph | graph | x | x |
| Flash Comm V1 | x | x | eager/graph | eager/graph |
### With Quantization
@@ -55,3 +54,56 @@ SP currently does not support quantization and is under adaptation.
| -------------------- | ---------- | -------- | -------------- | ------------ |
| Sequence Parallelism | x | x | x | x |
| Flash Comm V1 | x | x | eager/graph | eager/graph |
## Pass Design
When SP is enabled, the following passes run in order: `SequenceParallelismPass` then `SequenceParallelismMoePass`.
### SequenceParallelismPass
Runs `NoOpEliminationPass` first to eliminate redundant view-like operations, then applies AllReduce-based patterns:
| Pattern | Match | Replacement |
| -------------------------------------- | -------------------------------- | ------------------------------------------------------------------------------------- |
| `MiddleAllReduceRMSNormPattern` | `all_reduce` + `layernorm` | `reduce_scatter` + `layernorm` + `all_gather` |
| `LastAllReduceRMSNormPattern` | Same (last layer, no residual) | Same |
| `Qwen3VLMiddleAllReduceRMSNormPattern` | `all_reduce` + add + `layernorm` | `reduce_scatter` + chunk(`deepstack_input_embeds`) + add + `layernorm` + `all_gather` |
**Why Qwen3 VL needs special handling by Qwen3VLMiddleAllReduceRMSNormPattern**
Qwen3-VL middle layers insert an extra add between `all_reduce` and `layernorm`: `hidden_states=hidden_states + deepstack_input_embeds`. Under SP, `hidden_states` (i.e., `input`) is reduced-scattered to shape `[seq_len/tp, hidden]` per rank, while `deepstack_input_embeds` comes from the vision/deepstack path and stays full-sequence `[seq_len, hidden]` (typically replicated across TP ranks). Simply doing `reduce_scatter(input) + deepstack_input_embeds` would cause a shape mismatch.
The fix is to chunk `deepstack_input_embeds` by `tp_size` so each rank uses `add(reduce_scatter, chunk(deepstack_input_embeds)[tp_rank])`, keeping shapes consistent before `layernorm` and `all_gather`.
### SequenceParallelismMoePass
After `SequenceParallelismPass` applies, the MoE model computation graph looks like:
![AllGather EP computation graph](../../assets/sp_moe.png)
**Overview**
1. **Postponing allgather**: Under SP, `residual` is chunked by tensor parallelism. This causes a shape mismatch between hidden states and residual in the next layer's layernorm: hidden states are gathered (full sequence) while residual remains chunked. The fix is to move `all_gather` to *after* layernorm so that layernorm operates on consistent shapes per rank. `MiddleLayerAllgatherAddRMSNormPattern`, `LastLayerAllgatherRMSNormPattern`, and `Qwen3VLMiddleLayerAllgatherAddRMSNormPattern` are designed for this purpose, each handling different layer and structure variants (see the table below).
2. **AllGatherChunkNoOp cleanup**: When MoE SP is enabled, vllm introduces a `sequence_parallel_chunk` op (corresponding to `sp_chunk` in the diagram). Together with the preceding `all_gather`, the pair forms a redundant no-op (all_gather gathers, then chunk re-splits). `AllGatherChunkNoOpPattern` replaces this pair with identity to eliminate the redundant communication and computation.
**Pattern details:**
| Pattern | Match | Replacement |
| ---------------------------------- | ---------------------------------------- | --------------------------------------- |
| `MiddleLayerAllgatherAddRMSNormPattern` | `all_gather` + slice + `layernorm` | `layernorm` + `all_gather` |
| `LastLayerAllgatherRMSNormPattern` | Same (last layer, no residual) | Same |
| `Qwen3VLMiddleLayerAllgatherAddRMSNormPattern` | `all_gather` + slice + add + `layernorm` | add(chunk) + `layernorm` + `all_gather` |
| `AllGatherChunkNoOpPattern` | `all_gather` + `sequence_parallel_chunk_impl` | identity (no-op) |
### FAQ
#### Q1: Is SP enabled by default?
No, SP is not enabled by default. SP is currently in the experimental stage and will be enabled by default in the future.
The processing flow of `enable_sp` in the code is:
- In `pass_config`, `enable_sp` and `sp_min_token_num` default to `None`
- `NPUPlatform.apply_config_platform_defaults`: If `enable_sp` is `True` and `sp_min_token_num` is None, set default `sp_min_token_num` (1000 for Dense models, 1 for MoE models)
- `VllmConfig._apply_optimization_level_defaults`: `enable_sp` is set to `True` for dense models.
- `VllmConfig.__post_init__`: If `sp_min_token_num` is still `None`, then `enable_sp` is set to `False`

View File

@@ -0,0 +1,473 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Two-card e2e tests for SequenceParallelismMoePass patterns:
# - MiddleLayerAllgatherAddRMSNormPattern (all_gather + slice + RMSNorm)
# - Qwen3VLMiddleLayerAllgatherAddRMSNormPattern (all_gather + slice + add + RMSNorm)
# - AllGatherChunkNoOpPattern (all_gather + sequence_parallel_chunk_impl -> identity)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import queue
import traceback
from collections.abc import Callable, Generator
from dataclasses import dataclass
from typing import Any
import pytest
import torch
import torch.nn as nn
import vllm.config
from vllm.compilation.passes.fx_utils import OpOverload
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_world_size,
get_tp_group,
init_distributed_environment,
tensor_model_parallel_all_gather,
)
from vllm.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
initialize_model_parallel,
)
from vllm.utils.system_utils import update_environment_variables
import vllm_ascend.ops.register_custom_ops # noqa
from tests.e2e.singlecard.compile.backend import TestBackend as CompileTestBackend
from vllm_ascend.compilation.passes.sequence_parallelism_moe import (
SequenceParallelismMoePass,
)
from vllm_ascend.utils import enable_custom_op
MASTER_PORT = 29500
WORLD_SIZE = 2
WORKER_READY = "__ready__"
WORKER_STOP = "__stop__"
WORKER_RESULT_TIMEOUT_S = 180
WORKER_JOIN_TIMEOUT_S = 30
class BaseAllGatherRMSNormModel(nn.Module):
def __init__(
self,
hidden_size: int,
dtype: torch.dtype,
eps: float = 1e-6,
device: str = "npu",
):
super().__init__()
self.eps = eps
self.norm_w = torch.randn(hidden_size, dtype=dtype, device=device)
def _all_gather_sliced(self, x: torch.Tensor, num_tokens_helper: torch.Tensor) -> torch.Tensor:
num_tokens = num_tokens_helper.shape[0]
activated = torch.relu(x)
gathered = tensor_model_parallel_all_gather(activated, 0)
return gathered[:num_tokens]
@staticmethod
def ops_in_model_after() -> tuple[tuple[OpOverload, int], ...]:
return (
(torch.ops.vllm.all_gather.default, 1),
(torch.ops._C_ascend.npu_add_rms_norm_bias.default, 1),
(torch.ops.vllm.maybe_chunk_residual.default, 1),
)
class AllGatherRMSNormModel(BaseAllGatherRMSNormModel):
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor,
num_tokens_helper: torch.Tensor,
) -> torch.Tensor:
sliced = self._all_gather_sliced(x, num_tokens_helper)
rms_out = torch.ops._C_ascend.npu_add_rms_norm_bias(sliced, residual, self.norm_w, None, self.eps)
return rms_out[0]
@staticmethod
def ops_in_model_before() -> tuple[tuple[OpOverload, int], ...]:
return (
(torch.ops.vllm.all_gather.default, 1),
(torch.ops._C_ascend.npu_add_rms_norm_bias.default, 1),
)
class Qwen3VLAllGatherRMSNormModel(BaseAllGatherRMSNormModel):
"""Exercises Qwen3VLMiddleLayerAllgatherAddRMSNormPattern (all_gather + slice + add + RMSNorm)."""
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor,
num_tokens_helper: torch.Tensor,
deepstack_input_embeds: torch.Tensor,
) -> torch.Tensor:
sliced = self._all_gather_sliced(x, num_tokens_helper)
add_ = sliced + deepstack_input_embeds
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, self.norm_w, None, self.eps)
# Keep the residual output live so the traced graph preserves the full pattern.
result = result - residual
return result
@staticmethod
def ops_in_model_before() -> tuple[tuple[OpOverload, int], ...]:
return (
(torch.ops.vllm.all_gather.default, 1),
(torch.ops.aten.add.Tensor, 1),
(torch.ops._C_ascend.npu_add_rms_norm_bias.default, 1),
)
class AllGatherChunkNoOpModel(nn.Module):
"""Exercises AllGatherChunkNoOpPattern (all_gather + sequence_parallel_chunk_impl -> identity)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = torch.relu(x)
gathered = tensor_model_parallel_all_gather(z, 0)
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
@staticmethod
def ops_in_model_before() -> tuple[tuple[OpOverload, int], ...]:
return (
(torch.ops.vllm.all_gather.default, 1),
(torch.ops.vllm.sequence_parallel_chunk_impl.default, 1),
)
@staticmethod
def ops_in_model_after() -> tuple[tuple[OpOverload, int], ...]:
return (
(torch.ops.vllm.all_gather.default, 0),
(torch.ops.vllm.sequence_parallel_chunk_impl.default, 0),
)
def _build_all_gather_rms_norm_inputs(
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
tp_size: int,
) -> tuple[torch.Tensor, ...]:
local_tokens = batch_size * seq_len
num_tokens = local_tokens * tp_size
x = torch.randn(local_tokens, hidden_size, dtype=dtype)
residual = torch.zeros(num_tokens, hidden_size, dtype=dtype)
num_tokens_helper = torch.empty(num_tokens, device=x.device, dtype=dtype)
return (x, residual, num_tokens_helper)
def _build_qwen3vl_inputs(
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
tp_size: int,
) -> tuple[torch.Tensor, ...]:
x, residual, num_tokens_helper = _build_all_gather_rms_norm_inputs(
batch_size=batch_size,
seq_len=seq_len,
hidden_size=hidden_size,
dtype=dtype,
tp_size=tp_size,
)
deepstack = torch.randn(num_tokens_helper.shape[0], hidden_size, dtype=dtype)
return (x, residual, num_tokens_helper, deepstack)
def _build_allgather_chunk_noop_inputs(
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
tp_size: int,
) -> tuple[torch.Tensor, ...]:
del tp_size
local_tokens = batch_size * seq_len
x = torch.randn(local_tokens, hidden_size, dtype=dtype)
return (x,)
def _create_all_gather_rms_norm_model(
hidden_size: int,
dtype: torch.dtype,
eps: float,
device: str,
) -> nn.Module:
return AllGatherRMSNormModel(hidden_size=hidden_size, dtype=dtype, eps=eps, device=device)
def _create_qwen3vl_model(
hidden_size: int,
dtype: torch.dtype,
eps: float,
device: str,
) -> nn.Module:
return Qwen3VLAllGatherRMSNormModel(hidden_size=hidden_size, dtype=dtype, eps=eps, device=device)
def _create_allgather_chunk_noop_model(
hidden_size: int,
dtype: torch.dtype,
eps: float,
device: str,
) -> nn.Module:
del hidden_size, dtype, eps, device
return AllGatherChunkNoOpModel()
@dataclass(frozen=True)
class PatternTestCase:
model_factory: Any
input_builder: Any
dynamic_input_indices: tuple[int, ...]
pre_pass_expected_counts_factory: Any
post_pass_expected_counts_factory: Any
PATTERN_TEST_CASES = {
"middle_layer_allgather_add_rms_norm": PatternTestCase(
model_factory=_create_all_gather_rms_norm_model,
input_builder=_build_all_gather_rms_norm_inputs,
dynamic_input_indices=(0, 2),
pre_pass_expected_counts_factory=AllGatherRMSNormModel.ops_in_model_before,
post_pass_expected_counts_factory=AllGatherRMSNormModel.ops_in_model_after,
),
"qwen3vl_middle_layer_allgather_add_rms_norm": PatternTestCase(
model_factory=_create_qwen3vl_model,
input_builder=_build_qwen3vl_inputs,
dynamic_input_indices=(0, 2, 3),
pre_pass_expected_counts_factory=Qwen3VLAllGatherRMSNormModel.ops_in_model_before,
post_pass_expected_counts_factory=Qwen3VLAllGatherRMSNormModel.ops_in_model_after,
),
"allgather_chunk_noop": PatternTestCase(
model_factory=_create_allgather_chunk_noop_model,
input_builder=_build_allgather_chunk_noop_inputs,
dynamic_input_indices=(0,),
pre_pass_expected_counts_factory=AllGatherChunkNoOpModel.ops_in_model_before,
post_pass_expected_counts_factory=AllGatherChunkNoOpModel.ops_in_model_after,
),
}
def _assert_op_counts(
backend: CompileTestBackend,
expected_counts: tuple[tuple[OpOverload, int], ...],
before: bool = False,
) -> None:
for op, expected_count in expected_counts:
actual_count = backend.op_count(op, before=before)
stage = "before" if before else "after"
assert actual_count == expected_count, (
f"op {stage} pass: {op} expected {expected_count}, but got {actual_count}"
)
def _run_single_pattern_case(
local_rank: int,
case_name: str,
vllm_config: VllmConfig,
tp_size: int,
batch_size: int = 8,
seq_len: int = 16,
hidden_size: int = 16,
dtype: torch.dtype = torch.bfloat16,
eps: float = 1e-5,
) -> None:
case = PATTERN_TEST_CASES[case_name]
sp_moe_pass = SequenceParallelismMoePass(vllm_config)
backend = CompileTestBackend(custom_passes=[sp_moe_pass])
model = case.model_factory(
hidden_size=hidden_size,
dtype=dtype,
eps=eps,
device=f"npu:{local_rank}",
)
inputs = case.input_builder(
batch_size=batch_size,
seq_len=seq_len,
hidden_size=hidden_size,
dtype=dtype,
tp_size=tp_size,
)
for dynamic_input_index in case.dynamic_input_indices:
torch._dynamo.mark_dynamic(inputs[dynamic_input_index], 0)
unfused = model(*inputs)
compiled = torch.compile(model, backend=backend)
fused = compiled(*inputs)
assert unfused.shape == fused.shape
assert sp_moe_pass.matched_count == 1
_assert_op_counts(backend, case.pre_pass_expected_counts_factory(), before=True)
_assert_op_counts(backend, case.post_pass_expected_counts_factory())
def _run_sequence_parallelism_moe_test(
local_rank: int,
world_size: int,
master_port: int,
command_queue: Any,
result_queue: Any,
batch_size: int = 8,
seq_len: int = 16,
hidden_size: int = 16,
dtype: torch.dtype = torch.bfloat16,
eps: float = 1e-5,
) -> None:
torch.npu.set_device(local_rank)
torch.set_default_device(f"npu:{local_rank}")
torch.set_default_dtype(dtype)
torch.manual_seed(0)
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": str(master_port),
}
)
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
try:
with vllm.config.set_current_vllm_config(vllm_config):
init_distributed_environment(
world_size=world_size,
rank=local_rank,
local_rank=local_rank,
backend="hccl",
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
if not enable_custom_op():
raise RuntimeError("vllm_ascend custom ops are not available")
_ = get_tp_group().unique_name
tp_size = get_tensor_model_parallel_world_size()
result_queue.put((WORKER_READY, local_rank, "ok", ""))
while True:
case_name = command_queue.get()
if case_name == WORKER_STOP:
return
try:
_run_single_pattern_case(
local_rank=local_rank,
case_name=case_name,
vllm_config=vllm_config,
tp_size=tp_size,
batch_size=batch_size,
seq_len=seq_len,
hidden_size=hidden_size,
dtype=dtype,
eps=eps,
)
except Exception:
result_queue.put((case_name, local_rank, "error", traceback.format_exc()))
else:
result_queue.put((case_name, local_rank, "ok", ""))
finally:
destroy_model_parallel()
destroy_distributed_environment()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
def _worker_entrypoint(
local_rank: int,
world_size: int,
master_port: int,
command_queue: Any,
result_queue: Any,
) -> None:
try:
_run_sequence_parallelism_moe_test(
local_rank=local_rank,
world_size=world_size,
master_port=master_port,
command_queue=command_queue,
result_queue=result_queue,
)
except Exception:
result_queue.put((WORKER_READY, local_rank, "error", traceback.format_exc()))
def _wait_for_worker_reports(
result_queue: Any,
case_name: str,
expected_reports: int,
) -> None:
errors = []
for _ in range(expected_reports):
try:
reported_case_name, local_rank, status, payload = result_queue.get(timeout=WORKER_RESULT_TIMEOUT_S)
except queue.Empty as exc:
raise TimeoutError(f"Timed out waiting for worker reports for {case_name}") from exc
assert reported_case_name == case_name, f"Expected worker report for {case_name}, but got {reported_case_name}"
if status != "ok":
errors.append(f"rank {local_rank}:\n{payload}")
if errors:
raise AssertionError("\n\n".join(errors))
@pytest.fixture(scope="module")
def sequence_parallelism_moe_workers() -> Generator[Callable[[str], None], None, None]:
ctx = torch.multiprocessing.get_context("spawn")
command_queues = [ctx.Queue() for _ in range(WORLD_SIZE)]
result_queue = ctx.Queue()
workers = []
for local_rank in range(WORLD_SIZE):
worker = ctx.Process(
target=_worker_entrypoint,
args=(local_rank, WORLD_SIZE, MASTER_PORT, command_queues[local_rank], result_queue),
)
worker.start()
workers.append(worker)
try:
_wait_for_worker_reports(result_queue, WORKER_READY, WORLD_SIZE)
def _run_case(case_name: str) -> None:
for command_queue in command_queues:
command_queue.put(case_name)
_wait_for_worker_reports(result_queue, case_name, WORLD_SIZE)
yield _run_case
finally:
for command_queue in command_queues:
command_queue.put(WORKER_STOP)
for worker in workers:
worker.join(timeout=WORKER_JOIN_TIMEOUT_S)
if worker.is_alive():
worker.terminate()
worker.join()
@pytest.mark.parametrize("case_name", tuple(PATTERN_TEST_CASES), ids=tuple(PATTERN_TEST_CASES))
def test_sequence_parallelism_moe_patterns(
sequence_parallelism_moe_workers: Callable[[str], None], case_name: str
) -> None:
sequence_parallelism_moe_workers(case_name)

View File

@@ -39,9 +39,9 @@ def test_qwen3_vl_sp_tp2(model: str) -> None:
compilation_config={
"cudagraph_capture_sizes": [2, 4],
"cudagraph_mode": "FULL_DECODE_ONLY",
"pass_config": {"enable_sp": True},
"pass_config": {"enable_sp": True, "sp_min_token_num": 10},
},
additional_config={"sp_threshold": 10, "ascend_compilation_config": {"enable_npugraph_ex": False}},
additional_config={"ascend_compilation_config": {"enable_npugraph_ex": False}},
) as runner:
sp_outputs = runner.model.generate(prompts, sampling_params)

View File

@@ -100,6 +100,11 @@ class TestBackend:
"""Helper to find all FX nodes that call a specific operator."""
return [node for node in graph.graph.nodes if hasattr(node, "target") and node.target == target]
def op_count(self, op: OpOverload, before: bool = False) -> int:
"""Return the number of nodes that call the given operator."""
graph = self.graph_pre_pass if before else self.graph_post_pass
return len(self.find_nodes_by_target(graph, op))
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced: bool = True):
"""
Verify that the original (unfused) operators exist before the pass

View File

@@ -188,7 +188,10 @@ class TestPrepareAndFinalize(unittest.TestCase):
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
return_value=False)
def test_allgather_prepare_finalize(self, mock_enable_sp,
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp_by_pass",
return_value=False)
def test_allgather_prepare_finalize(self, mock_enable_sp_by_pass,
mock_enable_sp,
mock_get_forward_context,
mock_get_dp_group):
# Mock forward context

View File

@@ -159,6 +159,12 @@ class AscendConfig:
and get_ascend_device_type() != AscendDeviceType.A5
)
self.enable_sp_by_pass = (
vllm_config.model_config is not None
and not vllm_config.model_config.enforce_eager
and vllm_config.compilation_config.pass_config.enable_sp
)
@staticmethod
def _get_compile_ranges(compilation_config):
return compilation_config.compile_ranges_endpoints or []
@@ -195,14 +201,6 @@ class AscendConfig:
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)
from vllm_ascend.utils import is_moe_model
if vllm_config.compilation_config.pass_config.enable_sp and not is_moe_model(vllm_config):
from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_threshold
sp_threshold = get_sp_threshold(vllm_config)
new_compile_ranges_split_points.append(sp_threshold)
logger.debug(f"add {sp_threshold} to compile_ranges_split_points for sequence parallelism")
if len(new_compile_ranges_split_points) > len(self._get_compile_ranges(vllm_config.compilation_config)):
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
self._set_compile_ranges(vllm_config.compilation_config, new_compile_ranges_split_points)

View File

@@ -70,6 +70,8 @@ class GraphFusionPassManager:
self.passes.append(MulsAddFusionPass(config))
if config.compilation_config.pass_config.enable_sp:
from .passes.sequence_parallelism import AscendSequenceParallelismPass
from .passes.sequence_parallelism import SequenceParallelismPass
from .passes.sequence_parallelism_moe import SequenceParallelismMoePass
self.passes.append(AscendSequenceParallelismPass(config))
self.passes.append(SequenceParallelismPass(config))
self.passes.append(SequenceParallelismMoePass(config))

View File

@@ -0,0 +1,40 @@
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.logger import logger
class AllGatherChunkNoOpCleanupPass(VllmInductorPass):
"""Fold all_gather + sequence_parallel_chunk_impl into identity."""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_allgather_chunk_noop_cleanup_pass")
self._register_patterns()
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
def _empty(self, *args, **kwargs):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
def _register_patterns(self) -> None:
def pattern(input: torch.Tensor) -> torch.Tensor:
gathered = self._all_gather(input)
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
def replacement(input: torch.Tensor) -> torch.Tensor:
return input
pm.register_replacement(pattern, replacement, [self._empty(8, 16)], pm.fwd_only, self.patterns)
def __call__(self, graph: torch.fx.Graph) -> None:
self.begin()
matched_count = self.patterns.apply(graph)
logger.debug("AllGatherChunkNoOpCleanupPass replaced %s patterns", matched_count)
self.end_and_log()

View File

@@ -0,0 +1,62 @@
from collections.abc import Iterable
import torch
import torch.fx
from torch import SymInt
from torch.fx.experimental.symbolic_shapes import statically_known_true
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
from vllm.logger import logger
class NoOpEliminationPass(VllmInductorPass):
"""Remove no-op view/reshape nodes after pattern rewrites."""
def __call__(self, graph: torch.fx.Graph) -> None:
fx_graph = graph.graph if hasattr(graph, "graph") else graph
removed = 0
for node in list(fx_graph.nodes):
if not self._is_view_like(node):
continue
input_node = node.args[0]
if not isinstance(input_node, torch.fx.Node):
continue
input_meta = input_node.meta.get("val")
output_meta = node.meta.get("val")
if input_meta is None or output_meta is None:
continue
input_shape = getattr(input_meta, "shape", None)
output_shape = getattr(output_meta, "shape", None)
if input_shape is None or output_shape is None:
continue
if self._all_dims_equivalent(input_shape, output_shape):
node.replace_all_uses_with(input_node)
fx_graph.erase_node(node)
removed += 1
logger.debug("NoOpEliminationPass removed %s no-op views", removed)
@staticmethod
def _is_view_like(node: torch.fx.Node) -> bool:
return (node.op == "call_method" and node.target in {"view", "reshape"}) or (
node.op == "call_function"
and node.target
in {
torch.ops.aten.view.default,
torch.ops.aten.reshape.default,
}
)
@staticmethod
def _dims_equivalent(dim: int | SymInt, i_dim: int | SymInt) -> bool:
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
def _all_dims_equivalent(self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]) -> bool:
dims_ = list(dims)
i_dims_ = list(i_dims)
if len(dims_) != len(i_dims_):
return False
return all(self._dims_equivalent(s, i_s) for s, i_s in zip(dims_, i_dims_))

View File

@@ -7,21 +7,25 @@ from vllm.config.utils import Range
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce
from vllm.logger import logger
from vllm_ascend.compilation.passes.noop_elimination import NoOpEliminationPass
from vllm_ascend.utils import is_moe_model
SP_THRESHOLD = 1000
SP_MIN_TOKEN_NUM_DEFAULT = 1000
def get_sp_threshold(config: VllmConfig):
def get_sp_min_token_num(config: VllmConfig) -> int:
if is_moe_model(config):
return 1
additional_config = config.additional_config if config.additional_config is not None else {}
return additional_config.get("sp_threshold", SP_THRESHOLD)
return SP_MIN_TOKEN_NUM_DEFAULT
class _SequenceParallelPatternHelper:
"""Helper for sequence parallelism patterns."""
"""Helper for sequence parallelism patterns.
Provides TP communication helper methods: _all_reduce, _reduce_scatter,
_all_gather, and tensor creation utilities.
"""
def __init__(
self,
@@ -49,7 +53,10 @@ class _SequenceParallelPatternHelper:
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
"""Replaces all_reduce + AddRMSNormBias with reduce_scatter + AddRMSNormBias
+ all_gather for middle-layer sequence parallelism."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
@@ -92,7 +99,10 @@ class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
"""Same as MiddleAllReduceRMSNormPattern but for the last layer
(no residual backprop)."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
@@ -127,7 +137,13 @@ class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class Qwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
"""For Qwen3-VL middle layers with hidden_states + deepstack_input_embeds add.
Replaces all_reduce + add + AddRMSNormBias with reduce_scatter +
chunk(deepstack_input_embeds) + add + AddRMSNormBias + all_gather.
"""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
@@ -168,25 +184,45 @@ class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AscendSequenceParallelismPass(VllmInductorPass):
class SequenceParallelismPass(VllmInductorPass):
"""Sequence parallelism compilation pass.
Registers and applies the above patterns. Runs noop cleanup first, then
uses token range to determine whether to enable SP.
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_pass")
self.noop_cleanup = NoOpEliminationPass(config)
for epsilon in [1e-5, 1e-6]:
AscendMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
MiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
AscendLastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
LastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
AscendQwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
Qwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
self.min_tokens = get_sp_threshold(config)
self.min_tokens = get_sp_min_token_num(config)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.noop_cleanup(graph) # Eliminate redundant view-like operations
logger.debug(f"after noop_cleanup {graph.graph}")
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
logger.debug(f"after apply replacement {graph.graph}")
from torch._inductor.pattern_matcher import PatternPrettyPrinter
pattern_idx = 0
for pattern_entry in self.patterns.patterns.values():
for p in pattern_entry:
p_str = PatternPrettyPrinter.run(p.pattern)
logger.debug("Pattern %d: %s", pattern_idx, p_str)
pattern_idx += 1
self.end_and_log()
def is_applicable_for_range(self, compile_range: Range) -> bool:

View File

@@ -0,0 +1,204 @@
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.passes.vllm_inductor_pass import PatternPrettyPrinter, VllmInductorPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.logger import logger
from vllm_ascend.compilation.passes.sequence_parallelism import (
_SequenceParallelPatternHelper,
get_sp_min_token_num,
)
class MiddleLayerAllgatherAddRMSNormPattern(_SequenceParallelPatternHelper):
"""Replaces all_gather + slice + AddRMSNormBias with AddRMSNormBias +
all_gather to avoid middle-layer shape mismatch."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
input = self.empty(5, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
# num_tokens = 8
return [input, weight, residual]
def get_scalar_inputs(self):
return {"num_tokens": 8}
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
) -> tuple[torch.Tensor, torch.Tensor]:
all_gather = self._all_gather(input)
x_sliced = all_gather[:num_tokens]
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(x_sliced, residual, weight, None, self.eps)
return result, residual
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(input, residual, weight, None, self.eps)
all_gather = self._all_gather(result)
return all_gather, residual
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
)
class LastLayerAllgatherRMSNormPattern(_SequenceParallelPatternHelper):
"""Same as MiddleLayerAllgatherAddRMSNormPattern but for the last layer (no residual)
all_gather + RMSNorm fusion."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
input = self.empty(5, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
return [input, weight, residual]
def get_scalar_inputs(self):
return {"num_tokens": 8}
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
) -> tuple[torch.Tensor, torch.Tensor]:
all_gather = self._all_gather(input)
x_sliced = all_gather[:num_tokens]
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(x_sliced, residual, weight, None, self.eps)
return result
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(input, residual, weight, None, self.eps)
all_gather = self._all_gather(result)
return all_gather
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
)
class Qwen3VLMiddleLayerAllgatherAddRMSNormPattern(_SequenceParallelPatternHelper):
"""Replaces all_gather + slice + add + AddRMSNormBias with add(chunk) +
AddRMSNormBias + all_gather for Qwen3-VL-style all_gather path."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
input = self.empty(5, 16)
weight = self.empty(16)
residual = self.empty(8, 16)
deepstack_input_embeds = self.empty(8, 16)
return [input, weight, residual, deepstack_input_embeds]
def get_scalar_inputs(self):
return {"num_tokens": 8}
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
deepstack_input_embeds: torch.Tensor,
num_tokens,
) -> tuple[torch.Tensor, torch.Tensor]:
all_gather = self._all_gather(input)
x_sliced = all_gather[:num_tokens]
add_ = x_sliced + deepstack_input_embeds
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
return result, residual
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
deepstack_input_embeds: torch.Tensor,
num_tokens,
) -> tuple[torch.Tensor, torch.Tensor]:
chunk = deepstack_input_embeds.chunk(self.tp_size)[self.tp_rank]
add_ = input + chunk
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
all_gather = self._all_gather(result)
return all_gather, residual
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
)
class AllGatherChunkNoOpPattern(_SequenceParallelPatternHelper):
"""Folds all_gather + sequence_parallel_chunk_impl into identity (no-op)."""
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
def get_inputs(self):
return [self.empty(8, 16)]
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor) -> torch.Tensor:
gathered = self._all_gather(input)
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
def replacement(input: torch.Tensor) -> torch.Tensor:
return input
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class SequenceParallelismMoePass(VllmInductorPass):
"""Sequence parallelism AllGather epilogue pass.
Applies AllGather-based patterns: MiddleLayerAllgatherAddRMSNormPattern,
LastLayerAllgatherRMSNormPattern, Qwen3VLMiddleLayerAllgatherAddRMSNormPattern,
and AllGatherChunkNoOpPattern (all_gather + sequence_parallel_chunk_impl -> identity).
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_allgather_ep_pass")
for epsilon in [1e-5, 1e-6]:
MiddleLayerAllgatherAddRMSNormPattern(config, epsilon).register(self.patterns)
LastLayerAllgatherRMSNormPattern(config, epsilon).register(self.patterns)
Qwen3VLMiddleLayerAllgatherAddRMSNormPattern(config, epsilon).register(self.patterns)
AllGatherChunkNoOpPattern(config).register(self.patterns)
self.min_tokens = get_sp_min_token_num(config)
def __call__(self, graph: torch.fx.Graph):
self.begin()
logger.debug(f"before apply replacement {graph}")
self.matched_count = self.patterns.apply(graph)
logger.debug(f"after apply replacement {graph}")
logger.debug("SequenceParallelismMoePass replaced %s patterns", self.matched_count)
pattern_idx = 0
for pattern_entry in self.patterns.patterns.values():
for p in pattern_entry:
p_str = PatternPrettyPrinter.run(p.pattern)
logger.debug("Pattern %d: %s", pattern_idx, p_str)
pattern_idx += 1
self.end_and_log()
def is_applicable_for_range(self, compile_range: Range) -> bool:
applicable = compile_range.start >= self.min_tokens
logger.debug(f"SequenceParallelismMoePass {compile_range=} {applicable=}")
return applicable

View File

@@ -33,7 +33,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEPrepareOutput
from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
from vllm_ascend.utils import enable_sp, enable_sp_by_pass, npu_stream_switch, prefill_context_parallel_enable
class PrepareAndFinalize(ABC):
@@ -324,7 +324,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
Returns:
MoEPrepareOutput with global tensors.
"""
if enable_sp():
if enable_sp() or enable_sp_by_pass():
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce)
@@ -433,7 +433,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
Returns:
Tensor with shape [local_num_tokens, hidden_size]
"""
if enable_sp():
if enable_sp() or enable_sp_by_pass():
return self._finalize_with_ep_group(hidden_states)
return self._finalize_with_dp_group(hidden_states, reduce_results)

View File

@@ -17,7 +17,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
from vllm_ascend.ops.triton.muls_add import muls_add_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
from vllm_ascend.utils import enable_sp_by_pass, npu_stream_switch, prefetch_stream
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
@@ -43,7 +43,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
except AssertionError:
return x
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled or (enable_sp_by_pass() and is_ep_comm)
if flash_comm_v1_enabled and label:
dp_metadata = forward_context.dp_metadata
if dp_metadata is None or not is_ep_comm:
@@ -53,6 +53,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
x = x[:-pad_size]
else:
x = get_ep_group().all_gather(x, 0)
if enable_sp_by_pass(): # TODO: do unpad
return x
# unpad
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
@@ -74,7 +76,11 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
except AssertionError:
return tensor_model_parallel_all_reduce(x)
if not getattr(forward_context, "flash_comm_v1_enabled", False):
flash_comm_v1_enabled = getattr(forward_context, "flash_comm_v1_enabled", False) or (
enable_sp_by_pass() and is_ep_comm
)
if not flash_comm_v1_enabled:
return tensor_model_parallel_all_reduce(x)
dp_metadata = forward_context.dp_metadata
@@ -84,6 +90,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
x = F.pad(x, (0, 0, 0, pad_size))
return tensor_model_parallel_reduce_scatter(x, 0)
else:
if enable_sp_by_pass():
return get_ep_group().reduce_scatter(x.view(-1, *x.shape[1:]), 0)
# padding
dp_size = get_dp_group().world_size
num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu
@@ -107,7 +115,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
if _EXTRA_CTX.flash_comm_v1_enabled:
if _EXTRA_CTX.flash_comm_v1_enabled or enable_sp_by_pass():
return torch.empty(
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
)

View File

@@ -167,15 +167,12 @@
# 1. `vllm.distributed.parallel_state.GroupCoordinator`
# Why:
# vllm doesn't support all_to_all for GroupCoordinator.
# all_reduce in vLLM not is a customop, which will make MatmulAllReduceAddRMSNorm fusion failure.
# How
# Add all_to_all implementation for GroupCoordinator.
# make all_reduce as a customop.
# Related PR (if no, explain why):
# No, we should use vlLM all2all manager to support all_to_all for npu.
# Future Plan:
# Remove this patch when the refactor of all2all manager is done.
# Remove this patch when vLLM support all_reduce as customop.
#
# ** 2. File: worker/patch_multimodal_merge.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@@ -84,7 +84,7 @@ class GroupCoordinatorPatch(GroupCoordinator):
if use_message_queue_broadcaster and self.world_size > 1:
self.mq_broadcaster = MessageQueue.create_from_process_group(self.cpu_group, 1 << 22, 6)
self.use_custom_op_call = False
self.use_custom_op_call = True
self.use_cpu_custom_send_recv = False
def all_to_all(
@@ -106,10 +106,5 @@ class GroupCoordinatorPatch(GroupCoordinator):
assert self.device_communicator is not None, "device_communicator should be initialized when world_size > 1"
return self.device_communicator.all_to_all(input_, scatter_dim, gather_dim, scatter_sizes, gather_sizes)
def all_reduce(self, input_):
if self.world_size == 1:
return input_
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch

View File

@@ -156,6 +156,16 @@ class NPUPlatform(Platform):
def get_device_capability(cls, device_id: int = 0):
return None
@classmethod
def apply_config_platform_defaults(cls, vllm_config: VllmConfig) -> None:
"""Apply Ascend-specific defaults. Set sp_min_token_num=1 when enable_sp and not set."""
pass_config = vllm_config.compilation_config.pass_config
if pass_config.enable_sp and pass_config.sp_min_token_num is None:
from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_min_token_num
pass_config.sp_min_token_num = get_sp_min_token_num(vllm_config)
logger.info(f"set sp_min_token_num to {pass_config.sp_min_token_num}")
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return torch.npu.get_device_name(device_id)
@@ -198,6 +208,7 @@ class NPUPlatform(Platform):
# initialize ascend config from vllm additional_config
cls._fix_incompatible_config(vllm_config)
ascend_config = init_ascend_config(vllm_config)
if vllm_config.kv_transfer_config is not None:
@@ -218,6 +229,7 @@ class NPUPlatform(Platform):
if not isinstance(ascend_compilation_config, dict)
else ascend_compilation_config
)
ascend_config.update_compile_ranges_split_points()
if model_config and hasattr(model_config.hf_text_config, "index_topk"):
@@ -363,7 +375,8 @@ class NPUPlatform(Platform):
if parallel_config and parallel_config.worker_cls == "auto":
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
parallel_config.all2all_backend = "flashinfer_all2allv"
if not vllm_config.compilation_config.pass_config.enable_sp:
parallel_config.all2all_backend = "flashinfer_all2allv"
if is_310p():
parallel_config.worker_cls = "vllm_ascend._310p.worker_310p.NPUWorker310"
elif ascend_config.xlite_graph_config.enabled:
@@ -805,3 +818,7 @@ class NPUPlatform(Platform):
"ignored on Ascend. Resetting to default (32)."
)
att_config.flash_attn_max_num_splits_for_cuda_graph = 32
@classmethod
def use_custom_op_collectives(cls) -> bool:
return True

View File

@@ -764,8 +764,8 @@ def matmul_allreduce_enable() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
def enable_sp_by_pass(vllm_config: VllmConfig):
return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp
def enable_sp_by_pass():
return get_ascend_config().enable_sp_by_pass
def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
@@ -791,7 +791,7 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
# TODO remove it after vllm has this func
def shared_expert_dp_enabled() -> bool:
return get_ascend_config().enable_shared_expert_dp or enable_sp()
return get_ascend_config().enable_shared_expert_dp or enable_sp() or enable_sp_by_pass()
def prefill_context_parallel_enable() -> bool:

View File

@@ -1846,7 +1846,7 @@ class NPUModelRunner(GPUModelRunner):
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if enable_sp(self.vllm_config) or enable_sp_by_pass(self.vllm_config):
if enable_sp(self.vllm_config) or enable_sp_by_pass():
return round_up(num_scheduled_tokens, tp_size)
return num_scheduled_tokens