from typing import Any, Optional, Union import torch import numpy as np from vllm import envs from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.forward_context import set_forward_context from vllm.sequence import IntermediateTensors from vllm.utils import async_tensor_h2d, round_up from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput from vllm.profiler.prof import profile from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model class V1ZeroModelRunner(GPUModelRunner): def __init__(self, vllm_config, device): super().__init__(vllm_config, device) self.last_sampled_token_ids = None self.last_sampled_req_ids = [] self.last_sampled_token_lens = [] self.last_sampler_event = torch.cuda.Event(enable_timing=False) self.last_sampler_host_tokens = None self.token_ids_cpu_fix_record = [] self.last_draft_token_ids = None self.last_draft_host_tokens = None self.last_draft_event = torch.cuda.Event(enable_timing=False) self.spec_sampler_event = torch.cuda.Event(enable_timing=False) self.spec_scheduler_max_num_tokens = 0 if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer): self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device, self) def _prepare_inputs( self, scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], np.ndarray]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, attention_cuda_graphs: whether attention can run in cudagraph logits_indices, spec_decode_metadata ] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) # Get the number of scheduled tokens for each request. req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = max(tokens) self.spec_scheduler_max_num_tokens = max_num_scheduled_tokens # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] cu_num_tokens, arange = self._get_cumsum_and_arange( num_scheduled_tokens) # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._calc_mrope_positions(scheduler_output) # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Calculate the slot mapping for each KV cache group. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): block_size = kv_cache_group_spec.kv_cache_spec.block_size block_table: BlockTable = self.input_batch.block_table[ kv_cache_group_id] # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # where K is the max_num_blocks_per_req and the block size is 2. # NOTE(woosuk): We can't simply use `token_indices // block_size` # here because M (max_model_len) is not necessarily divisible by # block_size. block_table_indices = ( req_indices * block_table.max_num_blocks_per_req + positions_np // block_size) block_table_cpu = block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten( )[block_table_indices].numpy() block_offsets = positions_np % block_size np.add( block_numbers * block_size, block_offsets, out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) self.zero_prepare_inputs(scheduler_output, self.input_ids) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True) else: # Common case (1D positions) self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) self.query_start_loc[:num_reqs + 1].copy_( self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) # Fill unused with -1. Needed for reshape_and_cache self.seq_lens[num_reqs:].fill_(0) # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that self.query_start_loc[num_reqs + 1:].fill_( self.query_start_loc_cpu[num_reqs].item()) query_start_loc = self.query_start_loc[:num_reqs + 1] seq_lens = self.seq_lens[:num_reqs] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens, num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, ) attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 builder = self.attn_metadata_builders[kv_cache_group_id] if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, scheduler_output. num_common_prefix_blocks[kv_cache_group_id], kv_cache_group_spec.kv_cache_spec, builder, ) attn_metadata_i = (builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, )) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i attention_cuda_graphs = all( b.can_run_in_cudagraph(common_attn_metadata) for b in self.attn_metadata_builders) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 spec_decode_metadata = None else: # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all # requests have draft tokens. num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) for req_id, draft_token_ids in ( scheduler_output.scheduled_spec_decode_tokens.items()): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) return (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens) def zero_prepare_inputs(self, scheduler_output, input_ids): req_ids = self.input_batch.req_ids update_req_indices = [] input_ids_indices = [] token_idx = 0 if self.last_draft_token_ids is not None: draft_tokens_num = self.last_draft_token_ids.shape[1] for req_id in req_ids: if req_id in self.last_sampled_req_ids: req_idx = self.last_sampled_req_ids.index(req_id) * draft_tokens_num for num_idx in range(draft_tokens_num): update_req_indices.append(req_idx + num_idx) input_ids_indices.append(token_idx + num_idx + 1) token_idx += draft_tokens_num + 1 if len(update_req_indices) > 0: update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32, self.device, True) input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32, self.device, True) last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int) input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor] update_req_indices = [] input_ids_indices = [] token_idx = 0 if self.last_sampled_token_ids is not None: sampled_tokens_num = self.last_sampled_token_ids.shape[1] for req_id in req_ids: if req_id in self.last_sampled_req_ids: req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num update_req_indices.append(req_idx) input_ids_indices.append(token_idx) token_idx += scheduler_output.num_scheduled_tokens[req_id] if len(update_req_indices) > 0: update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32, self.device, True) input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32, self.device, True) last_sampled_token_ids = self.last_sampled_token_ids.flatten() for i in range(sampled_tokens_num): input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i] def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", num_accepted_tokens_tensor: torch.Tensor, sampled_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], attn_metadata: dict[str, Any], ) -> list[list[int]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) if sample_hidden_states.shape[0] == len(sampled_token_ids): # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states else: indices = [] offset = 0 for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, sampled_token_ids): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] spec_token_ids = self.drafter.propose( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. row_indices = torch.arange(sampled_token_ids.size(0), device=sampled_token_ids.device) next_token_ids = sampled_token_ids[row_indices, num_accepted_tokens_tensor].flatten() # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. eagle_attn_metadata = attn_metadata[ self.drafter.attn_layer_names[0]] # NOTE: deepseek_mtp uses MLA which does not have `block_table` if hasattr(eagle_attn_metadata, "block_table"): block_table = eagle_attn_metadata.block_table else: block_table = None spec_scheduler_max_num_tokens = self.spec_scheduler_max_num_tokens if spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] # TODO(woosuk): Support M-RoPE. target_positions = self.positions[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[:num_scheduled_tokens] target_slot_mapping = eagle_attn_metadata.slot_mapping cu_num_tokens = eagle_attn_metadata.query_start_loc else: # TODO(woosuk): Refactor this. cu_num_tokens, token_indices = self.drafter.prepare_inputs( eagle_attn_metadata.query_start_loc, num_accepted_tokens_tensor, ) spec_scheduler_max_num_tokens = 1 target_token_ids = self.input_ids[token_indices] # TODO(woosuk): Support M-RoPE. target_positions = self.positions[token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, block_table=block_table, sampling_metadata=sampling_metadata, decoding=spec_decode_metadata is not None, ) spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist() self.last_draft_token_ids = draft_token_ids self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True) self.last_draft_event.record() return spec_token_ids @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, IntermediateTensors]: self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): # Return empty ModelRunnerOutput if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward(scheduler_output) # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. # Add padding to the batch size. num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) else: # Eager mode. # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size if self.compilation_config.pass_config. \ enable_sequence_parallelism and tp_size > 1: num_input_tokens = round_up(num_scheduled_tokens, tp_size) else: num_input_tokens = num_scheduled_tokens # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order if self.is_multimodal_model: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) else: inputs_embeds = self.model.get_input_embeddings(input_ids) # TODO(woosuk): Avoid the copy. Optimize. self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) inputs_embeds = self.inputs_embeds[:num_input_tokens] input_ids = None else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] else: positions = self.positions[:num_input_tokens] if get_pp_group().is_first_rank: intermediate_tensors = None else: intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) # Some attention backends only support CUDA Graphs in pure decode. # If attention doesn't support CUDA Graphs for this batch, but we # compiled with full CUDA graphs, we have to skip them entirely. skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs if envs.VLLM_ENABLE_TBO and (not self.use_cuda_graph or skip_cuda_graphs): model_output, finished_sending, finished_recving = \ tbo_split_and_execute_model(self, attn_metadata, num_input_tokens, num_tokens_across_dp, input_ids, positions, inputs_embeds, scheduler_output, intermediate_tensors, skip_cuda_graphs) else: # Run the model. # Use persistent buffers for CUDA graphs. with set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, ): self.maybe_setup_kv_connector(scheduler_output) model_output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) self.maybe_wait_for_kv_save() finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output else: hidden_states = model_output aux_hidden_states = None # Broadcast PP output for external_launcher (torchrun) # to make sure we are synced across pp ranks # TODO: Support overlapping mirco-batches # https://github.com/vllm-project/vllm/issues/18019 broadcast_pp_output = \ self.parallel_config.distributed_executor_backend \ == "external_launcher" and len(get_pp_group().ranks) > 0 if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict(hidden_states.tensors, all_gather_group=get_tp_group()) logits = None else: if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, num_scheduled_tokens_np, finished_sending, finished_recving) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: model_output_broadcast_data = { "logits": logits.contiguous(), } if logits is not None else {} model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: self.apply_grammar_bitmask(scheduler_output, logits) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) else: # When indexing with a tensor (bonus_logits_indices), PyTorch # creates a new tensor with separate storage from the original # logits tensor. This means any in-place operations on bonus_logits # won't affect the original logits tensor. assert logits is not None bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] sampler_output = self.sampler( logits=bonus_logits, sampling_metadata=sampling_metadata, ) bonus_token_ids = sampler_output.sampled_token_ids # Just like `bonus_logits`, `target_logits` is a new tensor with # separate storage from the original `logits` tensor. Therefore, # it is safe to update `target_logits` in place. target_logits = logits[spec_decode_metadata.target_logits_indices] output_token_ids = self.rejection_sampler( spec_decode_metadata, None, # draft_probs target_logits, bonus_token_ids, sampling_metadata, ) sampler_output.sampled_token_ids = output_token_ids num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. discard_sampled_tokens_req_indices = [] for i, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) if seq_len < req_state.num_tokens: # Ignore the sampled token for partial prefills. # Rewind the generator state as if the token was not sampled. # This relies on cuda-specific torch-internal impl details generator = self.input_batch.generators.get(i) if generator is not None: generator.set_offset(generator.get_offset() - 4) # Record the index of the request that should not be sampled, # so that we could clear the sampled tokens before returning. discard_sampled_tokens_req_indices.append(i) # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors logprobs_lists = logprobs_tensors.tolists() \ if logprobs_tensors is not None else None # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( hidden_states[:num_scheduled_tokens], scheduler_output, ) fix_req_ids = None fix_sampled_token_ids = None fix_draft_token_ids = None fix_draft_req_ids = self.last_sampled_req_ids is_output_valid = False # Get the valid generated tokens. sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] if not self.speculative_config: # Speculative decoding is not enabled. spec_token_ids = None fix_draft_req_ids = None else: sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True) self.spec_sampler_event.record() if self.last_draft_host_tokens is not None: self.last_draft_event.synchronize() fix_draft_token_ids = self.last_draft_host_tokens.tolist() mask = (sampled_token_ids == -1) mask_int = mask.int() first_neg_one_indices = torch.argmax(mask_int, dim=1) num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1 spec_token_ids = self.propose_draft_token_ids( scheduler_output, num_accepted_tokens_tensor, sampled_token_ids, sampling_metadata, hidden_states, sample_hidden_states, aux_hidden_states, spec_decode_metadata, attn_metadata, ) if self.speculative_config: self.spec_sampler_event.synchronize() if max_gen_len == 1: valid_sampled_token_ids = sampled_token_ids_cpu.tolist() else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( sampled_token_ids_cpu, self.input_batch.vocab_size, ) self.last_sampler_host_tokens = None self.last_sampled_token_ids = None is_output_valid = True else: # No spec decode tokens. fix_req_ids = self.last_sampled_req_ids if self.last_sampler_host_tokens != None: self.last_sampler_event.synchronize() fix_sampled_token_ids = self.last_sampler_host_tokens.tolist() for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record: if start_idx == -1: continue req_id = fix_req_ids[req_idx] if req_id in self.input_batch.req_ids: new_req_idx = self.input_batch.req_ids.index(req_id) self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx] for req_idx, req_id in enumerate(fix_req_ids): if req_id in self.requests: req_state = self.requests[req_id] token_idx = self.last_sampled_token_lens[req_idx] if token_idx == -1: continue fix_len = len(fix_sampled_token_ids[req_idx]) req_state.output_token_ids[token_idx:token_idx + fix_len] = fix_sampled_token_ids[req_idx] self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True) self.last_sampler_event.record() self.last_sampled_token_ids = sampled_token_ids valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist() # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. self.token_ids_cpu_fix_record.clear() self.last_sampled_req_ids = [] self.last_sampled_token_lens = [] for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): req_id = self.input_batch.req_ids[req_idx] self.last_sampled_req_ids.append(req_id) cache_output_len = -1 if not sampled_ids: self.last_sampled_token_lens.append(-1) self.token_ids_cpu_fix_record.append([req_idx, -1, -1]) continue start_idx = self.input_batch.num_tokens_no_spec[req_idx] end_idx = start_idx + len(sampled_ids) assert end_idx <= self.max_model_len, ( "Sampled token IDs exceed the max model length. " f"Total number of tokens: {end_idx} > max_model_len: " f"{self.max_model_len}") self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx]) self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx if req_id in self.requests: req_state = self.requests[req_id] cache_output_len = len(req_state.output_token_ids) req_state.output_token_ids.extend(sampled_ids) self.last_sampled_token_lens.append(cache_output_len) # Clear KVConnector state after all KVs are generated. if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() self.eplb_step() model_output = ZeroV1ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], finished_sending=finished_sending, finished_recving=finished_recving, num_nans_in_logits=num_nans_in_logits, fix_req_ids = fix_req_ids, fix_sampled_token_ids = fix_sampled_token_ids, fix_draft_tokens_ids = fix_draft_token_ids, fix_draft_req_ids = fix_draft_req_ids, is_output_valid=is_output_valid ) return model_output