[Feat] Refactor rejection sampler (#4975)
### What this PR does / why we need it?
Currently, we are using `AscendRejctionSampler` that extends from
`RejctionSampler` in spec decoding. `AscendRejctionSampler` override
`forward` of `RejctionSampler`, only aming to replace `rejection_sample`
func. This
causes a lot of code of `RejctionSampler` cannot be reused, for example:
- https://github.com/vllm-project/vllm/pull/19482
- https://github.com/vllm-project/vllm/pull/26060
- https://github.com/vllm-project/vllm/pull/29223
#### Proposed Change:
- Delete `AscendRejctionSampler` and use `RejctionSampler` directly in
model runner.
- Patch `RejctionSampler.expand_batch_to_tokens` and
`RejctionSampler.rejection_sample`, maybe a better way is to make them
as custom ops.
- Modify `NPUModelRunner` following
https://github.com/vllm-project/vllm/pull/26060
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- [x] test logits processor for spec decoding
- [x] test logprobs for spec decoding
- [x] test logprobs for spec decoding + async shcheduling (test with
https://github.com/vllm-project/vllm-ascend/pull/4893/)
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -228,7 +228,7 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when the bug is fixed.
|
||||
#
|
||||
# ** File: worker/patch_qwen3_next_mtp.py**
|
||||
# ** 11. File: worker/patch_qwen3_next_mtp.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.worker.utils.bind_kv_cache`
|
||||
# Why:
|
||||
@@ -241,7 +241,7 @@
|
||||
# Future Plan:
|
||||
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
|
||||
#
|
||||
# ** File: worker/patch_module.py**
|
||||
# ** 12. File: worker/patch_module.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
|
||||
# Why:
|
||||
@@ -257,3 +257,19 @@
|
||||
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
|
||||
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
|
||||
#
|
||||
# ** 13. File: worker/patch_rejection_sampler.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.sample.rejection_sampler`
|
||||
# Why:
|
||||
# - some functions from `rejection_sampler` are not supported or slow on npu.
|
||||
# How:
|
||||
# - add npu_top_k_top_p to 'apply_sampling_constraints' func
|
||||
# - add custom triton kernel to `expand_batch_to_tokens` and `rejection_sample`
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/874
|
||||
# https://github.com/vllm-project/vllm/pull/4849
|
||||
# Future Plan:
|
||||
# 1. make these functions as class func of RejectionSampler, create AscendRejectionSampler
|
||||
# to override them, then delete the patch file `worker/patch_rejection_sampler.py`.
|
||||
# 2. make these functions as costom op, then remove AscendRejectionSampler
|
||||
#
|
||||
Reference in New Issue
Block a user