diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 7978b58..51c9508 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -467,7 +467,7 @@ class NPUTorchairModelRunner(NPUModelRunner): # pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs self.torchair_graph_batch_sizes = [self.max_num_reqs] logger.warning( - "is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs]" + f"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs] {[self.max_num_reqs]}" ) elif len(self.torchair_graph_batch_sizes) == 0: self.torchair_graph_batch_sizes = [1, self.max_num_reqs] @@ -517,11 +517,16 @@ class NPUTorchairModelRunner(NPUModelRunner): f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." ) - new_max_num_reqs = max(new_graph_batch_sizes) + new_max_num_reqs = math.ceil( + max(new_graph_batch_sizes) / self.decode_token_per_req) if self.max_num_reqs != new_max_num_reqs: logger.warning(f"max_num_reqs is updated to {new_max_num_reqs}") self.max_num_reqs = new_max_num_reqs - self.scheduler_config.max_num_seqs = new_max_num_reqs + if not (self.decode_token_per_req > 1 and self.is_kv_consumer): + # Do not update scheduler_config.max_num_seqs in KV consumer + MTP + # Since FIA need extra space for padding + # Enforce self.max_num_seqs > self.scheduler_config.max_num_seqs in KV consumer + MTP + self.scheduler_config.max_num_seqs = new_max_num_reqs if new_graph_batch_sizes != self.torchair_graph_batch_sizes: logger.warning( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 784062c..fb7bbb3 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2390,7 +2390,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens - max_num_reqs = self.scheduler_config.max_num_seqs + max_num_reqs = self.max_num_reqs if uniform_decode: num_reqs = cdiv(num_tokens, max_query_len) num_scheduled_tokens_list = [max_query_len] * num_reqs