diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index b29fdc2..e6ef03e 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -452,10 +452,31 @@ class NPUTorchairModelRunner(NPUModelRunner): self.torchair_graph_batch_sizes.append(self.max_num_reqs) # padded max number tokens = max_num_req * decode_token_per_req - self.torchair_graph_batch_sizes = [ - graph_batch_size * self.decode_token_per_req - for graph_batch_size in self.torchair_graph_batch_sizes - ] + if self.decode_token_per_req > 1: + # pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens + if self.is_kv_consumer: + FIA_SEQ_LEN_LIMIT = 16 + self.torchair_graph_batch_sizes = [ + (graph_batch_size + + math.ceil(graph_batch_size / FIA_SEQ_LEN_LIMIT) + + math.ceil(graph_batch_size * self.decode_token_per_req / + FIA_SEQ_LEN_LIMIT / FIA_SEQ_LEN_LIMIT)) * + self.decode_token_per_req + for graph_batch_size in self.torchair_graph_batch_sizes + ] + new_max_num_reqs = math.ceil( + max(self.torchair_graph_batch_sizes) / + self.decode_token_per_req) + if self.max_num_reqs < new_max_num_reqs: + logger.warning( + f"max_num_reqs is updated to {new_max_num_reqs}") + self.max_num_reqs = new_max_num_reqs + self.scheduler_config.max_num_seqs = new_max_num_reqs + else: + self.torchair_graph_batch_sizes = [ + graph_batch_size * self.decode_token_per_req + for graph_batch_size in self.torchair_graph_batch_sizes + ] # NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size` # Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same