diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index b8923a79..d340cdbc 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -885,10 +885,9 @@ class MtpProposer(Proposer): attn_metadata_i.seq_lens = attn_metadata_i.seq_lens + 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - exceeds_max_model_len_cpu = exceeds_max_model_len.to( - attn_metadata_i.seq_lens.device, non_blocking=False) - attn_metadata_i.seq_lens[:batch_size].masked_fill_( - exceeds_max_model_len_cpu, 1) + exceeds_mask = attn_metadata_i.seq_lens[:batch_size] > \ + self.runner.model_config.max_model_len + attn_metadata_i.seq_lens[:batch_size].masked_fill_(exceeds_mask, 1) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens.