from typing import List, Optional, Tuple import torch from torch import nn from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad logger = init_logger(__name__) class NeuronModelRunner: def __init__( self, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config if model_config is not None and model_config.get_sliding_window(): logger.warning("Sliding window is not supported on Neuron. " "The model will run without sliding window.") self.device_config = (device_config if device_config is not None else DeviceConfig()) self.device = self.device_config.device self.pin_memory = is_pin_memory_available() # Lazy initialization. self.model: nn.Module # initialize after load_model. def load_model(self) -> None: self.model = get_neuron_model(self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config) def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] input_block_ids: List[int] = [] seq_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() seq_len = len(prompt_tokens) seq_lens.append(seq_len) input_tokens.append(prompt_tokens) input_positions.append(list(range(seq_len))) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] assert len(block_table) == 1 input_block_ids.append(block_table[0]) max_seq_len = max(seq_lens) assert max_seq_len > 0 input_tokens = make_tensor_with_pad(input_tokens, max_seq_len, pad=0, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, max_seq_len, pad=0, dtype=torch.long, device=self.device) input_block_ids = torch.tensor(input_block_ids, dtype=torch.long, device=self.device) return input_tokens, input_positions, input_block_ids, seq_lens def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] input_block_ids: List[int] = [] context_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() input_tokens.append([generation_token]) seq_len = seq_data.get_len() position = seq_len - 1 input_positions.append([position]) context_lens.append(seq_len) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] assert len(block_table) == 1 input_block_ids.append(block_table[0]) input_tokens = make_tensor_with_pad(input_tokens, max_len=1, pad=0, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, max_len=1, pad=0, dtype=torch.long, device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) input_block_ids = torch.tensor(input_block_ids, dtype=torch.long, device=self.device) return input_tokens, input_positions, input_block_ids def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_block_ids, seq_lens) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, # query_lens is not needed if chunked prefill is not # supported. Since neuron worker doesn't support chunked prefill # just use seq_lens instead. seq_lens, self.device, self.pin_memory) return (input_tokens, input_positions, input_block_ids, sampling_metadata) @torch.inference_mode() def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, input_block_ids, sampling_metadata ) = self.prepare_input_tensors(seq_group_metadata_list) hidden_states = self.model( input_ids=input_tokens, positions=input_positions, input_block_ids=input_block_ids, ) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) # Sample the next token. output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) return output @property def vocab_size(self) -> int: return self.model_config.get_vocab_size()