[BUGFIX] Mtp torchair pd fix (#3449)
### What this PR does / why we need it? In memory of https://github.com/vllm-project/vllm-ascend/pull/2610 In the pd Disaggregation scenario, the first token of the inference after the d node receives the kv follows the eager mode. Fixes: Running with MTP torchair graph mode with Prefilling Decoding Disaggregation , if all requests processed by the D node are requests just transmitted from the P node, it will break the torchair graph. Reason: During PD Disaggregation , the P node only transmits the KV cache and prompt to the D node, not the actual tokens inferred (neither the main model tokens nor the MTP tokens are transmitted). Therefore, the D node will treat this request as one without MTP tokens for inference (seq_len=1). The community does not have graph mode issues because the community's attention has a seq_len=1 for each batch during the decode phase. We have issues because the graph mode pads according to processing 2 tokens per request. When there are some seq_len=1 and some seq_len=2, padding is done at the end. If all requests received by the D node are seq_len=1, padding cannot be performed normally according to the attention's fia operator constraints. Solution: The kv consumer uses extra torchair graph padding to avoid breaking FIA graph constrains (The one this PR implemented). The kv producer provides the correct tokens to the kv consumer, so that our graph mode constraints are not broken, and all logic is the same as the PD mixed deployment . Since we are using the community scheduler, the modification requires patching the vllm scheduler, but theoretically, performance should be better. (Maybe later ) ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@@ -452,10 +452,31 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
|
||||
|
||||
# padded max number tokens = max_num_req * decode_token_per_req
|
||||
self.torchair_graph_batch_sizes = [
|
||||
graph_batch_size * self.decode_token_per_req
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||
]
|
||||
if self.decode_token_per_req > 1:
|
||||
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
|
||||
if self.is_kv_consumer:
|
||||
FIA_SEQ_LEN_LIMIT = 16
|
||||
self.torchair_graph_batch_sizes = [
|
||||
(graph_batch_size +
|
||||
math.ceil(graph_batch_size / FIA_SEQ_LEN_LIMIT) +
|
||||
math.ceil(graph_batch_size * self.decode_token_per_req /
|
||||
FIA_SEQ_LEN_LIMIT / FIA_SEQ_LEN_LIMIT)) *
|
||||
self.decode_token_per_req
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||
]
|
||||
new_max_num_reqs = math.ceil(
|
||||
max(self.torchair_graph_batch_sizes) /
|
||||
self.decode_token_per_req)
|
||||
if self.max_num_reqs < new_max_num_reqs:
|
||||
logger.warning(
|
||||
f"max_num_reqs is updated to {new_max_num_reqs}")
|
||||
self.max_num_reqs = new_max_num_reqs
|
||||
self.scheduler_config.max_num_seqs = new_max_num_reqs
|
||||
else:
|
||||
self.torchair_graph_batch_sizes = [
|
||||
graph_batch_size * self.decode_token_per_req
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||
]
|
||||
|
||||
# NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||
# Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same
|
||||
|
||||
Reference in New Issue
Block a user