import torch from vllm.forward_context import set_forward_context from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer class V1ZeroEagleProposer(EagleProposer): def __init__(self, vllm_config, device, runner=None): super().__init__(vllm_config, device, runner) self.spec_scheduler_max_num_tokens = 0 def propose( self, # [num_tokens] target_token_ids: torch.Tensor, # [num_tokens] target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, # [num_tokens] target_slot_mapping: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, # [batch_size + 1] starting with 0 cu_num_tokens: torch.Tensor, # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, # [batch_size] sampling_metadata: SamplingMetadata, decoding: bool = False, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states( target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] self.input_ids[:num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids # FA requires seq_len to have dtype int32. seq_lens = (target_positions[last_token_indices] + 1).int() if self.method in ["eagle", "eagle3"]: # FIXME(woosuk): The below two ops cause synchronization. Optimize. max_seq_len = seq_lens.max().item() max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_tokens, max_query_len=max_num_tokens, query_start_loc=cu_num_tokens, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table, slot_mapping=target_slot_mapping, # TODO(woosuk): Support cascade attention. use_cascade=False, common_prefix_len=0, cu_prefix_query_lens=None, prefix_kv_lens=None, suffix_kv_lens=None, ) elif self.method == "deepseek_mtp": max_query_len = self.spec_scheduler_max_num_tokens common_attn_metadata = CommonAttentionMetadata( query_start_loc=cu_num_tokens, seq_lens=seq_lens, num_reqs=batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, slot_mapping=target_slot_mapping, spec_layer_decoding=decoding ) assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups attn_metadata = self.runner.attn_metadata_builders[0].build( common_prefix_len=0, common_attn_metadata=common_attn_metadata ) else: raise ValueError(f"Unsupported method: {self.method}") # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states if (decoding and self.use_full_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]): assert self.attn_metadata_cudagraph if self.method in ["eagle", "eagle3"]: self.attn_metadata_cudagraph.seq_lens[:batch_size] = ( attn_metadata.seq_lens) self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = ( attn_metadata.slot_mapping) self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = ( attn_metadata.query_start_loc) self.attn_metadata_cudagraph.block_table[:batch_size] = ( attn_metadata.block_table) elif self.method == "deepseek_mtp": self.attn_metadata_cudagraph.num_actual_tokens = ( attn_metadata.num_actual_tokens) self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = ( attn_metadata.query_start_loc) self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = ( attn_metadata.slot_mapping) self.attn_metadata_cudagraph.num_decodes = ( attn_metadata.num_decodes) self.attn_metadata_cudagraph.num_decode_tokens = ( attn_metadata.num_decode_tokens) self.attn_metadata_cudagraph.num_prefills = ( attn_metadata.num_prefills) if attn_metadata.decode is not None: self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = ( attn_metadata.decode.block_table) self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = ( attn_metadata.decode.seq_lens) with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens, skip_cuda_graphs=not decoding): ret_hidden_states = self.model( self.input_ids[:num_input_tokens], self.positions[:num_input_tokens], self.hidden_states[:num_input_tokens], ) if self.method == "deepseek_mtp": last_hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: # [batch_size, 1] return draft_token_ids.view(-1, 1) # TODO: Currently, MTP module released by deepseek only has # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] positions = target_positions[last_token_indices] if self.method == "deepseek_mtp": hidden_states = last_hidden_states[last_token_indices] else: hidden_states = hidden_states[last_token_indices] if self.use_cuda_graph and \ batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] if isinstance(attn_metadata, MLACommonMetadata): attn_metadata.num_decodes = batch_size attn_metadata.num_decode_tokens = batch_size attn_metadata.num_prefills = 0 block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...] attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode( block_table_tensor=block_table, seq_lens=seq_lens, ) for i in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. input_ids = draft_token_ids_list[-1].int() positions += 1 # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch # but adjust the position ids and slot mappings to avoid the # out-of-range access during the model execution. The draft tokens # generated with this adjustment should be ignored. exceeds_max_model_len = positions >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. clamped_positions = torch.where(exceeds_max_model_len, 0, positions) if isinstance(attn_metadata, MLACommonMetadata): attn_metadata.decode.seq_lens += 1 else: attn_metadata.seq_lens += 1 # Increment the sequence lengths. attn_metadata.max_seq_len += 1 # Consider max model length. attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, self.max_model_len) # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. block_numbers = clamped_positions // self.block_size block_ids = block_table.gather(dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) attn_metadata.slot_mapping = (block_ids * self.block_size + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states if (self.use_full_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]): assert self.attn_metadata_cudagraph if self.method in ["eagle", "eagle3"]: self.attn_metadata_cudagraph.seq_lens[:batch_size] = ( attn_metadata.seq_lens) self.attn_metadata_cudagraph.slot_mapping[:batch_size] = ( attn_metadata.slot_mapping) if i == 0: self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = ( attn_metadata . query_start_loc ) self.attn_metadata_cudagraph.block_table[:batch_size] = ( attn_metadata.block_table) elif self.method == "deepseek_mtp": self.attn_metadata_cudagraph.num_actual_tokens = ( attn_metadata.num_actual_tokens) self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = ( attn_metadata.slot_mapping) self.attn_metadata_cudagraph.num_decodes = ( attn_metadata.num_decodes) self.attn_metadata_cudagraph.num_decode_tokens = ( attn_metadata.num_decode_tokens) self.attn_metadata_cudagraph.num_prefills = ( attn_metadata.num_prefills) self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = ( attn_metadata.decode.seq_lens) if i == 0: self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = ( attn_metadata.query_start_loc) self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = ( attn_metadata.decode.block_table) # Run the model. with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): ret_hidden_states = self.model( self.input_ids[:input_batch_size], self.positions[:input_batch_size], self.hidden_states[:input_batch_size], ) if self.method == "deepseek_mtp": last_hidden_states = ret_hidden_states hidden_states = last_hidden_states[:batch_size] else: last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], None) # TODO(wenlong): get more than one token for tree attention draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids