# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from importlib.util import find_spec from typing import List, Optional import torch from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors from vllm.worker.neuron_model_runner import (ModelInputForNeuron, NeuronModelRunner) class MultiStepNeuronModelRunner(NeuronModelRunner): """A model runner for multi step decoding using the transformers_neuronx framework""" def __init__( self, vllm_config: VllmConfig, ): super().__init__(vllm_config) self.speculation_config = self.speculative_config from transformers_neuronx.config import GenerationConfig self.speculation_config.draft_model_config.neuron_sampling_params = ( GenerationConfig( max_length=self.scheduler_config.max_model_len, do_sample=True, per_batch_line=True, top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ * self.scheduler_config.max_num_seqs, top_p=[1.0] * self.scheduler_config.max_num_seqs, temperature=[1.0] * self.scheduler_config.max_num_seqs, dynamic=True, global_top_k=self._MAX_NEURON_SAMPLING_TOP_K )) def load_model(self) -> None: if find_spec("transformers_neuronx") is not None: from vllm.model_executor.model_loader.neuron import ( get_neuron_eagle_speculation_model, get_neuron_speculation_model) if self.speculation_config.speculative_token_tree is not None: self.model = get_neuron_eagle_speculation_model( self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, speculation_config=self.speculation_config) else: self.model = get_neuron_speculation_model( self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, speculation_config=self.speculation_config) else: raise NotImplementedError( "Supports only Transformer-NeuronX based models.") @torch.inference_mode() def execute_model( self, model_input: ModelInputForNeuron, kv_caches: Optional[List[torch.Tensor]] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: logits = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, **MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs or {}, device=self.device, ), ) output = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) return output