From 418a43e2a2f853282c4337ebac2455cd8f316b6f Mon Sep 17 00:00:00 2001 From: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Date: Fri, 23 Jan 2026 11:29:54 +0800 Subject: [PATCH] [Bugfix] Fix seq_lens reset issue causing performance degradation (#6158) ### What this PR does / why we need it? Now `seq_lens` was not being reset correctly after each step due to missing code that clears the sequence lengths. As a result, when processing a smaller batch after a larger batch, the `seq_lens` from the larger batch was still carried over. This caused the attention operator to compute using an unnecessarily larger sequence length, leading to an increased computation load and performance degradation. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: ZYang6263 --- vllm_ascend/worker/model_runner_v1.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7409d21c..a0e291bf 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -974,6 +974,8 @@ class NPUModelRunner(GPUModelRunner): 1:pad_size + 1] * self.uniform_decode_query_len + last_query_loc self.query_start_loc.copy_to_gpu(num_reqs_padded + 1) + self.seq_lens.np[num_reqs:].fill(0) + self.seq_lens.copy_to_gpu(num_reqs_padded) # So we are trying to simulate the behavior of GPUModelRunner's # prepare_inputs for uniform decode mode by padding query_start_loc