[BugFix][main] Adapted Qwen3-Next-MTP to chunked prefill (#4770)
### What this PR does / why we need it?
The pad `-1` modification is from
https://github.com/vllm-project/vllm/pull/25743.
It still has bugs for batched chunked prefill.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: drslark <slarksblood@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
34
vllm_ascend/patch/worker/patch_module.py
Normal file
34
vllm_ascend/patch/worker/patch_module.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
|
||||
|
||||
# torch_npu.argsort does not sipport bool now, it will support it in the future.
|
||||
# TODO When the operator of argsort is ready, this patch must be removed.
|
||||
def _argsort(tensor, *args, **kwargs):
|
||||
if tensor.dtype == torch.bool:
|
||||
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
|
||||
else:
|
||||
return torch.argsort(tensor, *args, **kwargs)
|
||||
|
||||
|
||||
class _TorchWrapper:
|
||||
|
||||
def __init__(self):
|
||||
self._raw_torch = torch
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "argsort":
|
||||
return _argsort
|
||||
else:
|
||||
return getattr(self._raw_torch, name)
|
||||
|
||||
|
||||
_is_patched = False
|
||||
|
||||
|
||||
# patch argsort only for torch in gdn_attn
|
||||
def patch_torch_npu_argsort():
|
||||
global _is_patched
|
||||
if not _is_patched:
|
||||
import vllm.v1.attention.backends.gdn_attn as gdn_attn
|
||||
gdn_attn.torch = _TorchWrapper()
|
||||
_is_patched = True
|
||||
Reference in New Issue
Block a user