From 67aad1fce82265cd62a397e6740b4805e9910392 Mon Sep 17 00:00:00 2001 From: zxr2333 <64738772+nwpu-zxr@users.noreply.github.com> Date: Tue, 24 Mar 2026 15:15:55 +0800 Subject: [PATCH] [BugFix][P/D] fix padding error on FullGraph mode && fix layerwise connector mamba accuracy (#7506) ### What this PR does / why we need it? 1. When the FullGraph mode is used, the branches in the Triton operator are compiled and fixed during the graph capture process, causing the branch condition in the `fused_recurrent_gated_delta_rule` operator, which checks whether `ssm_state_indices >= 0` before writing to the SSM cache, to become invalid. Now, the write operation is performed regardless of the value. This results in the operator performing address offset calculations and writing to the SSM cache based on the -1 offset after -1 is used for padding in vLLM GDN backend. Since the conv cache and SSM cache in vLLM Ascend implementation are actually a single continuous tensor divided into two parts, this leads to data overwriting and the generation of NaN values. This PR addresses two cases where padding -1 is required in the GDN metadata builder. The same logic is used to replace the padding with 0 to avoid the problem of memory overwriting, because block 0 is a reserved block. 2. Fix layerwise connector bug for mamba cache sending on heterogeneous TP. - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c --------- Signed-off-by: nwpu-zxr --- .../kv_transfer/kv_p2p/mooncake_layerwise_connector.py | 9 +++++---- vllm_ascend/worker/model_runner_v1.py | 6 ++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index bd6b0975..b800b9ed 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -317,7 +317,8 @@ class KVCacheSendingLayerThread(threading.Thread): for local_conv_offset, local_conv_size in zip(local_conv_offsets, local_conv_sizes): local_addr_offset = (i * conv_shape[1] + local_conv_offset) * get_dtype_size(conv_dtype) remote_addr_offset = ( - (i * conv_shape[1] * tp_ratio) + (self.tp_rank % tp_ratio) * local_conv_size + (i * conv_shape[1] + local_conv_offset) * tp_ratio + + (self.tp_rank % tp_ratio) * local_conv_size ) * get_dtype_size(conv_dtype) src_list.append(local_conv_addr + local_block_ids[0] * local_conv_len + local_addr_offset) dst_list.append(remote_conv_addr + remote_block_ids[0] * remote_conv_len + remote_addr_offset) @@ -1508,9 +1509,9 @@ class MooncakeLayerwiseConnectorWorker: # get reshape and cache event if layer_name == "": layer_name = self.index_to_name[self.current_layer][0] - if ( - type(attn_metadata) is dict and not getattr(attn_metadata[layer_name], "reshape_cache_event", None) - ) or (not getattr(attn_metadata, "reshape_cache_event", None)): + if (self.use_mla and not hasattr(attn_metadata[layer_name], "reshape_cache_event")) or ( + not self.use_mla and not hasattr(attn_metadata, "reshape_cache_event") + ): reshape_cache_event = torch.npu.Event() reshape_cache_event.record() elif self.use_mla: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f89404ef..b0101202 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2165,6 +2165,12 @@ class NPUModelRunner(GPUModelRunner): common_attn_metadata=common_attn_metadata, **extra_attn_metadata_args, ) + # NOTE(zxr): Due to the Triton operator does not deal with -1 padding in FullGraph mode, + # the padding needs to be changed from -1 to 0 to avoid writing invalid mamba block. + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() \ + and isinstance(builder, GDNAttentionMetadataBuilder) and attn_metadata_i.num_prefills == 0: + if attn_metadata_i.num_decodes == 0 and attn_metadata_i.num_spec_decodes > 0: + attn_metadata_i.spec_state_indices_tensor[attn_metadata_i.num_spec_decodes:].fill_(0) if ubid is None: assert isinstance(attn_metadata, dict)