### What this PR does / why we need it? Support shared expert DP for deepseek_mtp feature. `shared_expert_dp` requires `SP==True`, with corresponding parameter restrictions. Previously, due to the coupling between `shared_expert_dp` and torchair, and the removal of `deepseek_mtp` in vllm_ascend, shared expert dp of deepseek_mtp was temporarily removed. Currently, by performing the `reduce_scatter` on the input of deepssek_mtp in `mtp_proposer.py`, we ensure that it matches the dimensions of `input_embedding`, and then perform the `all_gather` on the output of mtp. ### How was this patch tested? baseline: <img width="1184" height="692" alt="image" src="https://github.com/user-attachments/assets/9680d53a-7b1d-481a-accc-b8f3dae2b9e3" /> enable shared_expert_dp and multistream_overlap_shared_expert: <img width="1167" height="687" alt="image" src="https://github.com/user-attachments/assets/2531d06b-dfda-4e24-8628-6f4b0f677ddc" /> TPOT: 48ms -> 45.4ms Average TPS per rank: 117.6 -> 126.1 - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: chenmenglong <chenmenglong1@huawei.com> Signed-off-by: zengran <zengran2@huawei.com> Co-authored-by: zengran <zengran2@huawei.com>
94 lines
3.0 KiB
Python
94 lines
3.0 KiB
Python
import os
|
|
|
|
import pytest
|
|
from vllm import SamplingParams
|
|
|
|
from tests.e2e.conftest import VllmRunner
|
|
from tests.e2e.model_utils import check_outputs_equal
|
|
|
|
MODELS = [
|
|
"vllm-ascend/DeepSeek-V2-Lite",
|
|
]
|
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
|
|
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
|
def test_models_with_enable_shared_expert_dp(model: str) -> None:
|
|
|
|
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
|
del os.environ['HCCL_OP_EXPANSION_MODE']
|
|
|
|
prompts = [
|
|
"Hello, my name is", "The capital of the United States is",
|
|
"The capital of France is", "The future of AI is"
|
|
]
|
|
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
|
|
|
|
with VllmRunner(
|
|
model,
|
|
max_model_len=1024,
|
|
enforce_eager=True,
|
|
tensor_parallel_size=2,
|
|
enable_expert_parallel=True,
|
|
) as runner:
|
|
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
|
|
|
os.environ["VLLM_ASCEND_ENABLE_FLASHCOMM1"] = "1"
|
|
with VllmRunner(
|
|
model,
|
|
max_model_len=1024,
|
|
enforce_eager=True,
|
|
tensor_parallel_size=2,
|
|
enable_expert_parallel=True,
|
|
additional_config={
|
|
"enable_shared_expert_dp": True,
|
|
},
|
|
) as runner:
|
|
shared_expert_dp_eager_outputs = runner.model.generate(
|
|
prompts, sampling_params)
|
|
|
|
with VllmRunner(
|
|
model,
|
|
max_model_len=1024,
|
|
tensor_parallel_size=2,
|
|
enforce_eager=False,
|
|
compilation_config={
|
|
"cudagraph_capture_sizes": [1, 4, 8, 16],
|
|
"cudagraph_mode": "FULL_DECODE_ONLY",
|
|
},
|
|
additional_config={
|
|
"enable_shared_expert_dp": True,
|
|
},
|
|
) as runner:
|
|
shared_expert_dp_aclgraph_outputs = runner.model.generate(
|
|
prompts, sampling_params)
|
|
|
|
vllm_eager_outputs_list = []
|
|
for output in vllm_eager_outputs:
|
|
vllm_eager_outputs_list.append(
|
|
(output.outputs[0].index, output.outputs[0].text))
|
|
|
|
shared_expert_dp_eager_outputs_list = []
|
|
for output in shared_expert_dp_eager_outputs:
|
|
shared_expert_dp_eager_outputs_list.append(
|
|
(output.outputs[0].index, output.outputs[0].text))
|
|
|
|
shared_expert_dp_aclgraph_outputs_list = []
|
|
for output in shared_expert_dp_aclgraph_outputs:
|
|
shared_expert_dp_aclgraph_outputs_list.append(
|
|
(output.outputs[0].index, output.outputs[0].text))
|
|
|
|
check_outputs_equal(
|
|
outputs_0_lst=vllm_eager_outputs_list,
|
|
outputs_1_lst=shared_expert_dp_eager_outputs_list,
|
|
name_0="vllm_eager_outputs",
|
|
name_1="shared_expert_dp_eager_outputs",
|
|
)
|
|
|
|
check_outputs_equal(
|
|
outputs_0_lst=vllm_eager_outputs_list,
|
|
outputs_1_lst=shared_expert_dp_aclgraph_outputs_list,
|
|
name_0="vllm_eager_outputs",
|
|
name_1="shared_expert_dp_aclgraph_outputs",
|
|
)
|