[main][BugFix] Fixed an accuracy bug of Qwen3-next-MTP when batched inferring (#4932)

### What this PR does / why we need it?
Fixes an accuracy bug of Qwen3-next-MTP when batched inferring.
It is descibed in
https://github.com/vllm-project/vllm-ascend/issues/4930.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2025-12-15 13:22:30 +08:00
committed by GitHub
parent 545e856971
commit 8fb0ef5ffa
3 changed files with 17 additions and 6 deletions

View File

@@ -237,7 +237,7 @@
# Replace with a new bind_kv_cache.
# Skip the raise.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/4770
# It need discuss.
# Future Plan:
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
#
@@ -245,11 +245,15 @@
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
# Why:
# 'torch.argsort' func of npu does not support bool.
# 1. 'torch.argsort' func of npu does not support bool.
# 2. Without `stable=True`, the output will have a lot of redundant tokens.
# How
# Replace with a new torch.argsort that will cast the input to torch.int32.
# Replace with a new torch.argsort that will cast the input to torch.int32
# and do stable sort.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/4770
# 1. It depends on torch_npu.
# 2. https://github.com/vllm-project/vllm/pull/30632
# Future Plan:
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
#

View File

@@ -5,6 +5,8 @@ import torch
# TODO When the operator of argsort is ready, this patch must be removed.
def _argsort(tensor, *args, **kwargs):
if tensor.dtype == torch.bool:
# If it is not stable, it will have redundant outputs.
kwargs["stable"] = True
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
else:
return torch.argsort(tensor, *args, **kwargs)