Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_module.py
drslark 0fb1dc43a1 [BugFix][main] Adapted Qwen3-Next-MTP to chunked prefill (#4770)
### What this PR does / why we need it?
The pad `-1` modification is from
https://github.com/vllm-project/vllm/pull/25743.

It still has bugs for batched chunked prefill.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: drslark <slarksblood@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 22:54:24 +08:00

35 lines
890 B
Python

import torch
# torch_npu.argsort does not sipport bool now, it will support it in the future.
# TODO When the operator of argsort is ready, this patch must be removed.
def _argsort(tensor, *args, **kwargs):
if tensor.dtype == torch.bool:
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
else:
return torch.argsort(tensor, *args, **kwargs)
class _TorchWrapper:
def __init__(self):
self._raw_torch = torch
def __getattr__(self, name):
if name == "argsort":
return _argsort
else:
return getattr(self._raw_torch, name)
_is_patched = False
# patch argsort only for torch in gdn_attn
def patch_torch_npu_argsort():
global _is_patched
if not _is_patched:
import vllm.v1.attention.backends.gdn_attn as gdn_attn
gdn_attn.torch = _TorchWrapper()
_is_patched = True