From 916a9a1913ebdfa338378b385f7ba03618a8bee0 Mon Sep 17 00:00:00 2001 From: Ronald Date: Mon, 8 Dec 2025 09:07:59 +0800 Subject: [PATCH] fix synchronize error of exceeds_max_model_len d2h copy (#4708) ### What this PR does / why we need it? there is d2h copy blocking cpu operations in mtp propose method, which make host bound issue. this pr refactor it and use cpu tensor to implement it. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? vllm main f5d3d93c40417c296c20dc301100e55708a17f3f - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: Ronald1995 Co-authored-by: wangxiyuan --- vllm_ascend/spec_decode/mtp_proposer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index b8923a79..d340cdbc 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -885,10 +885,9 @@ class MtpProposer(Proposer): attn_metadata_i.seq_lens = attn_metadata_i.seq_lens + 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - exceeds_max_model_len_cpu = exceeds_max_model_len.to( - attn_metadata_i.seq_lens.device, non_blocking=False) - attn_metadata_i.seq_lens[:batch_size].masked_fill_( - exceeds_max_model_len_cpu, 1) + exceeds_mask = attn_metadata_i.seq_lens[:batch_size] > \ + self.runner.model_config.max_model_len + attn_metadata_i.seq_lens[:batch_size].masked_fill_(exceeds_mask, 1) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens.