# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import List, Optional import torch from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors from vllm.worker.neuronx_distributed_model_runner import ( NeuronxDistributedModelRunner) class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner): """A model runner for multi-step decoding using the neuronx-distributed-inference framework""" def __init__( self, vllm_config: VllmConfig, ): super().__init__(vllm_config) def load_model(self) -> None: from vllm.model_executor.model_loader.neuronx_distributed import ( get_neuron_speculation_model) self.model = get_neuron_speculation_model( self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, speculation_config=self.speculative_config) @torch.inference_mode() def execute_model( self, model_input, kv_caches: Optional[List[torch.Tensor]] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: sampling_params = torch.tensor([[ seq_group.sampling_params.top_k, seq_group.sampling_params.top_p, seq_group.sampling_params.temperature, ] for seq_group in model_input.sampling_metadata.seq_groups]) logits = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, **MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs or {}, device=self.device, ), ) output = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) return output