From 8fb0ef5ffa06f8006bed0aa96285564879f669ca Mon Sep 17 00:00:00 2001 From: drslark <96540755+drslark@users.noreply.github.com> Date: Mon, 15 Dec 2025 13:22:30 +0800 Subject: [PATCH] [main][BugFix] Fixed an accuracy bug of Qwen3-next-MTP when batched inferring (#4932) ### What this PR does / why we need it? Fixes an accuracy bug of Qwen3-next-MTP when batched inferring. It is descibed in https://github.com/vllm-project/vllm-ascend/issues/4930. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: drslark --- tests/e2e/multicard/test_qwen3_next.py | 9 +++++++-- vllm_ascend/patch/__init__.py | 12 ++++++++---- vllm_ascend/patch/worker/patch_module.py | 2 ++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/tests/e2e/multicard/test_qwen3_next.py b/tests/e2e/multicard/test_qwen3_next.py index 7a7fe64b..83387acc 100644 --- a/tests/e2e/multicard/test_qwen3_next.py +++ b/tests/e2e/multicard/test_qwen3_next.py @@ -61,9 +61,14 @@ def test_qwen3_next_distributed_mp_full_decode_only_tp4(): del vllm_model -# TODO: Fix the accuary of batch chunked prefill def test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4(): - example_prompts = ["Hello, my name is"] + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + max_tokens = 20 with VllmRunner( diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 08b1c7a4..092e2ce5 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -237,7 +237,7 @@ # Replace with a new bind_kv_cache. # Skip the raise. # Related PR (if no, explain why): -# https://github.com/vllm-project/vllm/pull/4770 +# It need discuss. # Future Plan: # Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu. # @@ -245,11 +245,15 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort` # Why: -# 'torch.argsort' func of npu does not support bool. +# 1. 'torch.argsort' func of npu does not support bool. +# 2. Without `stable=True`, the output will have a lot of redundant tokens. # How: -# Replace with a new torch.argsort that will cast the input to torch.int32. +# Replace with a new torch.argsort that will cast the input to torch.int32 +# and do stable sort. # Related PR (if no, explain why): -# https://github.com/vllm-project/vllm/pull/4770 +# 1. It depends on torch_npu. +# 2. https://github.com/vllm-project/vllm/pull/30632 # Future Plan: # Remove this patch when bool is supported in 'torch.argsort' func of npu. +# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable. # diff --git a/vllm_ascend/patch/worker/patch_module.py b/vllm_ascend/patch/worker/patch_module.py index e8724473..eeca3a95 100644 --- a/vllm_ascend/patch/worker/patch_module.py +++ b/vllm_ascend/patch/worker/patch_module.py @@ -5,6 +5,8 @@ import torch # TODO When the operator of argsort is ready, this patch must be removed. def _argsort(tensor, *args, **kwargs): if tensor.dtype == torch.bool: + # If it is not stable, it will have redundant outputs. + kwargs["stable"] = True return torch.argsort(tensor.to(torch.int32), *args, **kwargs) else: return torch.argsort(tensor, *args, **kwargs)