[BugFix][P/D] fix padding error on FullGraph mode && fix layerwise connector mamba accuracy (#7506)

### What this PR does / why we need it?
1. When the FullGraph mode is used, the branches in the Triton operator
are compiled and fixed during the graph capture process, causing the
branch condition in the `fused_recurrent_gated_delta_rule` operator,
which checks whether `ssm_state_indices >= 0` before writing to the SSM
cache, to become invalid. Now, the write operation is performed
regardless of the value. This results in the operator performing address
offset calculations and writing to the SSM cache based on the -1 offset
after -1 is used for padding in vLLM GDN backend. Since the conv cache
and SSM cache in vLLM Ascend implementation are actually a single
continuous tensor divided into two parts, this leads to data overwriting
and the generation of NaN values.
This PR addresses two cases where padding -1 is required in the GDN
metadata builder. The same logic is used to replace the padding with 0
to avoid the problem of memory overwriting, because block 0 is a
reserved block.
2. Fix layerwise connector bug for mamba cache sending on heterogeneous
TP.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
This commit is contained in:
zxr2333
2026-03-24 15:15:55 +08:00
committed by GitHub
parent 475b4b0cea
commit 67aad1fce8
2 changed files with 11 additions and 4 deletions

View File

@@ -2165,6 +2165,12 @@ class NPUModelRunner(GPUModelRunner):
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args,
)
# NOTE(zxr): Due to the Triton operator does not deal with -1 padding in FullGraph mode,
# the padding needs to be changed from -1 to 0 to avoid writing invalid mamba block.
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() \
and isinstance(builder, GDNAttentionMetadataBuilder) and attn_metadata_i.num_prefills == 0:
if attn_metadata_i.num_decodes == 0 and attn_metadata_i.num_spec_decodes > 0:
attn_metadata_i.spec_state_indices_tensor[attn_metadata_i.num_spec_decodes:].fill_(0)
if ubid is None:
assert isinstance(attn_metadata, dict)