[Refactor]Refactor sampler (#2050)
Refactor Sampler implementation from patch way to inherit from vLLM
Sampler interface.
Next step: Make the op `TopKTopPSampler` in vLLM support custom ops
register mechanism
- vLLM version: v0.10.0
- vLLM main:
61a6905ab0
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -64,7 +64,6 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@@ -72,6 +71,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
|
||||
sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
@@ -165,7 +165,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.device = device
|
||||
self.dtype = self.model_config.dtype
|
||||
self.sampler = Sampler()
|
||||
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
|
||||
# TODO: drop the env config to use ascend sampler by default
|
||||
from vllm_ascend.sample.sampler import AscendSampler
|
||||
|
||||
self.sampler = AscendSampler()
|
||||
else:
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialization, these will be set after __init__
|
||||
self.kv_caches: List[torch.Tensor] = []
|
||||
|
||||
Reference in New Issue
Block a user