[Misc] Remove CP Redundant Variables after FIA operator enables for CANN 8.5 (#6013)

### What this PR does / why we need it?
PCP/DCP splits the kv-cache onto different cards. After introducing the
parameter cp-kv-cache-interleave-size, the first size tokens will be
cached at Card 0, and so on.
However, if there are too few tokens, some cards will not store the
key-value pairs, resulting in values ​​of 0, corrupted values, and
precision issues. Currently, additional operations are introduced to
avoid this precision problem.

After we integrate FIA operator in mla_cp._forward_decode and CANN
updates to 8.5.0, we now can remove these additional operations.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
passed all CI by CANN 8.5.0
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

Signed-off-by: dsxsteven <dsxsteven@sina.com>
Signed-off-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com>
This commit is contained in:
dsxsteven
2026-01-23 14:13:12 +08:00
committed by GitHub
parent 418a43e2a2
commit 8378bc28b0
8 changed files with 78 additions and 57 deletions

View File

@@ -477,7 +477,7 @@ class MtpProposer(EagleProposer):
self.positions[:batch_size] = clamped_positions
self.hidden_states[:hidden_states.shape[0]] = hidden_states
if self.pcp_size * self.dcp_size > 1:
# update local seq_len and batch_seq_mask
# update local seq_len
num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens(
ori_seq_len + step + 1,
self.pcp_size,
@@ -486,14 +486,7 @@ class MtpProposer(EagleProposer):
)
cp_seq_len = \
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank]
batch_seq_mask = (cp_seq_len == 0)
builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
batch_seq_mask, non_blocking=True)
batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask.
shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
attn_metadata_i.decode.cp_seq_len = cp_seq_len
attn_metadata_i.decode.batch_seq_mask = batch_seq_mask
# update slot_mapping
slot_indices += self.pcp_size
slot_mapping = mtp_slot_mapping[slot_indices]