[Feature] optimize sp & qwen3 next support sp. (#3225)
This PR will accomplish the following tasks: **optimize SP** In the old version implementation, the first layer was all_reduce, which used rms to split chunks. We changed it to perform reduce_scatter on the embedding side, replace one all_reduce operation and one chunk with one reduce_scatter operation. **Support qwen3 next** Since Qwen3 Next includes a linear attention module, the prefix name of this module cannot take effect directly. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -104,9 +104,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
|
||||
input_ = torch.tensor([1, 2, 3])
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp1:
|
||||
with patch("torch.ops.vllm.maybe_pad_and_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp1:
|
||||
output = layer.forward(input_)
|
||||
|
||||
# Should just pass through without masking
|
||||
@@ -123,9 +122,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
|
||||
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp:
|
||||
with patch("torch.ops.vllm.maybe_pad_and_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp:
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
|
||||
@@ -150,9 +148,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
return_value=mock_output.clone())
|
||||
|
||||
# Patch tensor_model_parallel_all_reduce to mock its behavior
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x):
|
||||
with patch("torch.ops.vllm.maybe_pad_and_reduce",
|
||||
side_effect=lambda x: x):
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
# Check that invalid positions (0, 2, 4) were zeroed out
|
||||
@@ -176,9 +173,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
|
||||
for input_, expected_shape in test_cases:
|
||||
with self.subTest(input=input_):
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x):
|
||||
with patch("torch.ops.vllm.maybe_pad_and_reduce",
|
||||
side_effect=lambda x: x):
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user