### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|`vllm_ascend/ops/layer_shard_linear.py`|
|`vllm_ascend/ops/linear.py`|
|`vllm_ascend/ops/linear_op.py`|
|`vllm_ascend/worker/worker.py`|
| ` vllm_ascend/patch/worker/patch_bert.py` |
| ` vllm_ascend/patch/worker/patch_deepseek.py` |
| ` vllm_ascend/patch/worker/patch_distributed.py` |
| ` vllm_ascend/patch/worker/patch_module.py` |
| ` vllm_ascend/patch/worker/patch_multimodal_merge.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next_mtp.py` |
| ` vllm_ascend/patch/worker/patch_rejection_sampler.py` |
| ` vllm_ascend/patch/worker/patch_rope.py` |
| ` vllm_ascend/patch/worker/patch_triton.py` |
| ` vllm_ascend/patch/worker/patch_unquantized_gemm.py` |
| ` vllm_ascend/patch/worker/patch_v2_egale.py` |
|` vllm_ascend/worker/npu_input_batch.py`|
|` vllm_ascend/worker/v2/aclgraph_utils.py`|
|` vllm_ascend/worker/v2/attn_utils.py`|
|` vllm_ascend/worker/v2/model_runner.py`|
|` vllm_ascend/worker/v2/sample/gumbel.py`|
|` vllm_ascend/worker/v2/sample/penalties.py`|
|` vllm_ascend/worker/v2/sample/sampler.py`|
|` vllm_ascend/worker/v2/spec_decode/__init__.py`|
|` vllm_ascend/worker/v2/spec_decode/eagle.py`|
|` vllm_ascend/worker/v2/states.py`|
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
37 lines
985 B
Python
37 lines
985 B
Python
import torch
|
|
|
|
|
|
# torch_npu.argsort does not sipport bool now, it will support it in the future.
|
|
# TODO When the operator of argsort is ready, this patch must be removed.
|
|
def _argsort(tensor, *args, **kwargs):
|
|
if tensor.dtype == torch.bool:
|
|
# If it is not stable, it will have redundant outputs.
|
|
kwargs["stable"] = True
|
|
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
|
|
else:
|
|
return torch.argsort(tensor, *args, **kwargs)
|
|
|
|
|
|
class _TorchWrapper:
|
|
def __init__(self):
|
|
self._raw_torch = torch
|
|
|
|
def __getattr__(self, name):
|
|
if name == "argsort":
|
|
return _argsort
|
|
else:
|
|
return getattr(self._raw_torch, name)
|
|
|
|
|
|
_is_patched = False
|
|
|
|
|
|
# patch argsort only for torch in gdn_attn
|
|
def patch_torch_npu_argsort():
|
|
global _is_patched
|
|
if not _is_patched:
|
|
import vllm.v1.attention.backends.gdn_attn as gdn_attn
|
|
|
|
gdn_attn.torch = _TorchWrapper()
|
|
_is_patched = True
|