From d01fd1d1c3526ef038f05edc2464a1e77dfb200d Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Tue, 23 Sep 2025 14:52:42 +0800 Subject: [PATCH] [misc][torchair] fix bugs around `deepseek mtp`, `enable_shared_expert_dp` and `use_cached_kv_cache_bytes` (#3074) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? This miscellaneous​ contains several small fixes: 1) fix initialization and forward bugs of DeepseekMTPLayer with `shared_expert_dp` enabled. 2) fix a tensor shape mismatches after o_proj caused by a work-aroud change in NPUModelRunner. 3) avoid unnecessary decline of kv_cache memory (default: 64MB) with `use_cached_kv_cache_bytes` disabled. 4) fall back `fused_moe_state` from `MC2` to `All2All` since the padding logic of `mc2_mask` is incompatible with input hidden_states when `shared_expert_dp` enabled. Once this PR is merged, users can launch disaggregated_prefill deployments (large_ep) with `deepseek_mtp` and `shared_expert_dp` as `v0.9.1-dev` branch. The remaining problem of kv_cache tokens decline compared to `v0.9.1-dev` will be resolved by https://github.com/vllm-project/vllm-ascend/pull/3073. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? E2E vllm serving about deepseek_mtp with torchair graph mode and `enable_shared_expert_dp` with eager mode. Large ep deployments are also tested with this PR. - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/5aeb9254521023f97aca292b3478aa7ff485ffb2 --------- Signed-off-by: linfeng-yuan <1102311262@qq.com> --- .../torchair/ops/test_torchair_fused_moe.py | 1 + vllm_ascend/spec_decode/mtp_proposer.py | 6 ++- .../torchair/models/torchair_deepseek_v2.py | 2 +- .../torchair/ops/torchair_fused_moe.py | 5 +++ .../quantization/torchair_w8a8_dynamic.py | 3 ++ vllm_ascend/torchair/torchair_model_runner.py | 4 +- vllm_ascend/torchair/torchair_worker.py | 39 +++++++++---------- 7 files changed, 36 insertions(+), 24 deletions(-) diff --git a/tests/ut/torchair/ops/test_torchair_fused_moe.py b/tests/ut/torchair/ops/test_torchair_fused_moe.py index ec2d9e7..155ee78 100644 --- a/tests/ut/torchair/ops/test_torchair_fused_moe.py +++ b/tests/ut/torchair/ops/test_torchair_fused_moe.py @@ -72,6 +72,7 @@ def mock_dist_env(mocker: MockerFixture): return_value=MagicMock( torchair_graph_config=MagicMock(enabled=False), enable_multistream_moe=False, + enable_shared_expert_dp=False, expert_map_path=None )), \ patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map', diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 5694c23..ac0b3c5 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -60,6 +60,8 @@ class MtpProposer(Proposer): self.torchair_compiled_models = {} # type: ignore self.torchair_graph_enabled = get_ascend_config( ).torchair_graph_config.enabled + self.enable_shared_expert_dp = get_ascend_config( + ).enable_shared_expert_dp # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + @@ -79,7 +81,9 @@ class MtpProposer(Proposer): with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - if self.torchair_graph_enabled: + if self.torchair_graph_enabled or ( + self.enable_shared_expert_dp + and self.vllm_config.model_config.use_mla): self.model = TorchairDeepSeekMTP( vllm_config=self.vllm_config).to(target_device) else: diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index 1b34e84..28be2b7 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -813,7 +813,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): attn_metadata = get_forward_context().attn_metadata if attn_metadata is not None and isinstance(attn_metadata, dict): - attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] + attn_metadata = next(iter(attn_metadata.values()), None) if attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens else: diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 2377b50..0c85c85 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -803,6 +803,7 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp try: device_group = get_mc2_group().device_group @@ -884,6 +885,8 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) fused_moe_state = get_forward_context().fused_moe_state + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All if fused_moe_state == FusedMoEState.MC2: return torchair_fused_experts_with_mc2( @@ -1155,6 +1158,8 @@ class TorchairAscendFusedMoE(FusedMoE): forward_context = get_forward_context() fused_moe_state = forward_context.fused_moe_state mc2_mask = forward_context.mc2_mask + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index be212e2..23c4699 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -833,6 +833,7 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp try: device_group = get_mc2_group().device_group @@ -946,6 +947,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: ) fused_moe_state = get_forward_context().fused_moe_state + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All shared_gate_up, shared_dequant_scale = None, None if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: with npu_stream_switch("moe_secondary", 0): diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index c92fa4e..ebf61df 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -87,8 +87,8 @@ class NPUTorchairModelRunner(NPUModelRunner): ) -> tuple[int, Optional[torch.Tensor], bool, bool]: """Override from NPUModelRunner to pad num_tokens""" if self.enable_shared_expert_dp: - return super()._sync_metadata_across_dp(num_tokens, with_prefill, - enable_dbo) + # Padding is not required for shared_expert_dp cases in eager mode. + return num_tokens, None, with_prefill, enable_dbo if self.dp_size == 1: if not with_prefill: maybe_padded_num_tokens = self.select_torchair_padded_batch_size( diff --git a/vllm_ascend/torchair/torchair_worker.py b/vllm_ascend/torchair/torchair_worker.py index 2c8c458..dbee800 100644 --- a/vllm_ascend/torchair/torchair_worker.py +++ b/vllm_ascend/torchair/torchair_worker.py @@ -35,26 +35,25 @@ class NPUTorchairWorker(NPUWorker): ascend_config = get_ascend_config() if ascend_config.enable_shared_expert_dp: return available_kv_cache_memory - if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes and check_kv_cache_bytes_cache_exist( - ): - old_kv_cache_bytes = read_kv_cache_bytes_from_file( - torch.distributed.get_rank()) - if 0 < old_kv_cache_bytes <= available_kv_cache_memory: - logger.info( - f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}" - ) - self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes - return old_kv_cache_bytes - else: - logger.info( - "Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache" - ) - delete_torchair_cache_file() - bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE - available_kv_cache_memory -= bytes_floating_tolerance - logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}") - self.model_runner.new_kv_cache_bytes = available_kv_cache_memory - + if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes: + if check_kv_cache_bytes_cache_exist(): + old_kv_cache_bytes = read_kv_cache_bytes_from_file( + torch.distributed.get_rank()) + if 0 < old_kv_cache_bytes <= available_kv_cache_memory: + logger.info( + f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}" + ) + self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes + return old_kv_cache_bytes + else: + logger.info( + "Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache" + ) + delete_torchair_cache_file() + bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE + available_kv_cache_memory -= bytes_floating_tolerance + logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}") + self.model_runner.new_kv_cache_bytes = available_kv_cache_memory return available_kv_cache_memory def init_device(self):