diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 04682a2..b432db0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1520,13 +1520,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + # TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead if self.vllm_config.model_config.use_mla: # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_params(self.update_stream, forward_context, - positions.shape[0]) + maybe_padded_num_tokens) else: update_attn_params(self.update_stream, forward_context, - positions.shape[0]) + maybe_padded_num_tokens) if get_forward_context().sp_enabled: hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)