import weakref import torch from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.top1_proposer import Top1Proposer class MLUNGramWorker(NGramWorker): """NGramWorker provides a light drafter without need for model. Current NGramWorker only implements prompt lookup decoding, and in future we may also do RAG type drafter and other scenarios which don't rely on LLM model to give proposals. """ def init_device(self): self.device = torch.device(f"mlu:{self.local_rank}") self.load_model = lambda *args, **kwargs: None # Current NGramWorker only supports Top1Proposer self._proposer = Top1Proposer( weakref.proxy(self), # type: ignore[arg-type] device=self.device, vocab_size=self.vocab_size, )