[Refactor]Refactor sampler (#2050)

Refactor Sampler implementation from patch way to inherit from vLLM
Sampler interface.

Next step: Make the op `TopKTopPSampler` in vLLM support custom ops
register mechanism

- vLLM version: v0.10.0
- vLLM main:
61a6905ab0

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-07-30 08:47:22 +08:00
committed by GitHub
parent b6a7f07c70
commit 9b67c87b14
8 changed files with 108 additions and 150 deletions

View File

@@ -64,7 +64,6 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -72,6 +71,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
@@ -165,7 +165,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.device = device
self.dtype = self.model_config.dtype
self.sampler = Sampler()
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
# TODO: drop the env config to use ascend sampler by default
from vllm_ascend.sample.sampler import AscendSampler
self.sampler = AscendSampler()
else:
from vllm.v1.sample.sampler import Sampler
self.sampler = Sampler()
# Lazy initialization, these will be set after __init__
self.kv_caches: List[torch.Tensor] = []