[Feature] optimize sp & qwen3 next support sp. (#3225)
This PR will accomplish the following tasks: **optimize SP** In the old version implementation, the first layer was all_reduce, which used rms to split chunks. We changed it to perform reduce_scatter on the embedding side, replace one all_reduce operation and one chunk with one reduce_scatter operation. **Support qwen3 next** Since Qwen3 Next includes a linear attention module, the prefix name of this module cannot take effect directly. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -9,12 +9,6 @@ from tests.ut.base import PytestBase
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
|
||||
|
||||
def mock_maybe_chunk_residual(x, residual):
|
||||
if x.size(0) != residual.size(0):
|
||||
return residual[:4]
|
||||
return residual
|
||||
|
||||
|
||||
def mock_rms_norm(x, weight, eps):
|
||||
return x + 1, None
|
||||
|
||||
@@ -36,8 +30,6 @@ class TestAscendRMSNorm(PytestBase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def context(self, mocker: MockerFixture):
|
||||
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=mock_maybe_chunk_residual)
|
||||
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
mocker.patch("torch_npu.npu_add_rms_norm",
|
||||
side_effect=mock_add_rms_norm)
|
||||
@@ -66,21 +58,6 @@ class TestAscendRMSNorm(PytestBase):
|
||||
|
||||
assert torch.allclose(x_out, x_out_expected)
|
||||
|
||||
# Test case for flashcomm_v1 scenario
|
||||
def test_forward_oot_with_flashcomm_v1(self):
|
||||
layer = RMSNorm(hidden_size=512, eps=1e-05)
|
||||
x = torch.randn(4, 512, dtype=torch.bfloat16)
|
||||
residual = torch.randn(16, 512, dtype=torch.bfloat16)
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
x_out_expected = 2 * x
|
||||
residual_out_expected = 2 * residual[:4]
|
||||
|
||||
assert residual_out.size(0) == 4
|
||||
assert torch.allclose(x_out, x_out_expected)
|
||||
assert torch.allclose(residual_out, residual_out_expected)
|
||||
|
||||
# Test case for addrmsnorm + w8a8 quant fusion
|
||||
def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
|
||||
mock_is_310p = mocker.patch("vllm_ascend.utils.is_310p")
|
||||
|
||||
@@ -104,9 +104,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
|
||||
input_ = torch.tensor([1, 2, 3])
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp1:
|
||||
with patch("torch.ops.vllm.maybe_pad_and_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp1:
|
||||
output = layer.forward(input_)
|
||||
|
||||
# Should just pass through without masking
|
||||
@@ -123,9 +122,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
|
||||
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp:
|
||||
with patch("torch.ops.vllm.maybe_pad_and_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp:
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
|
||||
@@ -150,9 +148,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
return_value=mock_output.clone())
|
||||
|
||||
# Patch tensor_model_parallel_all_reduce to mock its behavior
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x):
|
||||
with patch("torch.ops.vllm.maybe_pad_and_reduce",
|
||||
side_effect=lambda x: x):
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
# Check that invalid positions (0, 2, 4) were zeroed out
|
||||
@@ -176,9 +173,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
|
||||
for input_, expected_shape in test_cases:
|
||||
with self.subTest(input=input_):
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x):
|
||||
with patch("torch.ops.vllm.maybe_pad_and_reduce",
|
||||
side_effect=lambda x: x):
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
@@ -276,10 +276,7 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
@patch("torch_npu.npu_add_rms_norm")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=lambda x, residual: residual)
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done,
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
|
||||
mock_rms_norm, mock_add_norm,
|
||||
mock_distributed, base_config,
|
||||
vllm_config, mock_forward_context):
|
||||
|
||||
Reference in New Issue
Block a user