[Feat][SP] Suport SP for VL MoE models (#7044)
### What this PR does / why we need it?
2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.
### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.
### How was this patch tested?
- Model: Qwen3-VL-30B-A3B,
- TP4 DP2
- 100 reqs
- max concurrency 1
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k | 429.40 | 323.3 |
| 16k | 1297.01 | 911.74 |
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
2
.github/workflows/scripts/config.yaml
vendored
2
.github/workflows/scripts/config.yaml
vendored
@@ -141,6 +141,8 @@ e2e-multicard-2-cards:
|
||||
estimated_time: 164
|
||||
- name: tests/e2e/multicard/2-cards/test_sp_pass.py
|
||||
estimated_time: 198
|
||||
- name: tests/e2e/multicard/2-cards/test_sequence_parallelism_moe.py
|
||||
estimated_time: 120
|
||||
e2e-multicard-4-cards:
|
||||
- name: tests/e2e/multicard/4-cards/test_qwen3_next.py
|
||||
estimated_time: 1868
|
||||
|
||||
BIN
docs/source/assets/sp_moe.png
Normal file
BIN
docs/source/assets/sp_moe.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 576 KiB |
@@ -44,7 +44,6 @@ The following table lists additional configuration options available in vLLM Asc
|
||||
| `pa_shape_list` | list | `[]` | The custom shape list of page attention ops. |
|
||||
| `enable_kv_nz` | bool | `False` | Whether to enable KV cache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
||||
| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear |
|
||||
| `sp_threshold` | int | `1000` | For dense models, only num_tokens > threshold will enable sequence parallelism. |
|
||||
|
||||
The details of each configuration option are as follows:
|
||||
|
||||
|
||||
@@ -14,14 +14,13 @@ Currently, vllm-ascend has implemented Sequence Parallelism for VL-class models
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-VL-2B-Instruct \
|
||||
--tensor-parallel-size 2 \
|
||||
--compilation-config '{"pass_config": {"enable_sp": true}}' \
|
||||
--additional_config={"sp_threshold": 1000}
|
||||
--compilation-config '{"pass_config": {"enable_sp": true, , "sp_min_token_num": 1000}}'
|
||||
```
|
||||
|
||||
- `"pass_config": {"enable_sp": true}`: This is the switch for SP. Since SP relies on graph mode, it must be enabled and is not supported in eager mode.
|
||||
- `--additional_config={"sp_threshold": 1000}`: Based on our experiments, when the number of tokens is small (empirical value is less than 1000), SP can actually bring negative benefits. This is because when the communication volume is small, the fixed overhead of the communication operator becomes the dominant factor. Therefore, when one communication operator (Allreduce) is split into two communication operators (ReduceScatter+Allgather), the end-to-end latency often becomes longer. Thus, we have reserved the `sp_threshold`parameter; SP will only take effect when `num_tokens >= sp_threshold`. **The default value is 1000, which generally does not need to be modified.** `sp_threshold` will be appended into `compile_ranges_split_points`, which is a parameter provided by vllm that splits the graph compilation range `[1, max_num_batched_tokens]` into `{[1, split_points[0]], [split_points[0] + 1, split_points[1]], ..., [split_points[-1] + 1, max_num_batched_tokens]}`, and sequentially checks whether the `is_applicable_for_range` of the pass returns `True`.
|
||||
- `"enable_sp"`: This is the switch for SP. Since SP relies on graph mode, it is not supported in eager mode.
|
||||
- `sp_min_token_num` (from upstream vllm's `pass_config`): Based on our experiments, when the number of tokens is small (empirical value is less than 1000), SP can actually bring negative benefits. This is because when the communication volume is small, the fixed overhead of the communication operator becomes the dominant factor. SP will only take effect when `num_tokens >= sp_min_token_num`. **The default value is 1000 on Ascend, which generally does not need to be modified.** To customize, use `--compilation-config '{"pass_config": {"enable_sp": true, "sp_min_token_num": 512}}'`. The value will be appended into `compile_ranges_split_points`, which splits the graph compilation range and checks whether the pass is applicable per range.
|
||||
|
||||
Without modifying `sp_threshold`, the simplest way and recommended way to enable SP is:
|
||||
Without modifying `sp_min_token_num`, the simplest way and recommended way to enable SP is:
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-VL-2B-Instruct \
|
||||
@@ -44,7 +43,7 @@ FC1 is a unique optimization in vllm-ascend, currently implemented based on Cust
|
||||
|
||||
| | VL + Dense | VL + MoE | non-VL + Dense | non-VL + MoE |
|
||||
| -------------------- | ---------- | -------- | -------------- | ------------ |
|
||||
| Sequence Parallelism | graph | x | x | x |
|
||||
| Sequence Parallelism | graph | graph | x | x |
|
||||
| Flash Comm V1 | x | x | eager/graph | eager/graph |
|
||||
|
||||
### With Quantization
|
||||
@@ -55,3 +54,56 @@ SP currently does not support quantization and is under adaptation.
|
||||
| -------------------- | ---------- | -------- | -------------- | ------------ |
|
||||
| Sequence Parallelism | x | x | x | x |
|
||||
| Flash Comm V1 | x | x | eager/graph | eager/graph |
|
||||
|
||||
## Pass Design
|
||||
|
||||
When SP is enabled, the following passes run in order: `SequenceParallelismPass` then `SequenceParallelismMoePass`.
|
||||
|
||||
### SequenceParallelismPass
|
||||
|
||||
Runs `NoOpEliminationPass` first to eliminate redundant view-like operations, then applies AllReduce-based patterns:
|
||||
|
||||
| Pattern | Match | Replacement |
|
||||
| -------------------------------------- | -------------------------------- | ------------------------------------------------------------------------------------- |
|
||||
| `MiddleAllReduceRMSNormPattern` | `all_reduce` + `layernorm` | `reduce_scatter` + `layernorm` + `all_gather` |
|
||||
| `LastAllReduceRMSNormPattern` | Same (last layer, no residual) | Same |
|
||||
| `Qwen3VLMiddleAllReduceRMSNormPattern` | `all_reduce` + add + `layernorm` | `reduce_scatter` + chunk(`deepstack_input_embeds`) + add + `layernorm` + `all_gather` |
|
||||
|
||||
**Why Qwen3 VL needs special handling by Qwen3VLMiddleAllReduceRMSNormPattern**
|
||||
|
||||
Qwen3-VL middle layers insert an extra add between `all_reduce` and `layernorm`: `hidden_states=hidden_states + deepstack_input_embeds`. Under SP, `hidden_states` (i.e., `input`) is reduced-scattered to shape `[seq_len/tp, hidden]` per rank, while `deepstack_input_embeds` comes from the vision/deepstack path and stays full-sequence `[seq_len, hidden]` (typically replicated across TP ranks). Simply doing `reduce_scatter(input) + deepstack_input_embeds` would cause a shape mismatch.
|
||||
The fix is to chunk `deepstack_input_embeds` by `tp_size` so each rank uses `add(reduce_scatter, chunk(deepstack_input_embeds)[tp_rank])`, keeping shapes consistent before `layernorm` and `all_gather`.
|
||||
|
||||
### SequenceParallelismMoePass
|
||||
|
||||
After `SequenceParallelismPass` applies, the MoE model computation graph looks like:
|
||||
|
||||

|
||||
|
||||
**Overview**
|
||||
|
||||
1. **Postponing allgather**: Under SP, `residual` is chunked by tensor parallelism. This causes a shape mismatch between hidden states and residual in the next layer's layernorm: hidden states are gathered (full sequence) while residual remains chunked. The fix is to move `all_gather` to *after* layernorm so that layernorm operates on consistent shapes per rank. `MiddleLayerAllgatherAddRMSNormPattern`, `LastLayerAllgatherRMSNormPattern`, and `Qwen3VLMiddleLayerAllgatherAddRMSNormPattern` are designed for this purpose, each handling different layer and structure variants (see the table below).
|
||||
|
||||
2. **AllGatherChunkNoOp cleanup**: When MoE SP is enabled, vllm introduces a `sequence_parallel_chunk` op (corresponding to `sp_chunk` in the diagram). Together with the preceding `all_gather`, the pair forms a redundant no-op (all_gather gathers, then chunk re-splits). `AllGatherChunkNoOpPattern` replaces this pair with identity to eliminate the redundant communication and computation.
|
||||
|
||||
**Pattern details:**
|
||||
|
||||
| Pattern | Match | Replacement |
|
||||
| ---------------------------------- | ---------------------------------------- | --------------------------------------- |
|
||||
| `MiddleLayerAllgatherAddRMSNormPattern` | `all_gather` + slice + `layernorm` | `layernorm` + `all_gather` |
|
||||
| `LastLayerAllgatherRMSNormPattern` | Same (last layer, no residual) | Same |
|
||||
| `Qwen3VLMiddleLayerAllgatherAddRMSNormPattern` | `all_gather` + slice + add + `layernorm` | add(chunk) + `layernorm` + `all_gather` |
|
||||
| `AllGatherChunkNoOpPattern` | `all_gather` + `sequence_parallel_chunk_impl` | identity (no-op) |
|
||||
|
||||
### FAQ
|
||||
|
||||
#### Q1: Is SP enabled by default?
|
||||
|
||||
No, SP is not enabled by default. SP is currently in the experimental stage and will be enabled by default in the future.
|
||||
|
||||
The processing flow of `enable_sp` in the code is:
|
||||
|
||||
- In `pass_config`, `enable_sp` and `sp_min_token_num` default to `None`
|
||||
- `NPUPlatform.apply_config_platform_defaults`: If `enable_sp` is `True` and `sp_min_token_num` is None, set default `sp_min_token_num` (1000 for Dense models, 1 for MoE models)
|
||||
- `VllmConfig._apply_optimization_level_defaults`: `enable_sp` is set to `True` for dense models.
|
||||
- `VllmConfig.__post_init__`: If `sp_min_token_num` is still `None`, then `enable_sp` is set to `False`
|
||||
|
||||
473
tests/e2e/multicard/2-cards/test_sequence_parallelism_moe.py
Normal file
473
tests/e2e/multicard/2-cards/test_sequence_parallelism_moe.py
Normal file
@@ -0,0 +1,473 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Two-card e2e tests for SequenceParallelismMoePass patterns:
|
||||
# - MiddleLayerAllgatherAddRMSNormPattern (all_gather + slice + RMSNorm)
|
||||
# - Qwen3VLMiddleLayerAllgatherAddRMSNormPattern (all_gather + slice + add + RMSNorm)
|
||||
# - AllGatherChunkNoOpPattern (all_gather + sequence_parallel_chunk_impl -> identity)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import queue
|
||||
import traceback
|
||||
from collections.abc import Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import vllm.config
|
||||
from vllm.compilation.passes.fx_utils import OpOverload
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
init_distributed_environment,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
destroy_distributed_environment,
|
||||
destroy_model_parallel,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
import vllm_ascend.ops.register_custom_ops # noqa
|
||||
from tests.e2e.singlecard.compile.backend import TestBackend as CompileTestBackend
|
||||
from vllm_ascend.compilation.passes.sequence_parallelism_moe import (
|
||||
SequenceParallelismMoePass,
|
||||
)
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
MASTER_PORT = 29500
|
||||
WORLD_SIZE = 2
|
||||
WORKER_READY = "__ready__"
|
||||
WORKER_STOP = "__stop__"
|
||||
WORKER_RESULT_TIMEOUT_S = 180
|
||||
WORKER_JOIN_TIMEOUT_S = 30
|
||||
|
||||
|
||||
class BaseAllGatherRMSNormModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float = 1e-6,
|
||||
device: str = "npu",
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.norm_w = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
def _all_gather_sliced(self, x: torch.Tensor, num_tokens_helper: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = num_tokens_helper.shape[0]
|
||||
activated = torch.relu(x)
|
||||
gathered = tensor_model_parallel_all_gather(activated, 0)
|
||||
return gathered[:num_tokens]
|
||||
|
||||
@staticmethod
|
||||
def ops_in_model_after() -> tuple[tuple[OpOverload, int], ...]:
|
||||
return (
|
||||
(torch.ops.vllm.all_gather.default, 1),
|
||||
(torch.ops._C_ascend.npu_add_rms_norm_bias.default, 1),
|
||||
(torch.ops.vllm.maybe_chunk_residual.default, 1),
|
||||
)
|
||||
|
||||
|
||||
class AllGatherRMSNormModel(BaseAllGatherRMSNormModel):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
num_tokens_helper: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
sliced = self._all_gather_sliced(x, num_tokens_helper)
|
||||
rms_out = torch.ops._C_ascend.npu_add_rms_norm_bias(sliced, residual, self.norm_w, None, self.eps)
|
||||
return rms_out[0]
|
||||
|
||||
@staticmethod
|
||||
def ops_in_model_before() -> tuple[tuple[OpOverload, int], ...]:
|
||||
return (
|
||||
(torch.ops.vllm.all_gather.default, 1),
|
||||
(torch.ops._C_ascend.npu_add_rms_norm_bias.default, 1),
|
||||
)
|
||||
|
||||
|
||||
class Qwen3VLAllGatherRMSNormModel(BaseAllGatherRMSNormModel):
|
||||
"""Exercises Qwen3VLMiddleLayerAllgatherAddRMSNormPattern (all_gather + slice + add + RMSNorm)."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
num_tokens_helper: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
sliced = self._all_gather_sliced(x, num_tokens_helper)
|
||||
add_ = sliced + deepstack_input_embeds
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, self.norm_w, None, self.eps)
|
||||
# Keep the residual output live so the traced graph preserves the full pattern.
|
||||
result = result - residual
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def ops_in_model_before() -> tuple[tuple[OpOverload, int], ...]:
|
||||
return (
|
||||
(torch.ops.vllm.all_gather.default, 1),
|
||||
(torch.ops.aten.add.Tensor, 1),
|
||||
(torch.ops._C_ascend.npu_add_rms_norm_bias.default, 1),
|
||||
)
|
||||
|
||||
|
||||
class AllGatherChunkNoOpModel(nn.Module):
|
||||
"""Exercises AllGatherChunkNoOpPattern (all_gather + sequence_parallel_chunk_impl -> identity)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
z = torch.relu(x)
|
||||
gathered = tensor_model_parallel_all_gather(z, 0)
|
||||
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
|
||||
|
||||
@staticmethod
|
||||
def ops_in_model_before() -> tuple[tuple[OpOverload, int], ...]:
|
||||
return (
|
||||
(torch.ops.vllm.all_gather.default, 1),
|
||||
(torch.ops.vllm.sequence_parallel_chunk_impl.default, 1),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def ops_in_model_after() -> tuple[tuple[OpOverload, int], ...]:
|
||||
return (
|
||||
(torch.ops.vllm.all_gather.default, 0),
|
||||
(torch.ops.vllm.sequence_parallel_chunk_impl.default, 0),
|
||||
)
|
||||
|
||||
|
||||
def _build_all_gather_rms_norm_inputs(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
tp_size: int,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
local_tokens = batch_size * seq_len
|
||||
num_tokens = local_tokens * tp_size
|
||||
x = torch.randn(local_tokens, hidden_size, dtype=dtype)
|
||||
residual = torch.zeros(num_tokens, hidden_size, dtype=dtype)
|
||||
num_tokens_helper = torch.empty(num_tokens, device=x.device, dtype=dtype)
|
||||
return (x, residual, num_tokens_helper)
|
||||
|
||||
|
||||
def _build_qwen3vl_inputs(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
tp_size: int,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
x, residual, num_tokens_helper = _build_all_gather_rms_norm_inputs(
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
deepstack = torch.randn(num_tokens_helper.shape[0], hidden_size, dtype=dtype)
|
||||
return (x, residual, num_tokens_helper, deepstack)
|
||||
|
||||
|
||||
def _build_allgather_chunk_noop_inputs(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
tp_size: int,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
del tp_size
|
||||
local_tokens = batch_size * seq_len
|
||||
x = torch.randn(local_tokens, hidden_size, dtype=dtype)
|
||||
return (x,)
|
||||
|
||||
|
||||
def _create_all_gather_rms_norm_model(
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float,
|
||||
device: str,
|
||||
) -> nn.Module:
|
||||
return AllGatherRMSNormModel(hidden_size=hidden_size, dtype=dtype, eps=eps, device=device)
|
||||
|
||||
|
||||
def _create_qwen3vl_model(
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float,
|
||||
device: str,
|
||||
) -> nn.Module:
|
||||
return Qwen3VLAllGatherRMSNormModel(hidden_size=hidden_size, dtype=dtype, eps=eps, device=device)
|
||||
|
||||
|
||||
def _create_allgather_chunk_noop_model(
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float,
|
||||
device: str,
|
||||
) -> nn.Module:
|
||||
del hidden_size, dtype, eps, device
|
||||
return AllGatherChunkNoOpModel()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PatternTestCase:
|
||||
model_factory: Any
|
||||
input_builder: Any
|
||||
dynamic_input_indices: tuple[int, ...]
|
||||
pre_pass_expected_counts_factory: Any
|
||||
post_pass_expected_counts_factory: Any
|
||||
|
||||
|
||||
PATTERN_TEST_CASES = {
|
||||
"middle_layer_allgather_add_rms_norm": PatternTestCase(
|
||||
model_factory=_create_all_gather_rms_norm_model,
|
||||
input_builder=_build_all_gather_rms_norm_inputs,
|
||||
dynamic_input_indices=(0, 2),
|
||||
pre_pass_expected_counts_factory=AllGatherRMSNormModel.ops_in_model_before,
|
||||
post_pass_expected_counts_factory=AllGatherRMSNormModel.ops_in_model_after,
|
||||
),
|
||||
"qwen3vl_middle_layer_allgather_add_rms_norm": PatternTestCase(
|
||||
model_factory=_create_qwen3vl_model,
|
||||
input_builder=_build_qwen3vl_inputs,
|
||||
dynamic_input_indices=(0, 2, 3),
|
||||
pre_pass_expected_counts_factory=Qwen3VLAllGatherRMSNormModel.ops_in_model_before,
|
||||
post_pass_expected_counts_factory=Qwen3VLAllGatherRMSNormModel.ops_in_model_after,
|
||||
),
|
||||
"allgather_chunk_noop": PatternTestCase(
|
||||
model_factory=_create_allgather_chunk_noop_model,
|
||||
input_builder=_build_allgather_chunk_noop_inputs,
|
||||
dynamic_input_indices=(0,),
|
||||
pre_pass_expected_counts_factory=AllGatherChunkNoOpModel.ops_in_model_before,
|
||||
post_pass_expected_counts_factory=AllGatherChunkNoOpModel.ops_in_model_after,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _assert_op_counts(
|
||||
backend: CompileTestBackend,
|
||||
expected_counts: tuple[tuple[OpOverload, int], ...],
|
||||
before: bool = False,
|
||||
) -> None:
|
||||
for op, expected_count in expected_counts:
|
||||
actual_count = backend.op_count(op, before=before)
|
||||
stage = "before" if before else "after"
|
||||
assert actual_count == expected_count, (
|
||||
f"op {stage} pass: {op} expected {expected_count}, but got {actual_count}"
|
||||
)
|
||||
|
||||
|
||||
def _run_single_pattern_case(
|
||||
local_rank: int,
|
||||
case_name: str,
|
||||
vllm_config: VllmConfig,
|
||||
tp_size: int,
|
||||
batch_size: int = 8,
|
||||
seq_len: int = 16,
|
||||
hidden_size: int = 16,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
eps: float = 1e-5,
|
||||
) -> None:
|
||||
case = PATTERN_TEST_CASES[case_name]
|
||||
sp_moe_pass = SequenceParallelismMoePass(vllm_config)
|
||||
backend = CompileTestBackend(custom_passes=[sp_moe_pass])
|
||||
model = case.model_factory(
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
eps=eps,
|
||||
device=f"npu:{local_rank}",
|
||||
)
|
||||
inputs = case.input_builder(
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
for dynamic_input_index in case.dynamic_input_indices:
|
||||
torch._dynamo.mark_dynamic(inputs[dynamic_input_index], 0)
|
||||
|
||||
unfused = model(*inputs)
|
||||
compiled = torch.compile(model, backend=backend)
|
||||
fused = compiled(*inputs)
|
||||
assert unfused.shape == fused.shape
|
||||
|
||||
assert sp_moe_pass.matched_count == 1
|
||||
_assert_op_counts(backend, case.pre_pass_expected_counts_factory(), before=True)
|
||||
_assert_op_counts(backend, case.post_pass_expected_counts_factory())
|
||||
|
||||
|
||||
def _run_sequence_parallelism_moe_test(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
master_port: int,
|
||||
command_queue: Any,
|
||||
result_queue: Any,
|
||||
batch_size: int = 8,
|
||||
seq_len: int = 16,
|
||||
hidden_size: int = 16,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
eps: float = 1e-5,
|
||||
) -> None:
|
||||
torch.npu.set_device(local_rank)
|
||||
torch.set_default_device(f"npu:{local_rank}")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "127.0.0.1",
|
||||
"MASTER_PORT": str(master_port),
|
||||
}
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
|
||||
|
||||
try:
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=local_rank,
|
||||
local_rank=local_rank,
|
||||
backend="hccl",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
if not enable_custom_op():
|
||||
raise RuntimeError("vllm_ascend custom ops are not available")
|
||||
|
||||
_ = get_tp_group().unique_name
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
result_queue.put((WORKER_READY, local_rank, "ok", ""))
|
||||
|
||||
while True:
|
||||
case_name = command_queue.get()
|
||||
if case_name == WORKER_STOP:
|
||||
return
|
||||
|
||||
try:
|
||||
_run_single_pattern_case(
|
||||
local_rank=local_rank,
|
||||
case_name=case_name,
|
||||
vllm_config=vllm_config,
|
||||
tp_size=tp_size,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
eps=eps,
|
||||
)
|
||||
except Exception:
|
||||
result_queue.put((case_name, local_rank, "error", traceback.format_exc()))
|
||||
else:
|
||||
result_queue.put((case_name, local_rank, "ok", ""))
|
||||
finally:
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def _worker_entrypoint(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
master_port: int,
|
||||
command_queue: Any,
|
||||
result_queue: Any,
|
||||
) -> None:
|
||||
try:
|
||||
_run_sequence_parallelism_moe_test(
|
||||
local_rank=local_rank,
|
||||
world_size=world_size,
|
||||
master_port=master_port,
|
||||
command_queue=command_queue,
|
||||
result_queue=result_queue,
|
||||
)
|
||||
except Exception:
|
||||
result_queue.put((WORKER_READY, local_rank, "error", traceback.format_exc()))
|
||||
|
||||
|
||||
def _wait_for_worker_reports(
|
||||
result_queue: Any,
|
||||
case_name: str,
|
||||
expected_reports: int,
|
||||
) -> None:
|
||||
errors = []
|
||||
for _ in range(expected_reports):
|
||||
try:
|
||||
reported_case_name, local_rank, status, payload = result_queue.get(timeout=WORKER_RESULT_TIMEOUT_S)
|
||||
except queue.Empty as exc:
|
||||
raise TimeoutError(f"Timed out waiting for worker reports for {case_name}") from exc
|
||||
|
||||
assert reported_case_name == case_name, f"Expected worker report for {case_name}, but got {reported_case_name}"
|
||||
if status != "ok":
|
||||
errors.append(f"rank {local_rank}:\n{payload}")
|
||||
|
||||
if errors:
|
||||
raise AssertionError("\n\n".join(errors))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sequence_parallelism_moe_workers() -> Generator[Callable[[str], None], None, None]:
|
||||
ctx = torch.multiprocessing.get_context("spawn")
|
||||
command_queues = [ctx.Queue() for _ in range(WORLD_SIZE)]
|
||||
result_queue = ctx.Queue()
|
||||
workers = []
|
||||
|
||||
for local_rank in range(WORLD_SIZE):
|
||||
worker = ctx.Process(
|
||||
target=_worker_entrypoint,
|
||||
args=(local_rank, WORLD_SIZE, MASTER_PORT, command_queues[local_rank], result_queue),
|
||||
)
|
||||
worker.start()
|
||||
workers.append(worker)
|
||||
|
||||
try:
|
||||
_wait_for_worker_reports(result_queue, WORKER_READY, WORLD_SIZE)
|
||||
|
||||
def _run_case(case_name: str) -> None:
|
||||
for command_queue in command_queues:
|
||||
command_queue.put(case_name)
|
||||
_wait_for_worker_reports(result_queue, case_name, WORLD_SIZE)
|
||||
|
||||
yield _run_case
|
||||
finally:
|
||||
for command_queue in command_queues:
|
||||
command_queue.put(WORKER_STOP)
|
||||
for worker in workers:
|
||||
worker.join(timeout=WORKER_JOIN_TIMEOUT_S)
|
||||
if worker.is_alive():
|
||||
worker.terminate()
|
||||
worker.join()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case_name", tuple(PATTERN_TEST_CASES), ids=tuple(PATTERN_TEST_CASES))
|
||||
def test_sequence_parallelism_moe_patterns(
|
||||
sequence_parallelism_moe_workers: Callable[[str], None], case_name: str
|
||||
) -> None:
|
||||
sequence_parallelism_moe_workers(case_name)
|
||||
@@ -39,9 +39,9 @@ def test_qwen3_vl_sp_tp2(model: str) -> None:
|
||||
compilation_config={
|
||||
"cudagraph_capture_sizes": [2, 4],
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"pass_config": {"enable_sp": True},
|
||||
"pass_config": {"enable_sp": True, "sp_min_token_num": 10},
|
||||
},
|
||||
additional_config={"sp_threshold": 10, "ascend_compilation_config": {"enable_npugraph_ex": False}},
|
||||
additional_config={"ascend_compilation_config": {"enable_npugraph_ex": False}},
|
||||
) as runner:
|
||||
sp_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@@ -100,6 +100,11 @@ class TestBackend:
|
||||
"""Helper to find all FX nodes that call a specific operator."""
|
||||
return [node for node in graph.graph.nodes if hasattr(node, "target") and node.target == target]
|
||||
|
||||
def op_count(self, op: OpOverload, before: bool = False) -> int:
|
||||
"""Return the number of nodes that call the given operator."""
|
||||
graph = self.graph_pre_pass if before else self.graph_post_pass
|
||||
return len(self.find_nodes_by_target(graph, op))
|
||||
|
||||
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced: bool = True):
|
||||
"""
|
||||
Verify that the original (unfused) operators exist before the pass
|
||||
|
||||
@@ -188,7 +188,10 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
|
||||
return_value=False)
|
||||
def test_allgather_prepare_finalize(self, mock_enable_sp,
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp_by_pass",
|
||||
return_value=False)
|
||||
def test_allgather_prepare_finalize(self, mock_enable_sp_by_pass,
|
||||
mock_enable_sp,
|
||||
mock_get_forward_context,
|
||||
mock_get_dp_group):
|
||||
# Mock forward context
|
||||
|
||||
@@ -159,6 +159,12 @@ class AscendConfig:
|
||||
and get_ascend_device_type() != AscendDeviceType.A5
|
||||
)
|
||||
|
||||
self.enable_sp_by_pass = (
|
||||
vllm_config.model_config is not None
|
||||
and not vllm_config.model_config.enforce_eager
|
||||
and vllm_config.compilation_config.pass_config.enable_sp
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_compile_ranges(compilation_config):
|
||||
return compilation_config.compile_ranges_endpoints or []
|
||||
@@ -195,14 +201,6 @@ class AscendConfig:
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
if vllm_config.compilation_config.pass_config.enable_sp and not is_moe_model(vllm_config):
|
||||
from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_threshold
|
||||
|
||||
sp_threshold = get_sp_threshold(vllm_config)
|
||||
new_compile_ranges_split_points.append(sp_threshold)
|
||||
logger.debug(f"add {sp_threshold} to compile_ranges_split_points for sequence parallelism")
|
||||
if len(new_compile_ranges_split_points) > len(self._get_compile_ranges(vllm_config.compilation_config)):
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
self._set_compile_ranges(vllm_config.compilation_config, new_compile_ranges_split_points)
|
||||
|
||||
@@ -70,6 +70,8 @@ class GraphFusionPassManager:
|
||||
self.passes.append(MulsAddFusionPass(config))
|
||||
|
||||
if config.compilation_config.pass_config.enable_sp:
|
||||
from .passes.sequence_parallelism import AscendSequenceParallelismPass
|
||||
from .passes.sequence_parallelism import SequenceParallelismPass
|
||||
from .passes.sequence_parallelism_moe import SequenceParallelismMoePass
|
||||
|
||||
self.passes.append(AscendSequenceParallelismPass(config))
|
||||
self.passes.append(SequenceParallelismPass(config))
|
||||
self.passes.append(SequenceParallelismMoePass(config))
|
||||
|
||||
40
vllm_ascend/compilation/passes/allgather_chunk_noop_pass.py
Normal file
40
vllm_ascend/compilation/passes/allgather_chunk_noop_pass.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
class AllGatherChunkNoOpCleanupPass(VllmInductorPass):
|
||||
"""Fold all_gather + sequence_parallel_chunk_impl into identity."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_allgather_chunk_noop_cleanup_pass")
|
||||
self._register_patterns()
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
|
||||
|
||||
def _empty(self, *args, **kwargs):
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
|
||||
|
||||
def _register_patterns(self) -> None:
|
||||
def pattern(input: torch.Tensor) -> torch.Tensor:
|
||||
gathered = self._all_gather(input)
|
||||
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
|
||||
|
||||
def replacement(input: torch.Tensor) -> torch.Tensor:
|
||||
return input
|
||||
|
||||
pm.register_replacement(pattern, replacement, [self._empty(8, 16)], pm.fwd_only, self.patterns)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.begin()
|
||||
matched_count = self.patterns.apply(graph)
|
||||
logger.debug("AllGatherChunkNoOpCleanupPass replaced %s patterns", matched_count)
|
||||
self.end_and_log()
|
||||
62
vllm_ascend/compilation/passes/noop_elimination.py
Normal file
62
vllm_ascend/compilation/passes/noop_elimination.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
class NoOpEliminationPass(VllmInductorPass):
|
||||
"""Remove no-op view/reshape nodes after pattern rewrites."""
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
fx_graph = graph.graph if hasattr(graph, "graph") else graph
|
||||
removed = 0
|
||||
for node in list(fx_graph.nodes):
|
||||
if not self._is_view_like(node):
|
||||
continue
|
||||
|
||||
input_node = node.args[0]
|
||||
if not isinstance(input_node, torch.fx.Node):
|
||||
continue
|
||||
|
||||
input_meta = input_node.meta.get("val")
|
||||
output_meta = node.meta.get("val")
|
||||
if input_meta is None or output_meta is None:
|
||||
continue
|
||||
|
||||
input_shape = getattr(input_meta, "shape", None)
|
||||
output_shape = getattr(output_meta, "shape", None)
|
||||
if input_shape is None or output_shape is None:
|
||||
continue
|
||||
|
||||
if self._all_dims_equivalent(input_shape, output_shape):
|
||||
node.replace_all_uses_with(input_node)
|
||||
fx_graph.erase_node(node)
|
||||
removed += 1
|
||||
|
||||
logger.debug("NoOpEliminationPass removed %s no-op views", removed)
|
||||
|
||||
@staticmethod
|
||||
def _is_view_like(node: torch.fx.Node) -> bool:
|
||||
return (node.op == "call_method" and node.target in {"view", "reshape"}) or (
|
||||
node.op == "call_function"
|
||||
and node.target
|
||||
in {
|
||||
torch.ops.aten.view.default,
|
||||
torch.ops.aten.reshape.default,
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _dims_equivalent(dim: int | SymInt, i_dim: int | SymInt) -> bool:
|
||||
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
|
||||
|
||||
def _all_dims_equivalent(self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]) -> bool:
|
||||
dims_ = list(dims)
|
||||
i_dims_ = list(i_dims)
|
||||
if len(dims_) != len(i_dims_):
|
||||
return False
|
||||
return all(self._dims_equivalent(s, i_s) for s, i_s in zip(dims_, i_dims_))
|
||||
@@ -7,21 +7,25 @@ from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.compilation.passes.noop_elimination import NoOpEliminationPass
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
SP_THRESHOLD = 1000
|
||||
SP_MIN_TOKEN_NUM_DEFAULT = 1000
|
||||
|
||||
|
||||
def get_sp_threshold(config: VllmConfig):
|
||||
def get_sp_min_token_num(config: VllmConfig) -> int:
|
||||
if is_moe_model(config):
|
||||
return 1
|
||||
|
||||
additional_config = config.additional_config if config.additional_config is not None else {}
|
||||
return additional_config.get("sp_threshold", SP_THRESHOLD)
|
||||
return SP_MIN_TOKEN_NUM_DEFAULT
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
"""Helper for sequence parallelism patterns.
|
||||
|
||||
Provides TP communication helper methods: _all_reduce, _reduce_scatter,
|
||||
_all_gather, and tensor creation utilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -49,7 +53,10 @@ class _SequenceParallelPatternHelper:
|
||||
return torch.empty(*args, dtype=self.dtype, device="npu", **kws)
|
||||
|
||||
|
||||
class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Replaces all_reduce + AddRMSNormBias with reduce_scatter + AddRMSNormBias
|
||||
+ all_gather for middle-layer sequence parallelism."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -92,7 +99,10 @@ class AscendMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Same as MiddleAllReduceRMSNormPattern but for the last layer
|
||||
(no residual backprop)."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -127,7 +137,13 @@ class AscendLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
class Qwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""For Qwen3-VL middle layers with hidden_states + deepstack_input_embeds add.
|
||||
|
||||
Replaces all_reduce + add + AddRMSNormBias with reduce_scatter +
|
||||
chunk(deepstack_input_embeds) + add + AddRMSNormBias + all_gather.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
@@ -168,25 +184,45 @@ class AscendQwen3VLMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AscendSequenceParallelismPass(VllmInductorPass):
|
||||
class SequenceParallelismPass(VllmInductorPass):
|
||||
"""Sequence parallelism compilation pass.
|
||||
|
||||
Registers and applies the above patterns. Runs noop cleanup first, then
|
||||
uses token range to determine whether to enable SP.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_pass")
|
||||
self.noop_cleanup = NoOpEliminationPass(config)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
AscendMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
MiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendLastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
LastAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AscendQwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
Qwen3VLMiddleAllReduceRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
self.min_tokens = get_sp_threshold(config)
|
||||
self.min_tokens = get_sp_min_token_num(config)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.noop_cleanup(graph) # Eliminate redundant view-like operations
|
||||
logger.debug(f"after noop_cleanup {graph.graph}")
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
logger.debug(f"after apply replacement {graph.graph}")
|
||||
|
||||
from torch._inductor.pattern_matcher import PatternPrettyPrinter
|
||||
|
||||
pattern_idx = 0
|
||||
for pattern_entry in self.patterns.patterns.values():
|
||||
for p in pattern_entry:
|
||||
p_str = PatternPrettyPrinter.run(p.pattern)
|
||||
logger.debug("Pattern %d: %s", pattern_idx, p_str)
|
||||
pattern_idx += 1
|
||||
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
|
||||
204
vllm_ascend/compilation/passes/sequence_parallelism_moe.py
Normal file
204
vllm_ascend/compilation/passes/sequence_parallelism_moe.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from vllm.compilation.passes.vllm_inductor_pass import PatternPrettyPrinter, VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.compilation.passes.sequence_parallelism import (
|
||||
_SequenceParallelPatternHelper,
|
||||
get_sp_min_token_num,
|
||||
)
|
||||
|
||||
|
||||
class MiddleLayerAllgatherAddRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Replaces all_gather + slice + AddRMSNormBias with AddRMSNormBias +
|
||||
all_gather to avoid middle-layer shape mismatch."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
# num_tokens = 8
|
||||
return [input, weight, residual]
|
||||
|
||||
def get_scalar_inputs(self):
|
||||
return {"num_tokens": 8}
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather = self._all_gather(input)
|
||||
x_sliced = all_gather[:num_tokens]
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(x_sliced, residual, weight, None, self.eps)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(input, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather, residual
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
|
||||
)
|
||||
|
||||
|
||||
class LastLayerAllgatherRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Same as MiddleLayerAllgatherAddRMSNormPattern but for the last layer (no residual)
|
||||
all_gather + RMSNorm fusion."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def get_scalar_inputs(self):
|
||||
return {"num_tokens": 8}
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather = self._all_gather(input)
|
||||
x_sliced = all_gather[:num_tokens]
|
||||
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(x_sliced, residual, weight, None, self.eps)
|
||||
|
||||
return result
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, num_tokens
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
|
||||
result, _, _ = torch.ops._C_ascend.npu_add_rms_norm_bias(input, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
|
||||
)
|
||||
|
||||
|
||||
class Qwen3VLMiddleLayerAllgatherAddRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
"""Replaces all_gather + slice + add + AddRMSNormBias with add(chunk) +
|
||||
AddRMSNormBias + all_gather for Qwen3-VL-style all_gather path."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
input = self.empty(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(8, 16)
|
||||
deepstack_input_embeds = self.empty(8, 16)
|
||||
return [input, weight, residual, deepstack_input_embeds]
|
||||
|
||||
def get_scalar_inputs(self):
|
||||
return {"num_tokens": 8}
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
num_tokens,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather = self._all_gather(input)
|
||||
x_sliced = all_gather[:num_tokens]
|
||||
add_ = x_sliced + deepstack_input_embeds
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
num_tokens,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
chunk = deepstack_input_embeds.chunk(self.tp_size)[self.tp_rank]
|
||||
add_ = input + chunk
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(input, residual)
|
||||
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, weight, None, self.eps)
|
||||
all_gather = self._all_gather(result)
|
||||
return all_gather, residual
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass, scalar_workaround=self.get_scalar_inputs()
|
||||
)
|
||||
|
||||
|
||||
class AllGatherChunkNoOpPattern(_SequenceParallelPatternHelper):
|
||||
"""Folds all_gather + sequence_parallel_chunk_impl into identity (no-op)."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
super().__init__(eps, vllm_config.model_config.dtype, torch.npu.current_device())
|
||||
|
||||
def get_inputs(self):
|
||||
return [self.empty(8, 16)]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor) -> torch.Tensor:
|
||||
gathered = self._all_gather(input)
|
||||
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
|
||||
|
||||
def replacement(input: torch.Tensor) -> torch.Tensor:
|
||||
return input
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SequenceParallelismMoePass(VllmInductorPass):
|
||||
"""Sequence parallelism AllGather epilogue pass.
|
||||
|
||||
Applies AllGather-based patterns: MiddleLayerAllgatherAddRMSNormPattern,
|
||||
LastLayerAllgatherRMSNormPattern, Qwen3VLMiddleLayerAllgatherAddRMSNormPattern,
|
||||
and AllGatherChunkNoOpPattern (all_gather + sequence_parallel_chunk_impl -> identity).
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_sequence_parallelism_allgather_ep_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
MiddleLayerAllgatherAddRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
LastLayerAllgatherRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
Qwen3VLMiddleLayerAllgatherAddRMSNormPattern(config, epsilon).register(self.patterns)
|
||||
|
||||
AllGatherChunkNoOpPattern(config).register(self.patterns)
|
||||
|
||||
self.min_tokens = get_sp_min_token_num(config)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
logger.debug(f"before apply replacement {graph}")
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug(f"after apply replacement {graph}")
|
||||
logger.debug("SequenceParallelismMoePass replaced %s patterns", self.matched_count)
|
||||
pattern_idx = 0
|
||||
for pattern_entry in self.patterns.patterns.values():
|
||||
for p in pattern_entry:
|
||||
p_str = PatternPrettyPrinter.run(p.pattern)
|
||||
logger.debug("Pattern %d: %s", pattern_idx, p_str)
|
||||
pattern_idx += 1
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
applicable = compile_range.start >= self.min_tokens
|
||||
logger.debug(f"SequenceParallelismMoePass {compile_range=} {applicable=}")
|
||||
return applicable
|
||||
@@ -33,7 +33,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEPrepareOutput
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
||||
from vllm_ascend.utils import enable_sp, enable_sp_by_pass, npu_stream_switch, prefill_context_parallel_enable
|
||||
|
||||
|
||||
class PrepareAndFinalize(ABC):
|
||||
@@ -324,7 +324,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
Returns:
|
||||
MoEPrepareOutput with global tensors.
|
||||
"""
|
||||
if enable_sp():
|
||||
if enable_sp() or enable_sp_by_pass():
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce)
|
||||
@@ -433,7 +433,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
Returns:
|
||||
Tensor with shape [local_num_tokens, hidden_size]
|
||||
"""
|
||||
if enable_sp():
|
||||
if enable_sp() or enable_sp_by_pass():
|
||||
return self._finalize_with_ep_group(hidden_states)
|
||||
|
||||
return self._finalize_with_dp_group(hidden_states, reduce_results)
|
||||
|
||||
@@ -17,7 +17,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
||||
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
from vllm_ascend.utils import enable_sp_by_pass, npu_stream_switch, prefetch_stream
|
||||
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
@@ -43,7 +43,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
except AssertionError:
|
||||
return x
|
||||
|
||||
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
|
||||
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled or (enable_sp_by_pass() and is_ep_comm)
|
||||
if flash_comm_v1_enabled and label:
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
@@ -53,6 +53,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
x = x[:-pad_size]
|
||||
else:
|
||||
x = get_ep_group().all_gather(x, 0)
|
||||
if enable_sp_by_pass(): # TODO: do unpad
|
||||
return x
|
||||
# unpad
|
||||
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
|
||||
@@ -74,7 +76,11 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
if not getattr(forward_context, "flash_comm_v1_enabled", False):
|
||||
flash_comm_v1_enabled = getattr(forward_context, "flash_comm_v1_enabled", False) or (
|
||||
enable_sp_by_pass() and is_ep_comm
|
||||
)
|
||||
|
||||
if not flash_comm_v1_enabled:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
@@ -84,6 +90,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
return tensor_model_parallel_reduce_scatter(x, 0)
|
||||
else:
|
||||
if enable_sp_by_pass():
|
||||
return get_ep_group().reduce_scatter(x.view(-1, *x.shape[1:]), 0)
|
||||
# padding
|
||||
dp_size = get_dp_group().world_size
|
||||
num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
||||
@@ -107,7 +115,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if _EXTRA_CTX.flash_comm_v1_enabled:
|
||||
if _EXTRA_CTX.flash_comm_v1_enabled or enable_sp_by_pass():
|
||||
return torch.empty(
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
|
||||
@@ -167,15 +167,12 @@
|
||||
# 1. `vllm.distributed.parallel_state.GroupCoordinator`
|
||||
# Why:
|
||||
# vllm doesn't support all_to_all for GroupCoordinator.
|
||||
# all_reduce in vLLM not is a customop, which will make MatmulAllReduceAddRMSNorm fusion failure.
|
||||
# How:
|
||||
# Add all_to_all implementation for GroupCoordinator.
|
||||
# make all_reduce as a customop.
|
||||
# Related PR (if no, explain why):
|
||||
# No, we should use vlLM all2all manager to support all_to_all for npu.
|
||||
# Future Plan:
|
||||
# Remove this patch when the refactor of all2all manager is done.
|
||||
# Remove this patch when vLLM support all_reduce as customop.
|
||||
#
|
||||
# ** 2. File: worker/patch_multimodal_merge.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -84,7 +84,7 @@ class GroupCoordinatorPatch(GroupCoordinator):
|
||||
if use_message_queue_broadcaster and self.world_size > 1:
|
||||
self.mq_broadcaster = MessageQueue.create_from_process_group(self.cpu_group, 1 << 22, 6)
|
||||
|
||||
self.use_custom_op_call = False
|
||||
self.use_custom_op_call = True
|
||||
self.use_cpu_custom_send_recv = False
|
||||
|
||||
def all_to_all(
|
||||
@@ -106,10 +106,5 @@ class GroupCoordinatorPatch(GroupCoordinator):
|
||||
assert self.device_communicator is not None, "device_communicator should be initialized when world_size > 1"
|
||||
return self.device_communicator.all_to_all(input_, scatter_dim, gather_dim, scatter_sizes, gather_sizes)
|
||||
|
||||
def all_reduce(self, input_):
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch
|
||||
|
||||
@@ -156,6 +156,16 @@ class NPUPlatform(Platform):
|
||||
def get_device_capability(cls, device_id: int = 0):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def apply_config_platform_defaults(cls, vllm_config: VllmConfig) -> None:
|
||||
"""Apply Ascend-specific defaults. Set sp_min_token_num=1 when enable_sp and not set."""
|
||||
pass_config = vllm_config.compilation_config.pass_config
|
||||
if pass_config.enable_sp and pass_config.sp_min_token_num is None:
|
||||
from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_min_token_num
|
||||
|
||||
pass_config.sp_min_token_num = get_sp_min_token_num(vllm_config)
|
||||
logger.info(f"set sp_min_token_num to {pass_config.sp_min_token_num}")
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.npu.get_device_name(device_id)
|
||||
@@ -198,6 +208,7 @@ class NPUPlatform(Platform):
|
||||
|
||||
# initialize ascend config from vllm additional_config
|
||||
cls._fix_incompatible_config(vllm_config)
|
||||
|
||||
ascend_config = init_ascend_config(vllm_config)
|
||||
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
@@ -218,6 +229,7 @@ class NPUPlatform(Platform):
|
||||
if not isinstance(ascend_compilation_config, dict)
|
||||
else ascend_compilation_config
|
||||
)
|
||||
|
||||
ascend_config.update_compile_ranges_split_points()
|
||||
|
||||
if model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
@@ -363,7 +375,8 @@ class NPUPlatform(Platform):
|
||||
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
|
||||
parallel_config.all2all_backend = "flashinfer_all2allv"
|
||||
if not vllm_config.compilation_config.pass_config.enable_sp:
|
||||
parallel_config.all2all_backend = "flashinfer_all2allv"
|
||||
if is_310p():
|
||||
parallel_config.worker_cls = "vllm_ascend._310p.worker_310p.NPUWorker310"
|
||||
elif ascend_config.xlite_graph_config.enabled:
|
||||
@@ -805,3 +818,7 @@ class NPUPlatform(Platform):
|
||||
"ignored on Ascend. Resetting to default (32)."
|
||||
)
|
||||
att_config.flash_attn_max_num_splits_for_cuda_graph = 32
|
||||
|
||||
@classmethod
|
||||
def use_custom_op_collectives(cls) -> bool:
|
||||
return True
|
||||
|
||||
@@ -764,8 +764,8 @@ def matmul_allreduce_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||
|
||||
|
||||
def enable_sp_by_pass(vllm_config: VllmConfig):
|
||||
return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp
|
||||
def enable_sp_by_pass():
|
||||
return get_ascend_config().enable_sp_by_pass
|
||||
|
||||
|
||||
def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
@@ -791,7 +791,7 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
|
||||
# TODO remove it after vllm has this func
|
||||
def shared_expert_dp_enabled() -> bool:
|
||||
return get_ascend_config().enable_shared_expert_dp or enable_sp()
|
||||
return get_ascend_config().enable_shared_expert_dp or enable_sp() or enable_sp_by_pass()
|
||||
|
||||
|
||||
def prefill_context_parallel_enable() -> bool:
|
||||
|
||||
@@ -1846,7 +1846,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if enable_sp(self.vllm_config) or enable_sp_by_pass(self.vllm_config):
|
||||
if enable_sp(self.vllm_config) or enable_sp_by_pass():
|
||||
return round_up(num_scheduled_tokens, tp_size)
|
||||
return num_scheduled_tokens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user