[main][BugFix] Fixed an accuracy bug of Qwen3-next-MTP when batched inferring (#4932)
### What this PR does / why we need it?
Fixes an accuracy bug of Qwen3-next-MTP when batched inferring.
It is descibed in
https://github.com/vllm-project/vllm-ascend/issues/4930.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -61,9 +61,14 @@ def test_qwen3_next_distributed_mp_full_decode_only_tp4():
|
|||||||
del vllm_model
|
del vllm_model
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fix the accuary of batch chunked prefill
|
|
||||||
def test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4():
|
def test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4():
|
||||||
example_prompts = ["Hello, my name is"]
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
max_tokens = 20
|
max_tokens = 20
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
|
|||||||
@@ -237,7 +237,7 @@
|
|||||||
# Replace with a new bind_kv_cache.
|
# Replace with a new bind_kv_cache.
|
||||||
# Skip the raise.
|
# Skip the raise.
|
||||||
# Related PR (if no, explain why):
|
# Related PR (if no, explain why):
|
||||||
# https://github.com/vllm-project/vllm/pull/4770
|
# It need discuss.
|
||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
|
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
|
||||||
#
|
#
|
||||||
@@ -245,11 +245,15 @@
|
|||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
|
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
|
||||||
# Why:
|
# Why:
|
||||||
# 'torch.argsort' func of npu does not support bool.
|
# 1. 'torch.argsort' func of npu does not support bool.
|
||||||
|
# 2. Without `stable=True`, the output will have a lot of redundant tokens.
|
||||||
# How:
|
# How:
|
||||||
# Replace with a new torch.argsort that will cast the input to torch.int32.
|
# Replace with a new torch.argsort that will cast the input to torch.int32
|
||||||
|
# and do stable sort.
|
||||||
# Related PR (if no, explain why):
|
# Related PR (if no, explain why):
|
||||||
# https://github.com/vllm-project/vllm/pull/4770
|
# 1. It depends on torch_npu.
|
||||||
|
# 2. https://github.com/vllm-project/vllm/pull/30632
|
||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
|
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
|
||||||
|
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import torch
|
|||||||
# TODO When the operator of argsort is ready, this patch must be removed.
|
# TODO When the operator of argsort is ready, this patch must be removed.
|
||||||
def _argsort(tensor, *args, **kwargs):
|
def _argsort(tensor, *args, **kwargs):
|
||||||
if tensor.dtype == torch.bool:
|
if tensor.dtype == torch.bool:
|
||||||
|
# If it is not stable, it will have redundant outputs.
|
||||||
|
kwargs["stable"] = True
|
||||||
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
|
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return torch.argsort(tensor, *args, **kwargs)
|
return torch.argsort(tensor, *args, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user