Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_module.py
drslark 8fb0ef5ffa [main][BugFix] Fixed an accuracy bug of Qwen3-next-MTP when batched inferring (#4932)
### What this PR does / why we need it?
Fixes an accuracy bug of Qwen3-next-MTP when batched inferring.
It is descibed in
https://github.com/vllm-project/vllm-ascend/issues/4930.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: drslark <slarksblood@qq.com>
2025-12-15 13:22:30 +08:00

37 lines
985 B
Python

import torch
# torch_npu.argsort does not sipport bool now, it will support it in the future.
# TODO When the operator of argsort is ready, this patch must be removed.
def _argsort(tensor, *args, **kwargs):
if tensor.dtype == torch.bool:
# If it is not stable, it will have redundant outputs.
kwargs["stable"] = True
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
else:
return torch.argsort(tensor, *args, **kwargs)
class _TorchWrapper:
def __init__(self):
self._raw_torch = torch
def __getattr__(self, name):
if name == "argsort":
return _argsort
else:
return getattr(self._raw_torch, name)
_is_patched = False
# patch argsort only for torch in gdn_attn
def patch_torch_npu_argsort():
global _is_patched
if not _is_patched:
import vllm.v1.attention.backends.gdn_attn as gdn_attn
gdn_attn.torch = _TorchWrapper()
_is_patched = True