diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 86fa7a4b..0850c382 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -934,23 +934,27 @@ class NPUModelRunner(GPUModelRunner): # TODO: We should make this official ASAP. Also note that if we pad here, # the builders won’t need to add any extra padding. - max_decode_tokens = self.scheduler_config.max_num_seqs * self.uniform_decode_query_len if self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ - uniform_decode and self.uniform_decode_query_len <= num_input_tokens <= max_decode_tokens: - num_reqs_padded = num_input_tokens // self.uniform_decode_query_len - pad_size = num_reqs_padded - num_reqs - if pad_size > 0: - last_query_loc = self.query_start_loc.np[num_reqs] + uniform_decode: + max_decode_tokens = min( + self.scheduler_config.max_num_seqs * + self.uniform_decode_query_len, + self.cudagraph_batch_sizes[-1]) + if self.uniform_decode_query_len <= num_input_tokens <= max_decode_tokens: + num_reqs_padded = num_input_tokens // self.uniform_decode_query_len + pad_size = num_reqs_padded - num_reqs + if pad_size > 0: + last_query_loc = self.query_start_loc.np[num_reqs] - self.query_start_loc.np[ - num_reqs + 1:num_reqs_padded + 1] = self.arange_np[ - 1:pad_size + - 1] * self.uniform_decode_query_len + last_query_loc - self.query_start_loc.copy_to_gpu(num_reqs_padded + 1) + self.query_start_loc.np[ + num_reqs + 1:num_reqs_padded + 1] = self.arange_np[ + 1:pad_size + + 1] * self.uniform_decode_query_len + last_query_loc + self.query_start_loc.copy_to_gpu(num_reqs_padded + 1) - # So we are trying to simulate the behavior of GPUModelRunner's - # prepare_inputs for uniform decode mode by padding query_start_loc - num_reqs = num_reqs_padded + # So we are trying to simulate the behavior of GPUModelRunner's + # prepare_inputs for uniform decode mode by padding query_start_loc + num_reqs = num_reqs_padded # Make AscendCommonAttentionMetadata common_attn_metadata = AscendCommonAttentionMetadata(