[Feature] optimize sp & qwen3 next support sp. (#3225)

This PR will accomplish the following tasks: 
**optimize SP**
In the old version implementation, the first layer was all_reduce, which
used rms to split chunks. We changed it to perform reduce_scatter on the
embedding side, replace one all_reduce operation and one chunk with one
reduce_scatter operation.
**Support qwen3 next**
Since Qwen3 Next includes a linear attention module, the prefix name of
this module cannot take effect directly.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-10-13 23:02:12 +08:00
committed by GitHub
parent 31682961af
commit 6972df5951
10 changed files with 140 additions and 193 deletions

View File

@@ -1,9 +1,7 @@
import torch
import torch.nn.functional as F
import torch_npu
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
@@ -15,27 +13,6 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
def _maybe_chunk_residual_impl(x: torch.Tensor,
residual: torch.Tensor) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
return residual
if x.size(0) != residual.size(0):
sp_enabled = forward_context.sp_enabled
assert sp_enabled is True, ("Currently, this situation only occurs "
"when sp is enabled")
pad_size = forward_context.pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
return residual
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
label: bool) -> torch.Tensor:
try:
@@ -187,12 +164,6 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
return tensor_model_parallel_all_reduce(final_hidden_states)
direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=lambda x, label: x,