[Feature] optimize sp & qwen3 next support sp. (#3225)

This PR will accomplish the following tasks: 
**optimize SP**
In the old version implementation, the first layer was all_reduce, which
used rms to split chunks. We changed it to perform reduce_scatter on the
embedding side, replace one all_reduce operation and one chunk with one
reduce_scatter operation.
**Support qwen3 next**
Since Qwen3 Next includes a linear attention module, the prefix name of
this module cannot take effect directly.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-10-13 23:02:12 +08:00
committed by GitHub
parent 31682961af
commit 6972df5951
10 changed files with 140 additions and 193 deletions

View File

@@ -9,12 +9,6 @@ from tests.ut.base import PytestBase
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
def mock_maybe_chunk_residual(x, residual):
if x.size(0) != residual.size(0):
return residual[:4]
return residual
def mock_rms_norm(x, weight, eps):
return x + 1, None
@@ -36,8 +30,6 @@ class TestAscendRMSNorm(PytestBase):
@pytest.fixture(autouse=True)
def context(self, mocker: MockerFixture):
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual)
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
mocker.patch("torch_npu.npu_add_rms_norm",
side_effect=mock_add_rms_norm)
@@ -66,21 +58,6 @@ class TestAscendRMSNorm(PytestBase):
assert torch.allclose(x_out, x_out_expected)
# Test case for flashcomm_v1 scenario
def test_forward_oot_with_flashcomm_v1(self):
layer = RMSNorm(hidden_size=512, eps=1e-05)
x = torch.randn(4, 512, dtype=torch.bfloat16)
residual = torch.randn(16, 512, dtype=torch.bfloat16)
x_out, residual_out = layer.forward_oot(x, residual)
x_out_expected = 2 * x
residual_out_expected = 2 * residual[:4]
assert residual_out.size(0) == 4
assert torch.allclose(x_out, x_out_expected)
assert torch.allclose(residual_out, residual_out_expected)
# Test case for addrmsnorm + w8a8 quant fusion
def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
mock_is_310p = mocker.patch("vllm_ascend.utils.is_310p")

View File

@@ -104,9 +104,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
input_ = torch.tensor([1, 2, 3])
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x) as mock_reduce_tp1:
with patch("torch.ops.vllm.maybe_pad_and_reduce",
side_effect=lambda x: x) as mock_reduce_tp1:
output = layer.forward(input_)
# Should just pass through without masking
@@ -123,9 +122,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x) as mock_reduce_tp:
with patch("torch.ops.vllm.maybe_pad_and_reduce",
side_effect=lambda x: x) as mock_reduce_tp:
# Call the forward method
output = layer.forward(input_)
@@ -150,9 +148,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
return_value=mock_output.clone())
# Patch tensor_model_parallel_all_reduce to mock its behavior
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x):
with patch("torch.ops.vllm.maybe_pad_and_reduce",
side_effect=lambda x: x):
# Call the forward method
output = layer.forward(input_)
# Check that invalid positions (0, 2, 4) were zeroed out
@@ -176,9 +173,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
for input_, expected_shape in test_cases:
with self.subTest(input=input_):
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x):
with patch("torch.ops.vllm.maybe_pad_and_reduce",
side_effect=lambda x: x):
# Call the forward method
output = layer.forward(input_)
self.assertEqual(output.shape, expected_shape)

View File

@@ -276,10 +276,7 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
@patch("torch_npu.npu_add_rms_norm")
@patch("torch_npu.npu_rms_norm")
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=lambda x, residual: residual)
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done,
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
mock_rms_norm, mock_add_norm,
mock_distributed, base_config,
vllm_config, mock_forward_context):