### What this PR does / why we need it?
2nd PR for https://github.com/vllm-project/vllm-ascend/issues/5712,
extend SP to VL MoE models.
### Does this PR introduce _any_ user-facing change?
remove `sp_threshold` in additional config and reuse `sp_min_token_num`
from vLLM.
### How was this patch tested?
- Model: Qwen3-VL-30B-A3B,
- TP4 DP2
- 100 reqs
- max concurrency 1
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|------------|---------------------|------------------------|
| 4k | 429.40 | 323.3 |
| 16k | 1297.01 | 911.74 |
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
474 lines
15 KiB
Python
474 lines
15 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Two-card e2e tests for SequenceParallelismMoePass patterns:
|
|
# - MiddleLayerAllgatherAddRMSNormPattern (all_gather + slice + RMSNorm)
|
|
# - Qwen3VLMiddleLayerAllgatherAddRMSNormPattern (all_gather + slice + add + RMSNorm)
|
|
# - AllGatherChunkNoOpPattern (all_gather + sequence_parallel_chunk_impl -> identity)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import queue
|
|
import traceback
|
|
from collections.abc import Callable, Generator
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
import vllm.config
|
|
from vllm.compilation.passes.fx_utils import OpOverload
|
|
from vllm.config import ModelConfig, VllmConfig
|
|
from vllm.distributed import (
|
|
get_tensor_model_parallel_world_size,
|
|
get_tp_group,
|
|
init_distributed_environment,
|
|
tensor_model_parallel_all_gather,
|
|
)
|
|
from vllm.distributed.parallel_state import (
|
|
destroy_distributed_environment,
|
|
destroy_model_parallel,
|
|
initialize_model_parallel,
|
|
)
|
|
from vllm.utils.system_utils import update_environment_variables
|
|
|
|
import vllm_ascend.ops.register_custom_ops # noqa
|
|
from tests.e2e.singlecard.compile.backend import TestBackend as CompileTestBackend
|
|
from vllm_ascend.compilation.passes.sequence_parallelism_moe import (
|
|
SequenceParallelismMoePass,
|
|
)
|
|
from vllm_ascend.utils import enable_custom_op
|
|
|
|
MASTER_PORT = 29500
|
|
WORLD_SIZE = 2
|
|
WORKER_READY = "__ready__"
|
|
WORKER_STOP = "__stop__"
|
|
WORKER_RESULT_TIMEOUT_S = 180
|
|
WORKER_JOIN_TIMEOUT_S = 30
|
|
|
|
|
|
class BaseAllGatherRMSNormModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
dtype: torch.dtype,
|
|
eps: float = 1e-6,
|
|
device: str = "npu",
|
|
):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.norm_w = torch.randn(hidden_size, dtype=dtype, device=device)
|
|
|
|
def _all_gather_sliced(self, x: torch.Tensor, num_tokens_helper: torch.Tensor) -> torch.Tensor:
|
|
num_tokens = num_tokens_helper.shape[0]
|
|
activated = torch.relu(x)
|
|
gathered = tensor_model_parallel_all_gather(activated, 0)
|
|
return gathered[:num_tokens]
|
|
|
|
@staticmethod
|
|
def ops_in_model_after() -> tuple[tuple[OpOverload, int], ...]:
|
|
return (
|
|
(torch.ops.vllm.all_gather.default, 1),
|
|
(torch.ops._C_ascend.npu_add_rms_norm_bias.default, 1),
|
|
(torch.ops.vllm.maybe_chunk_residual.default, 1),
|
|
)
|
|
|
|
|
|
class AllGatherRMSNormModel(BaseAllGatherRMSNormModel):
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
num_tokens_helper: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
sliced = self._all_gather_sliced(x, num_tokens_helper)
|
|
rms_out = torch.ops._C_ascend.npu_add_rms_norm_bias(sliced, residual, self.norm_w, None, self.eps)
|
|
return rms_out[0]
|
|
|
|
@staticmethod
|
|
def ops_in_model_before() -> tuple[tuple[OpOverload, int], ...]:
|
|
return (
|
|
(torch.ops.vllm.all_gather.default, 1),
|
|
(torch.ops._C_ascend.npu_add_rms_norm_bias.default, 1),
|
|
)
|
|
|
|
|
|
class Qwen3VLAllGatherRMSNormModel(BaseAllGatherRMSNormModel):
|
|
"""Exercises Qwen3VLMiddleLayerAllgatherAddRMSNormPattern (all_gather + slice + add + RMSNorm)."""
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
num_tokens_helper: torch.Tensor,
|
|
deepstack_input_embeds: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
sliced = self._all_gather_sliced(x, num_tokens_helper)
|
|
add_ = sliced + deepstack_input_embeds
|
|
result, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(add_, residual, self.norm_w, None, self.eps)
|
|
# Keep the residual output live so the traced graph preserves the full pattern.
|
|
result = result - residual
|
|
return result
|
|
|
|
@staticmethod
|
|
def ops_in_model_before() -> tuple[tuple[OpOverload, int], ...]:
|
|
return (
|
|
(torch.ops.vllm.all_gather.default, 1),
|
|
(torch.ops.aten.add.Tensor, 1),
|
|
(torch.ops._C_ascend.npu_add_rms_norm_bias.default, 1),
|
|
)
|
|
|
|
|
|
class AllGatherChunkNoOpModel(nn.Module):
|
|
"""Exercises AllGatherChunkNoOpPattern (all_gather + sequence_parallel_chunk_impl -> identity)."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
z = torch.relu(x)
|
|
gathered = tensor_model_parallel_all_gather(z, 0)
|
|
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
|
|
|
|
@staticmethod
|
|
def ops_in_model_before() -> tuple[tuple[OpOverload, int], ...]:
|
|
return (
|
|
(torch.ops.vllm.all_gather.default, 1),
|
|
(torch.ops.vllm.sequence_parallel_chunk_impl.default, 1),
|
|
)
|
|
|
|
@staticmethod
|
|
def ops_in_model_after() -> tuple[tuple[OpOverload, int], ...]:
|
|
return (
|
|
(torch.ops.vllm.all_gather.default, 0),
|
|
(torch.ops.vllm.sequence_parallel_chunk_impl.default, 0),
|
|
)
|
|
|
|
|
|
def _build_all_gather_rms_norm_inputs(
|
|
batch_size: int,
|
|
seq_len: int,
|
|
hidden_size: int,
|
|
dtype: torch.dtype,
|
|
tp_size: int,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
local_tokens = batch_size * seq_len
|
|
num_tokens = local_tokens * tp_size
|
|
x = torch.randn(local_tokens, hidden_size, dtype=dtype)
|
|
residual = torch.zeros(num_tokens, hidden_size, dtype=dtype)
|
|
num_tokens_helper = torch.empty(num_tokens, device=x.device, dtype=dtype)
|
|
return (x, residual, num_tokens_helper)
|
|
|
|
|
|
def _build_qwen3vl_inputs(
|
|
batch_size: int,
|
|
seq_len: int,
|
|
hidden_size: int,
|
|
dtype: torch.dtype,
|
|
tp_size: int,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
x, residual, num_tokens_helper = _build_all_gather_rms_norm_inputs(
|
|
batch_size=batch_size,
|
|
seq_len=seq_len,
|
|
hidden_size=hidden_size,
|
|
dtype=dtype,
|
|
tp_size=tp_size,
|
|
)
|
|
deepstack = torch.randn(num_tokens_helper.shape[0], hidden_size, dtype=dtype)
|
|
return (x, residual, num_tokens_helper, deepstack)
|
|
|
|
|
|
def _build_allgather_chunk_noop_inputs(
|
|
batch_size: int,
|
|
seq_len: int,
|
|
hidden_size: int,
|
|
dtype: torch.dtype,
|
|
tp_size: int,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
del tp_size
|
|
local_tokens = batch_size * seq_len
|
|
x = torch.randn(local_tokens, hidden_size, dtype=dtype)
|
|
return (x,)
|
|
|
|
|
|
def _create_all_gather_rms_norm_model(
|
|
hidden_size: int,
|
|
dtype: torch.dtype,
|
|
eps: float,
|
|
device: str,
|
|
) -> nn.Module:
|
|
return AllGatherRMSNormModel(hidden_size=hidden_size, dtype=dtype, eps=eps, device=device)
|
|
|
|
|
|
def _create_qwen3vl_model(
|
|
hidden_size: int,
|
|
dtype: torch.dtype,
|
|
eps: float,
|
|
device: str,
|
|
) -> nn.Module:
|
|
return Qwen3VLAllGatherRMSNormModel(hidden_size=hidden_size, dtype=dtype, eps=eps, device=device)
|
|
|
|
|
|
def _create_allgather_chunk_noop_model(
|
|
hidden_size: int,
|
|
dtype: torch.dtype,
|
|
eps: float,
|
|
device: str,
|
|
) -> nn.Module:
|
|
del hidden_size, dtype, eps, device
|
|
return AllGatherChunkNoOpModel()
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PatternTestCase:
|
|
model_factory: Any
|
|
input_builder: Any
|
|
dynamic_input_indices: tuple[int, ...]
|
|
pre_pass_expected_counts_factory: Any
|
|
post_pass_expected_counts_factory: Any
|
|
|
|
|
|
PATTERN_TEST_CASES = {
|
|
"middle_layer_allgather_add_rms_norm": PatternTestCase(
|
|
model_factory=_create_all_gather_rms_norm_model,
|
|
input_builder=_build_all_gather_rms_norm_inputs,
|
|
dynamic_input_indices=(0, 2),
|
|
pre_pass_expected_counts_factory=AllGatherRMSNormModel.ops_in_model_before,
|
|
post_pass_expected_counts_factory=AllGatherRMSNormModel.ops_in_model_after,
|
|
),
|
|
"qwen3vl_middle_layer_allgather_add_rms_norm": PatternTestCase(
|
|
model_factory=_create_qwen3vl_model,
|
|
input_builder=_build_qwen3vl_inputs,
|
|
dynamic_input_indices=(0, 2, 3),
|
|
pre_pass_expected_counts_factory=Qwen3VLAllGatherRMSNormModel.ops_in_model_before,
|
|
post_pass_expected_counts_factory=Qwen3VLAllGatherRMSNormModel.ops_in_model_after,
|
|
),
|
|
"allgather_chunk_noop": PatternTestCase(
|
|
model_factory=_create_allgather_chunk_noop_model,
|
|
input_builder=_build_allgather_chunk_noop_inputs,
|
|
dynamic_input_indices=(0,),
|
|
pre_pass_expected_counts_factory=AllGatherChunkNoOpModel.ops_in_model_before,
|
|
post_pass_expected_counts_factory=AllGatherChunkNoOpModel.ops_in_model_after,
|
|
),
|
|
}
|
|
|
|
|
|
def _assert_op_counts(
|
|
backend: CompileTestBackend,
|
|
expected_counts: tuple[tuple[OpOverload, int], ...],
|
|
before: bool = False,
|
|
) -> None:
|
|
for op, expected_count in expected_counts:
|
|
actual_count = backend.op_count(op, before=before)
|
|
stage = "before" if before else "after"
|
|
assert actual_count == expected_count, (
|
|
f"op {stage} pass: {op} expected {expected_count}, but got {actual_count}"
|
|
)
|
|
|
|
|
|
def _run_single_pattern_case(
|
|
local_rank: int,
|
|
case_name: str,
|
|
vllm_config: VllmConfig,
|
|
tp_size: int,
|
|
batch_size: int = 8,
|
|
seq_len: int = 16,
|
|
hidden_size: int = 16,
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
eps: float = 1e-5,
|
|
) -> None:
|
|
case = PATTERN_TEST_CASES[case_name]
|
|
sp_moe_pass = SequenceParallelismMoePass(vllm_config)
|
|
backend = CompileTestBackend(custom_passes=[sp_moe_pass])
|
|
model = case.model_factory(
|
|
hidden_size=hidden_size,
|
|
dtype=dtype,
|
|
eps=eps,
|
|
device=f"npu:{local_rank}",
|
|
)
|
|
inputs = case.input_builder(
|
|
batch_size=batch_size,
|
|
seq_len=seq_len,
|
|
hidden_size=hidden_size,
|
|
dtype=dtype,
|
|
tp_size=tp_size,
|
|
)
|
|
for dynamic_input_index in case.dynamic_input_indices:
|
|
torch._dynamo.mark_dynamic(inputs[dynamic_input_index], 0)
|
|
|
|
unfused = model(*inputs)
|
|
compiled = torch.compile(model, backend=backend)
|
|
fused = compiled(*inputs)
|
|
assert unfused.shape == fused.shape
|
|
|
|
assert sp_moe_pass.matched_count == 1
|
|
_assert_op_counts(backend, case.pre_pass_expected_counts_factory(), before=True)
|
|
_assert_op_counts(backend, case.post_pass_expected_counts_factory())
|
|
|
|
|
|
def _run_sequence_parallelism_moe_test(
|
|
local_rank: int,
|
|
world_size: int,
|
|
master_port: int,
|
|
command_queue: Any,
|
|
result_queue: Any,
|
|
batch_size: int = 8,
|
|
seq_len: int = 16,
|
|
hidden_size: int = 16,
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
eps: float = 1e-5,
|
|
) -> None:
|
|
torch.npu.set_device(local_rank)
|
|
torch.set_default_device(f"npu:{local_rank}")
|
|
torch.set_default_dtype(dtype)
|
|
torch.manual_seed(0)
|
|
|
|
update_environment_variables(
|
|
{
|
|
"RANK": str(local_rank),
|
|
"LOCAL_RANK": str(local_rank),
|
|
"WORLD_SIZE": str(world_size),
|
|
"MASTER_ADDR": "127.0.0.1",
|
|
"MASTER_PORT": str(master_port),
|
|
}
|
|
)
|
|
|
|
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
|
|
|
|
try:
|
|
with vllm.config.set_current_vllm_config(vllm_config):
|
|
init_distributed_environment(
|
|
world_size=world_size,
|
|
rank=local_rank,
|
|
local_rank=local_rank,
|
|
backend="hccl",
|
|
)
|
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
|
|
|
if not enable_custom_op():
|
|
raise RuntimeError("vllm_ascend custom ops are not available")
|
|
|
|
_ = get_tp_group().unique_name
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
result_queue.put((WORKER_READY, local_rank, "ok", ""))
|
|
|
|
while True:
|
|
case_name = command_queue.get()
|
|
if case_name == WORKER_STOP:
|
|
return
|
|
|
|
try:
|
|
_run_single_pattern_case(
|
|
local_rank=local_rank,
|
|
case_name=case_name,
|
|
vllm_config=vllm_config,
|
|
tp_size=tp_size,
|
|
batch_size=batch_size,
|
|
seq_len=seq_len,
|
|
hidden_size=hidden_size,
|
|
dtype=dtype,
|
|
eps=eps,
|
|
)
|
|
except Exception:
|
|
result_queue.put((case_name, local_rank, "error", traceback.format_exc()))
|
|
else:
|
|
result_queue.put((case_name, local_rank, "ok", ""))
|
|
finally:
|
|
destroy_model_parallel()
|
|
destroy_distributed_environment()
|
|
if torch.distributed.is_initialized():
|
|
torch.distributed.destroy_process_group()
|
|
|
|
|
|
def _worker_entrypoint(
|
|
local_rank: int,
|
|
world_size: int,
|
|
master_port: int,
|
|
command_queue: Any,
|
|
result_queue: Any,
|
|
) -> None:
|
|
try:
|
|
_run_sequence_parallelism_moe_test(
|
|
local_rank=local_rank,
|
|
world_size=world_size,
|
|
master_port=master_port,
|
|
command_queue=command_queue,
|
|
result_queue=result_queue,
|
|
)
|
|
except Exception:
|
|
result_queue.put((WORKER_READY, local_rank, "error", traceback.format_exc()))
|
|
|
|
|
|
def _wait_for_worker_reports(
|
|
result_queue: Any,
|
|
case_name: str,
|
|
expected_reports: int,
|
|
) -> None:
|
|
errors = []
|
|
for _ in range(expected_reports):
|
|
try:
|
|
reported_case_name, local_rank, status, payload = result_queue.get(timeout=WORKER_RESULT_TIMEOUT_S)
|
|
except queue.Empty as exc:
|
|
raise TimeoutError(f"Timed out waiting for worker reports for {case_name}") from exc
|
|
|
|
assert reported_case_name == case_name, f"Expected worker report for {case_name}, but got {reported_case_name}"
|
|
if status != "ok":
|
|
errors.append(f"rank {local_rank}:\n{payload}")
|
|
|
|
if errors:
|
|
raise AssertionError("\n\n".join(errors))
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def sequence_parallelism_moe_workers() -> Generator[Callable[[str], None], None, None]:
|
|
ctx = torch.multiprocessing.get_context("spawn")
|
|
command_queues = [ctx.Queue() for _ in range(WORLD_SIZE)]
|
|
result_queue = ctx.Queue()
|
|
workers = []
|
|
|
|
for local_rank in range(WORLD_SIZE):
|
|
worker = ctx.Process(
|
|
target=_worker_entrypoint,
|
|
args=(local_rank, WORLD_SIZE, MASTER_PORT, command_queues[local_rank], result_queue),
|
|
)
|
|
worker.start()
|
|
workers.append(worker)
|
|
|
|
try:
|
|
_wait_for_worker_reports(result_queue, WORKER_READY, WORLD_SIZE)
|
|
|
|
def _run_case(case_name: str) -> None:
|
|
for command_queue in command_queues:
|
|
command_queue.put(case_name)
|
|
_wait_for_worker_reports(result_queue, case_name, WORLD_SIZE)
|
|
|
|
yield _run_case
|
|
finally:
|
|
for command_queue in command_queues:
|
|
command_queue.put(WORKER_STOP)
|
|
for worker in workers:
|
|
worker.join(timeout=WORKER_JOIN_TIMEOUT_S)
|
|
if worker.is_alive():
|
|
worker.terminate()
|
|
worker.join()
|
|
|
|
|
|
@pytest.mark.parametrize("case_name", tuple(PATTERN_TEST_CASES), ids=tuple(PATTERN_TEST_CASES))
|
|
def test_sequence_parallelism_moe_patterns(
|
|
sequence_parallelism_moe_workers: Callable[[str], None], case_name: str
|
|
) -> None:
|
|
sequence_parallelism_moe_workers(case_name)
|