[Bugfix] Fix mtp torchair in pd Disaggregation scenario (#2951)

### What this PR does / why we need it?
1. In memory of #2509, Fix mtp torchair in pd Disaggregation scenario
2. fix mla bug in SpecDecoding Scenario, since num_decodes !=
num_decode_tokens


### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.10.2
- vLLM main:
5206ab20ba

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
xuyexiong
2025-09-17 09:07:58 +08:00
committed by GitHub
parent 6b7117dbb7
commit ae758dda05
3 changed files with 58 additions and 9 deletions

View File

@@ -424,7 +424,13 @@ class NPUTorchairModelRunner(NPUModelRunner):
def update_torchair_graph_batch_sizes(self):
# return graph_batch_sizes according to the max number of tokens
# first pad according to the number of requests
if len(self.torchair_graph_batch_sizes) == 0:
if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
# pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs
self.torchair_graph_batch_sizes = [self.max_num_reqs]
logger.warning(
"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs]"
)
elif len(self.torchair_graph_batch_sizes) == 0:
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
else:
self.torchair_graph_batch_sizes = sorted(