[BUGFIX] Mtp torchair pd fix (#3449)

### What this PR does / why we need it?
In memory of https://github.com/vllm-project/vllm-ascend/pull/2610
In the pd Disaggregation scenario, the first token of the inference
after the d node receives the kv follows the eager mode.

Fixes:
Running with MTP torchair graph mode with Prefilling Decoding
Disaggregation , if all requests processed by the D node are requests
just transmitted from the P node, it will break the torchair graph.

Reason: During PD Disaggregation , the P node only transmits the KV
cache and prompt to the D node, not the actual tokens inferred (neither
the main model tokens nor the MTP tokens are transmitted). Therefore,
the D node will treat this request as one without MTP tokens for
inference (seq_len=1).
The community does not have graph mode issues because the community's
attention has a seq_len=1 for each batch during the decode phase.
We have issues because the graph mode pads according to processing 2
tokens per request. When there are some seq_len=1 and some seq_len=2,
padding is done at the end. If all requests received by the D node are
seq_len=1, padding cannot be performed normally according to the
attention's fia operator constraints.

Solution:

The kv consumer uses extra torchair graph padding to avoid breaking FIA
graph constrains (The one this PR implemented).

The kv producer provides the correct tokens to the kv consumer, so that
our graph mode constraints are not broken, and all logic is the same as
the PD mixed deployment . Since we are using the community scheduler,
the modification requires patching the vllm scheduler, but
theoretically, performance should be better. (Maybe later )

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
xuyexiong
2025-10-16 09:03:49 +08:00
committed by GitHub
parent 291c00a224
commit b0ae203e72

View File

@@ -452,10 +452,31 @@ class NPUTorchairModelRunner(NPUModelRunner):
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
# padded max number tokens = max_num_req * decode_token_per_req
self.torchair_graph_batch_sizes = [
graph_batch_size * self.decode_token_per_req
for graph_batch_size in self.torchair_graph_batch_sizes
]
if self.decode_token_per_req > 1:
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
if self.is_kv_consumer:
FIA_SEQ_LEN_LIMIT = 16
self.torchair_graph_batch_sizes = [
(graph_batch_size +
math.ceil(graph_batch_size / FIA_SEQ_LEN_LIMIT) +
math.ceil(graph_batch_size * self.decode_token_per_req /
FIA_SEQ_LEN_LIMIT / FIA_SEQ_LEN_LIMIT)) *
self.decode_token_per_req
for graph_batch_size in self.torchair_graph_batch_sizes
]
new_max_num_reqs = math.ceil(
max(self.torchair_graph_batch_sizes) /
self.decode_token_per_req)
if self.max_num_reqs < new_max_num_reqs:
logger.warning(
f"max_num_reqs is updated to {new_max_num_reqs}")
self.max_num_reqs = new_max_num_reqs
self.scheduler_config.max_num_seqs = new_max_num_reqs
else:
self.torchair_graph_batch_sizes = [
graph_batch_size * self.decode_token_per_req
for graph_batch_size in self.torchair_graph_batch_sizes
]
# NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size`
# Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same