forked from EngineX-Cambricon/enginex-mlu370-vllm
27 lines
823 B
Python
27 lines
823 B
Python
import weakref
|
|
|
|
import torch
|
|
|
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
|
|
|
|
|
class MLUNGramWorker(NGramWorker):
|
|
"""NGramWorker provides a light drafter without need for model.
|
|
|
|
Current NGramWorker only implements prompt lookup decoding,
|
|
and in future we may also do RAG type drafter and other scenarios
|
|
which don't rely on LLM model to give proposals.
|
|
"""
|
|
|
|
def init_device(self):
|
|
self.device = torch.device(f"mlu:{self.local_rank}")
|
|
self.load_model = lambda *args, **kwargs: None
|
|
|
|
# Current NGramWorker only supports Top1Proposer
|
|
self._proposer = Top1Proposer(
|
|
weakref.proxy(self), # type: ignore[arg-type]
|
|
device=self.device,
|
|
vocab_size=self.vocab_size,
|
|
)
|