[main][BugFix] Fixed an accuracy bug of Qwen3-next-MTP when batched inferring (#4932)
### What this PR does / why we need it?
Fixes an accuracy bug of Qwen3-next-MTP when batched inferring.
It is descibed in
https://github.com/vllm-project/vllm-ascend/issues/4930.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -5,6 +5,8 @@ import torch
|
||||
# TODO When the operator of argsort is ready, this patch must be removed.
|
||||
def _argsort(tensor, *args, **kwargs):
|
||||
if tensor.dtype == torch.bool:
|
||||
# If it is not stable, it will have redundant outputs.
|
||||
kwargs["stable"] = True
|
||||
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
|
||||
else:
|
||||
return torch.argsort(tensor, *args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user