[Feat] Refactor rejection sampler (#4975)

### What this PR does / why we need it?

Currently, we are using `AscendRejctionSampler` that extends from
`RejctionSampler` in spec decoding. `AscendRejctionSampler` override
`forward` of `RejctionSampler`, only aming to replace `rejection_sample`
func. This
causes a lot of code of `RejctionSampler` cannot be reused, for example:
- https://github.com/vllm-project/vllm/pull/19482
- https://github.com/vllm-project/vllm/pull/26060
- https://github.com/vllm-project/vllm/pull/29223

#### Proposed Change:
- Delete `AscendRejctionSampler` and use `RejctionSampler` directly in
model runner.
- Patch `RejctionSampler.expand_batch_to_tokens` and
`RejctionSampler.rejection_sample`, maybe a better way is to make them
as custom ops.
- Modify `NPUModelRunner` following
https://github.com/vllm-project/vllm/pull/26060

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- [x] test logits processor for spec decoding
- [x] test logprobs for spec decoding
- [x] test logprobs for spec decoding + async shcheduling (test with
https://github.com/vllm-project/vllm-ascend/pull/4893/)


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-12-16 11:32:26 +08:00
committed by GitHub
parent 5f840696c1
commit 9e24bdd44c
6 changed files with 260 additions and 236 deletions

View File

@@ -228,7 +228,7 @@
# Future Plan:
# Remove this patch when the bug is fixed.
#
# ** File: worker/patch_qwen3_next_mtp.py**
# ** 11. File: worker/patch_qwen3_next_mtp.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.worker.utils.bind_kv_cache`
# Why:
@@ -241,7 +241,7 @@
# Future Plan:
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
#
# ** File: worker/patch_module.py**
# ** 12. File: worker/patch_module.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
# Why:
@@ -257,3 +257,19 @@
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
#
# ** 13. File: worker/patch_rejection_sampler.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.sample.rejection_sampler`
# Why:
# - some functions from `rejection_sampler` are not supported or slow on npu.
# How
# - add npu_top_k_top_p to 'apply_sampling_constraints' func
# - add custom triton kernel to `expand_batch_to_tokens` and `rejection_sample`
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/874
# https://github.com/vllm-project/vllm/pull/4849
# Future Plan:
# 1. make these functions as class func of RejectionSampler, create AscendRejectionSampler
# to override them, then delete the patch file `worker/patch_rejection_sampler.py`.
# 2. make these functions as costom op, then remove AscendRejectionSampler
#