From 5645ca839202428e9478762726cfc398b39fac27 Mon Sep 17 00:00:00 2001 From: zxr2333 <64738772+nwpu-zxr@users.noreply.github.com> Date: Tue, 17 Mar 2026 23:03:45 +0800 Subject: [PATCH] [BugFix]A2 MOE method&& layerwise MTP bugfix && Mamba gdn_metadata bugfix (#7364) ### What this PR does / why we need it? Some bug fixes, mainly including: 1. For A2, the number of experts each single card cannot be greater than 16 when using MC2. The PR fixed the error in the A2 moe communication method selection, which would cause the selection of an incorrect communication method when the number of model experts exceeds 256. For example, when using an A2 16-cards model to load the PD-disaggregation D node with Qwen3.5 series models, the incorrect MC2 method would be chosen. 2. Fixed the issue where the layerwise connector sends the kv-cache of the MTP layer multiple times when `num_spec_tokens` > 1. Now, the kv-cache is sent only when the MTP layer is forward for the first time. 3. Fix the accuracy issue of qwen3.5 when using MTP for PD disaggregation. The cause is that `num_decode_draft_tokens` does not consider that `spec_tokens` are not existed during the first inference when PD disaggregation (`spec_tokens` are generated during the first inference). However, `spec_tokens_padding` is added by `recomputed_scheduler`. As a result, `gdn_metadata` incorrectly considers that the prefill with a length of 2 is performed. --------- Signed-off-by: nwpu-zxr Signed-off-by: zxr2333 <64738772+nwpu-zxr@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm_ascend/ascend_forward_context.py | 11 ++++++----- .../kv_p2p/mooncake_layerwise_connector.py | 3 +++ vllm_ascend/worker/model_runner_v1.py | 14 +++++++------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 7c7242ef..243dede6 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -233,11 +233,12 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1: moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendDeviceType.A2}: - if ( - num_tokens <= mc2_tokens_capacity - and vllm_config.parallel_config.world_size_across_dp / vllm_config.parallel_config.pipeline_parallel_size - >= 16 - ): + num_experts = vllm_config.model_config.get_num_experts() + ep_world_size = ( + vllm_config.parallel_config.world_size_across_dp // vllm_config.parallel_config.pipeline_parallel_size + ) + num_experts_per_device = num_experts // ep_world_size + if num_experts_per_device <= 24 and ep_world_size >= 16 and num_tokens <= mc2_tokens_capacity: moe_comm_type = MoECommType.MC2 else: moe_comm_type = MoECommType.ALLGATHER diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index ea82322b..f4de1323 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -1493,6 +1493,9 @@ class MooncakeLayerwiseConnectorWorker: ) -> None: """MooncakeLayerwiseConnector does not save explicitly.""" if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(): + if self.current_layer >= self.total_layers: + self.current_layer += 1 + return # get reshape and cache event if layer_name == "": layer_name = self.index_to_name[self.current_layer][0] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8fc168ac..f2ab3074 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -831,6 +831,7 @@ class NPUModelRunner(GPUModelRunner): num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) # For chunked prefills, use -1 as mask rather than 0, as guided # decoding may rollback speculative tokens. + new_schedule_reqs = [x.req_id for x in scheduler_output.scheduled_new_reqs] num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) for ( req_id, @@ -838,13 +839,12 @@ class NPUModelRunner(GPUModelRunner): ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = ( - len(draft_token_ids) - if ( - self.input_batch.num_computed_tokens_cpu[req_idx] >= self.input_batch.num_prompt_tokens[req_idx] - ) - else -1 - ) + if (self.is_kv_consumer and req_id in new_schedule_reqs) or \ + (self.input_batch.num_computed_tokens_cpu[req_idx] >= \ + self.input_batch.num_prompt_tokens[req_idx]): + num_decode_draft_tokens[req_idx] = len(draft_token_ids) + else: + num_decode_draft_tokens[req_idx] = -1 spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens,