[Bugfix] Fix mtp torchair in pd Disaggregation scenario (#2951)
### What this PR does / why we need it?
1. In memory of #2509, Fix mtp torchair in pd Disaggregation scenario
2. fix mla bug in SpecDecoding Scenario, since num_decodes !=
num_decode_tokens
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.10.2
- vLLM main:
5206ab20ba
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@@ -424,7 +424,13 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
def update_torchair_graph_batch_sizes(self):
|
||||
# return graph_batch_sizes according to the max number of tokens
|
||||
# first pad according to the number of requests
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
||||
# pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs
|
||||
self.torchair_graph_batch_sizes = [self.max_num_reqs]
|
||||
logger.warning(
|
||||
"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs]"
|
||||
)
|
||||
elif len(self.torchair_graph_batch_sizes) == 0:
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
else:
|
||||
self.torchair_graph_batch_sizes = sorted(
|
||||
|
||||
Reference in New Issue
Block a user