From 143e1f46d0bc764713accb94b3ea38267c9881ad Mon Sep 17 00:00:00 2001 From: MengLong Chen <71744434+dragondream-chen@users.noreply.github.com> Date: Mon, 1 Dec 2025 20:44:11 +0800 Subject: [PATCH] [Feat] shared expert dp for deepseek_mtp (#3811) ### What this PR does / why we need it? Support shared expert DP for deepseek_mtp feature. `shared_expert_dp` requires `SP==True`, with corresponding parameter restrictions. Previously, due to the coupling between `shared_expert_dp` and torchair, and the removal of `deepseek_mtp` in vllm_ascend, shared expert dp of deepseek_mtp was temporarily removed. Currently, by performing the `reduce_scatter` on the input of deepssek_mtp in `mtp_proposer.py`, we ensure that it matches the dimensions of `input_embedding`, and then perform the `all_gather` on the output of mtp. ### How was this patch tested? baseline: image enable shared_expert_dp and multistream_overlap_shared_expert: image TPOT: 48ms -> 45.4ms Average TPS per rank: 117.6 -> 126.1 - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: chenmenglong Signed-off-by: zengran Co-authored-by: zengran --- tests/e2e/multicard/test_shared_expert_dp.py | 93 ++++++++++++++++++++ tests/ut/ops/test_layernorm.py | 7 +- vllm_ascend/ascend_config.py | 4 + vllm_ascend/attention/mla_v1.py | 2 + vllm_ascend/ops/layernorm.py | 1 + vllm_ascend/ops/register_custom_ops.py | 28 ++++++ vllm_ascend/platform.py | 4 +- vllm_ascend/spec_decode/mtp_proposer.py | 55 +++++++++--- vllm_ascend/utils.py | 8 +- 9 files changed, 185 insertions(+), 17 deletions(-) create mode 100644 tests/e2e/multicard/test_shared_expert_dp.py diff --git a/tests/e2e/multicard/test_shared_expert_dp.py b/tests/e2e/multicard/test_shared_expert_dp.py new file mode 100644 index 00000000..867d3ab6 --- /dev/null +++ b/tests/e2e/multicard/test_shared_expert_dp.py @@ -0,0 +1,93 @@ +import os + +import pytest +from vllm import SamplingParams + +from tests.e2e.conftest import VllmRunner +from tests.e2e.model_utils import check_outputs_equal + +MODELS = [ + "vllm-ascend/DeepSeek-V2-Lite", +] +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +@pytest.mark.parametrize("model", MODELS) +def test_models_with_enable_shared_expert_dp(model: str) -> None: + + if 'HCCL_OP_EXPANSION_MODE' in os.environ: + del os.environ['HCCL_OP_EXPANSION_MODE'] + + prompts = [ + "Hello, my name is", "The capital of the United States is", + "The capital of France is", "The future of AI is" + ] + sampling_params = SamplingParams(max_tokens=32, temperature=0.0) + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + tensor_parallel_size=2, + enable_expert_parallel=True, + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, sampling_params) + + os.environ["VLLM_ASCEND_ENABLE_FLASHCOMM1"] = "1" + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + tensor_parallel_size=2, + enable_expert_parallel=True, + additional_config={ + "enable_shared_expert_dp": True, + }, + ) as runner: + shared_expert_dp_eager_outputs = runner.model.generate( + prompts, sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=False, + compilation_config={ + "cudagraph_capture_sizes": [1, 4, 8, 16], + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + additional_config={ + "enable_shared_expert_dp": True, + }, + ) as runner: + shared_expert_dp_aclgraph_outputs = runner.model.generate( + prompts, sampling_params) + + vllm_eager_outputs_list = [] + for output in vllm_eager_outputs: + vllm_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + shared_expert_dp_eager_outputs_list = [] + for output in shared_expert_dp_eager_outputs: + shared_expert_dp_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + shared_expert_dp_aclgraph_outputs_list = [] + for output in shared_expert_dp_aclgraph_outputs: + shared_expert_dp_aclgraph_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=shared_expert_dp_eager_outputs_list, + name_0="vllm_eager_outputs", + name_1="shared_expert_dp_eager_outputs", + ) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=shared_expert_dp_aclgraph_outputs_list, + name_0="vllm_eager_outputs", + name_1="shared_expert_dp_aclgraph_outputs", + ) diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index 314775f8..77af2649 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import patch import pytest import torch @@ -42,7 +43,9 @@ class TestAscendRMSNorm(PytestBase): # Test case for the most common and basic scenario @pytest.mark.parametrize( "residual", [None, torch.randn(4, 8, dtype=torch.float16)]) - def test_forward_oot_basic(self, residual): + @patch("torch.ops.vllm.maybe_chunk_residual") + def test_forward_oot_basic(self, mock_maybe_chunk_residual, residual): + mock_maybe_chunk_residual.side_effect = lambda x, residual: residual layer = RMSNorm(hidden_size=8, eps=1e-05) x = torch.randn(4, 8, dtype=torch.float16) if residual is not None: @@ -107,6 +110,8 @@ class TestAscendRMSNorm(PytestBase): mock_forward_context.num_hidden_layers = num_hidden_layers mock_forward_context.fusion_linear = "gate_up_dense" mock_forward_context.weight_prefetch_method = None + mocker.patch("torch.ops.vllm.maybe_chunk_residual", + lambda x, residual: residual) # Ensure fusion and layer_idx increment are handled correctly x = torch.randn(4, 8, dtype=torch.float16) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 16d16a4d..115dbef1 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -72,6 +72,10 @@ class AscendConfig: self.enable_shared_expert_dp = additional_config.get( "enable_shared_expert_dp", False ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel + if self.enable_shared_expert_dp: + from vllm_ascend.utils import enable_sp + assert enable_sp(vllm_config=vllm_config, + enable_shared_expert_dp=True) self.multistream_overlap_shared_expert = additional_config.get( "multistream_overlap_shared_expert", False) self.recompute_scheduler_enable = additional_config.get( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 19c8025b..5d341d03 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1677,6 +1677,8 @@ class AscendMLAImpl(MLAAttentionImpl): forward_context = get_forward_context() if (self.enable_mlapo and (attn_metadata is None or not forward_context.with_prefill)): + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states.contiguous(), need_gather_q_kv) decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( hidden_states, kv_cache, attn_metadata) else: diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 8c395b54..da5051c0 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -110,6 +110,7 @@ class AscendRMSNorm(RMSNorm): import torch_npu if residual is not None: + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) assert x.size(0) == residual.size(0) x, residual = _addrmsnorm_forward_oot( self, x, residual, self.next_need_quant_fusion_linear, diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index bb16bc00..03bea554 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F import torch_npu from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -15,6 +16,27 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.utils import npu_stream_switch, prefetch_stream +def _maybe_chunk_residual_impl(x: torch.Tensor, + residual: torch.Tensor) -> torch.Tensor: + try: + forward_context = get_forward_context() + except AssertionError: + return residual + + if x.size(0) != residual.size(0): + sp_enabled = forward_context.sp_enabled + assert sp_enabled is True, ("Currently, this situation only occurs " + "when sp is enabled") + pad_size = forward_context.pad_size + if pad_size > 0: + residual = F.pad(residual, (0, 0, 0, pad_size)) + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + residual = torch.chunk(residual, tp_size, dim=0)[tp_rank] + + return residual + + def _maybe_all_gather_and_maybe_unpad_impl( x: torch.Tensor, label: bool, @@ -259,6 +281,12 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, return output +direct_register_custom_op(op_name="maybe_chunk_residual", + op_func=_maybe_chunk_residual_impl, + fake_impl=lambda x, residual: x, + mutates_args=[], + dispatch_key="PrivateUse1") + direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", op_func=_maybe_all_gather_and_maybe_unpad_impl, fake_impl=_maybe_all_gather_and_maybe_unpad_fake, diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 7cc84fc6..5ff66926 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -283,7 +283,7 @@ class NPUPlatform(Platform): if parallel_config and parallel_config.worker_cls == "auto": # TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm. parallel_config.all2all_backend = "flashinfer_all2allv" - if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp: + if ascend_config.torchair_graph_config.enabled: parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker" else: parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker" @@ -379,8 +379,6 @@ class NPUPlatform(Platform): ascend_config = get_ascend_config() if use_mla and ascend_config.enable_shared_expert_dp: - if use_mla and not use_sparse: - return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend" if use_mla and use_sparse: return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend" diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 15b2b4cf..cacc2bdf 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -32,7 +32,8 @@ from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, update_mla_attn_params) from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, - prefill_context_parallel_enable) + prefill_context_parallel_enable, + shared_expert_dp_enabled) if prefill_context_parallel_enable(): from vllm.distributed import get_pcp_group @@ -94,6 +95,7 @@ class MtpProposer(Proposer): # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() + self.enable_shared_expert_dp = shared_expert_dp_enabled() self.pcp_size = self.runner.pcp_size self.dcp_size = self.runner.dcp_size @@ -286,6 +288,12 @@ class MtpProposer(Proposer): aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, is_mtp_model=True): + if self.enable_shared_expert_dp: + positions = positions.unsqueeze(-1) + positions = torch.ops.vllm.maybe_pad_and_reduce(positions) + positions = positions.squeeze(-1) + previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + previous_hidden_states) self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) @@ -294,9 +302,13 @@ class MtpProposer(Proposer): not forward_context.capturing: if self.vllm_config.model_config.use_mla: update_mla_attn_params( - self.update_stream, forward_context, - positions.shape[0], + self.update_stream, forward_context, num_tokens, self.vllm_config.speculative_config) + if self.enable_shared_expert_dp: + positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + positions, True) + previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + previous_hidden_states, True) dummy_compute_logits(previous_hidden_states) if with_prefill: break @@ -675,7 +687,8 @@ class MtpProposer(Proposer): moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens) - if scheduler_output: + # Enable shared_expert_dp and MTP FULL graph may cause accuracy issues. + if scheduler_output and not self.enable_shared_expert_dp: max_query_len = common_attn_metadata.max_query_len uniform_decode = (max_query_len in list( range(1, self.num_speculative_tokens + @@ -725,11 +738,22 @@ class MtpProposer(Proposer): with ProfileExecuteDuration().capture_async('mtp_forward'): model_kwargs = {} model_kwargs["attn_metadata"] = attn_metadata + input_ids = self.input_ids[:num_input_tokens] + positions = self.positions[:num_input_tokens] + hidden_states = self.hidden_states[:num_input_tokens] - hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens]) + if self.enable_shared_expert_dp: + # positions [N] -> [N, 1] for padding + positions = positions.unsqueeze(-1) + positions = torch.ops.vllm.maybe_pad_and_reduce( + positions) + positions = positions.squeeze(-1) + hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + hidden_states) + + hidden_states = self.model(input_ids=input_ids, + positions=positions, + hidden_states=hidden_states) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: if self.vllm_config.model_config.use_mla: @@ -738,6 +762,12 @@ class MtpProposer(Proposer): num_input_tokens, self.vllm_config.speculative_config) + if self.enable_shared_expert_dp: + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states.contiguous(), True) + positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + positions.contiguous(), True) + num_indices = last_token_indices.shape[0] if lmhead_tp_enable(): if not self.runner.with_prefill: @@ -805,20 +835,21 @@ class MtpProposer(Proposer): batch_size, attn_metadata_i.decode.actual_seq_lengths_q) attn_metadata_i.decode.cos = builder.cos_cache[ - positions].unsqueeze(1).unsqueeze(2) + positions[:batch_size]].unsqueeze(1).unsqueeze(2) attn_metadata_i.decode.sin = builder.sin_cache[ - positions].unsqueeze(1).unsqueeze(2) + positions[:batch_size]].unsqueeze(1).unsqueeze(2) # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch # but adjust the position ids and slot mappings to avoid the # out-of-range access during the model execution. The draft tokens # generated with this adjustment should be ignored. - exceeds_max_model_len = positions >= self.runner.model_config.max_model_len + exceeds_max_model_len = positions[: + batch_size] >= self.runner.model_config.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + positions[:batch_size]) # Increment the sequence lengths. attn_metadata_i.seq_lens[:batch_size] += 1 # For the requests that exceed the max model length, we set the diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index b41034a4..e9441e28 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -758,7 +758,7 @@ def dense_optim_enable() -> bool: return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE -def enable_sp(vllm_config=None) -> bool: +def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool: global _ENABLE_SP if _ENABLE_SP is None: if vllm_config is None: @@ -772,6 +772,12 @@ def enable_sp(vllm_config=None) -> bool: # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')))) + if not _ENABLE_SP and enable_shared_expert_dp: + _ENABLE_SP = True + logger.info( + "shared_expert_dp requires enable_sp = True. has set enable_sp to True" + ) + if not _ENABLE_SP: return _ENABLE_SP