2025-12-10 22:54:24 +08:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# torch_npu.argsort does not sipport bool now, it will support it in the future.
|
|
|
|
|
# TODO When the operator of argsort is ready, this patch must be removed.
|
|
|
|
|
def _argsort(tensor, *args, **kwargs):
|
|
|
|
|
if tensor.dtype == torch.bool:
|
2025-12-15 13:22:30 +08:00
|
|
|
# If it is not stable, it will have redundant outputs.
|
|
|
|
|
kwargs["stable"] = True
|
2025-12-10 22:54:24 +08:00
|
|
|
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return torch.argsort(tensor, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _TorchWrapper:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self._raw_torch = torch
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
|
|
if name == "argsort":
|
|
|
|
|
return _argsort
|
|
|
|
|
else:
|
|
|
|
|
return getattr(self._raw_torch, name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_is_patched = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# patch argsort only for torch in gdn_attn
|
|
|
|
|
def patch_torch_npu_argsort():
|
|
|
|
|
global _is_patched
|
|
|
|
|
if not _is_patched:
|
|
|
|
|
import vllm.v1.attention.backends.gdn_attn as gdn_attn
|
2026-02-06 15:35:06 +08:00
|
|
|
|
2025-12-10 22:54:24 +08:00
|
|
|
gdn_attn.torch = _TorchWrapper()
|
|
|
|
|
_is_patched = True
|