Files
2026-02-04 17:22:39 +08:00

27 lines
823 B
Python

import weakref
import torch
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
class MLUNGramWorker(NGramWorker):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implements prompt lookup decoding,
and in future we may also do RAG type drafter and other scenarios
which don't rely on LLM model to give proposals.
"""
def init_device(self):
self.device = torch.device(f"mlu:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None
# Current NGramWorker only supports Top1Proposer
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
device=self.device,
vocab_size=self.vocab_size,
)