# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project from typing import List, Optional, Any import copy import torch import torch.nn.functional as F from vllm.config.vllm import CUDAGraphMode from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, FlashAttentionMetadata) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, logger from vllm.distributed.communication_op import tensor_model_parallel_all_gather_into_list from vllm.distributed import ( get_logits_tp_world_size, get_logits_tp_group, get_tensor_model_parallel_world_size, ) from vllm_mlu.v1.attention.backends.flash_attn import pad_attn_metadata from vllm_mlu.v1.attention.backends.mla.flashmla import FlashMLAMetadataBuilder from vllm_mlu.v1.attention.backends.utils import ( MLUCommonAttentionMetadata, COMMON_METADATA_STR) from vllm_mlu._mlu_utils import * from vllm_mlu.v1.attention.backends.utils import MLUInferMode from vllm_mlu.mlu_forward_context import MLUDPMetadata from vllm_mlu.v1.spec_decode.eagle import MluEagleProposer from vllm_mlu.model_executor.models.dp_utils import ( enable_data_parallel, DataParallelRuntimeParams ) class DPMluEagleProposer(MluEagleProposer): def get_logits_batch_sizes(self, batch_size: int) -> Optional[List[int]]: tp_world_size, logits_batch_sizes = get_logits_tp_world_size(), None if tp_world_size != get_tensor_model_parallel_world_size(): tp_tensor = torch.tensor([batch_size]).to(self.runner.device) outputs = tensor_model_parallel_all_gather_into_list(tp_tensor, get_logits_tp_group()) # Convert device tensor to host list outputs = torch.cat(outputs).tolist() logits_batch_sizes = [outputs[i] for i in range(tp_world_size)] return logits_batch_sizes def propose_ds_execute_dummy_batch( self, # [num_tokens] target_token_ids: torch.Tensor, # [num_tokens] target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, dp_params: DataParallelRuntimeParams, ) -> tuple[torch.Tensor, torch.Tensor]: # num_scheduled_tokens num_tokens = target_token_ids.shape[0] input_ids = self.input_ids[:num_tokens] # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] input_ids[:-1] = target_token_ids[1:] # always skip attn compute attn_metadata: Optional[dict[str, Any]] = None # Get graph capture related infomation for deepseek model. with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens): hidden_states = self.model( input_ids=input_ids, positions=target_positions, hidden_states=target_hidden_states, intermediate_tensors=None, inputs_embeds=None, dp_params=dp_params, ) if dp_params is not None: dp_params.logits_batch_split_list = self.get_logits_batch_sizes(num_tokens) _ = self.model.compute_logits(hidden_states, dp_params=dp_params) if self.num_speculative_tokens == 1: return ''' ============================= Modify by vllm_mlu @brief: support k > 1, need run draft model k-1 times ============================= ''' # support k > 1 for _ in range(self.num_speculative_tokens - 1): new_dp_params = self.runner._get_data_parallel_metadata( num_tokens, num_tokens, True, [1] * num_tokens) with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens): hidden_states = self.model( input_ids=input_ids, positions=target_positions, hidden_states=target_hidden_states, intermediate_tensors=None, inputs_embeds=None, dp_params=new_dp_params, ) _ = self.model.compute_logits(hidden_states, dp_params=new_dp_params) ''' ============================= End of MLU Hijack ============================= ''' def propose( self, # [num_tokens] target_token_ids: torch.Tensor, # [num_tokens] target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, last_token_indices: torch.Tensor | None, common_attn_metadata: MLUCommonAttentionMetadata, sampling_metadata: SamplingMetadata, # [batch_size] num_rejected_tokens: torch.Tensor, # [num_tokens] token_indices: torch.Tensor, whole_block_table: torch.Tensor, main_model_dp_params: Optional[DataParallelRuntimeParams] = None, time_markers: List =[], ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] if last_token_indices is None: last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states( target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] self.input_ids[:num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids hidden_states_indices = last_token_indices assert self.runner is not None if self.attn_metadata_builder is None: attn_metadata_builder = self._get_attention_metadata_builder() else: attn_metadata_builder = self.attn_metadata_builder # FIXME: need to consider multiple kv_cache_groups attn_metadata = attn_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0, ) ''' ============================= Modify by vllm_mlu ============================= @brief: Use full graph with draft model and pad batch_size for dp ''' dp_group_max_token_num = max(main_model_dp_params.token_split_list) if dp_group_max_token_num <= self.vllm_config.compilation_config.max_cudagraph_capture_size: batch_descriptor_num_tokens = self.vllm_config.pad_for_cudagraph(dp_group_max_token_num) captured_already = True else: batch_descriptor_num_tokens = num_tokens captured_already = False # Determine if we can use full graph decode_only = all(not prefill for prefill in main_model_dp_params.dp_is_prefill) # FIXME(wangchao2): disable mtp graph for ds3.2 with dp fow now(core dump) is_dsv32 = self.vllm_config.model_config.hf_config.model_type == "deepseek_v32" use_full_graph = (self.use_cuda_graph and decode_only and captured_already and not is_dsv32) if (self.use_cuda_graph and decode_only and not use_full_graph and not is_dsv32): logger.warning_once( f"Select MLU-V1 Full-MLUGraph mode with drafter, however running in " + f"eager mode: decode_only={decode_only}, captured_already={captured_already}, " + f"num_tokens={num_tokens}." ) cudagraph_runtime_mode = CUDAGraphMode.FULL if use_full_graph else CUDAGraphMode.NONE batch_descriptor = BatchDescriptor( num_tokens=batch_descriptor_num_tokens, uniform_decode=True, ) # dp pad batch_size if use_full_graph: K = self.num_speculative_tokens num_input_tokens = batch_descriptor_num_tokens padded_batch_size = num_input_tokens // (K + 1) else: padded_batch_size = batch_size num_input_tokens = num_tokens # change attn metadata num_actual_tokens attn_metadata.num_actual_tokens = num_input_tokens common_attn_metadata_copy = None # copy common_attn_metadata when k>1 for draft model, # because dp pad batch_size will change common_attn_metadata if self.num_speculative_tokens > 1: common_attn_metadata_copy = copy.deepcopy(common_attn_metadata) # pad attn metadata if use_full_graph and enable_data_parallel() and num_input_tokens != num_tokens: assert self.runner is not None # Update attention metadata. pad_attn_metadata( attn_metadata, common_attn_metadata, whole_block_table, self.runner, num_tokens, num_input_tokens, batch_size, padded_batch_size, ) # Update input ids, pad with 0 if necessary. token_pad_size = num_input_tokens - num_tokens assert token_pad_size >= 0 # Update target hidden states, pad with zeros if necessary. if token_pad_size > 0: target_hidden_states = F.pad( target_hidden_states, (0, 0, 0, token_pad_size), value=0.0 ) # Update positions, pad with zeros if necessary. if token_pad_size > 0: target_positions = F.pad( target_positions, (0, token_pad_size), value=0 ) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata # copy inputs to buffer for cudagraph self.positions[:num_input_tokens] = target_positions self.hidden_states[:num_input_tokens] = target_hidden_states kwargs = {} if main_model_dp_params is None else {"dp_params": main_model_dp_params} if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: start = torch.mlu.Event(enable_timing=True) start.record() with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens, batch_descriptor=batch_descriptor if use_full_graph else None, cudagraph_runtime_mode=cudagraph_runtime_mode): if use_full_graph: ret_hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens], intermediate_tensors=None, inputs_embeds=None, is_running_drafter=True, **kwargs, ) else: ret_hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens], **kwargs, ) if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: end = torch.mlu.Event(enable_timing=True) end.record() time_markers.append([start, end]) if self.method == "mtp": last_hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states ''' ============================= End of MLU Hijack ============================= ''' if main_model_dp_params is not None: # Ensure main_model_dp_params has required attribute before assignment if hasattr(main_model_dp_params, 'logits_batch_split_list'): main_model_dp_params.logits_batch_split_list = self.get_logits_batch_sizes(batch_size) else: raise AttributeError("dp_params must have 'logits_batch_split_list' attribute") sample_hidden_states = last_hidden_states[hidden_states_indices] logits = self.model.compute_logits(sample_hidden_states, dp_params=main_model_dp_params) draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: # [batch_size, 1] return draft_token_ids.view(-1, 1) if self.uses_mrope: positions = target_positions[:, last_token_indices] else: positions = target_positions[last_token_indices] ''' ============================= Modify by vllm_mlu ============================= ''' hidden_states = last_hidden_states[hidden_states_indices] ''' ============================= End of MLU Hijack ============================= ''' # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] input_batch_size = batch_size if common_attn_metadata.infer_mode != MLUInferMode.DECODE_ONLY: seq_lens_cpu = torch.ones(input_batch_size, dtype=torch.int32,) cu_num_tokens = torch.cumsum(seq_lens_cpu, dim=0) query_start_loc_cpu = torch.empty(input_batch_size + 1, dtype=torch.int32) query_start_loc_cpu[0] = 0 query_start_loc_cpu[1:] = cu_num_tokens seq_start_loc_cpu = self.arange[:input_batch_size + 1] common_attn_metadata_k = MLUCommonAttentionMetadata.build( query_start_loc=query_start_loc_cpu.to(self.device, non_blocking=True), query_start_loc_cpu=query_start_loc_cpu, seq_lens=seq_lens_cpu.to(self.device, non_blocking=True), seq_lens_cpu=seq_lens_cpu, num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, seq_start_loc=seq_start_loc_cpu.to(self.device, non_blocking=True), is_start_loc_match=False, # not prefill max_query_len=1, num_actual_tokens=input_batch_size, num_input_tokens=input_batch_size, num_speculative_tokens=self.num_speculative_tokens, has_prefill_reqs=common_attn_metadata.infer_mode == MLUInferMode.CHUNKED, ) else: common_attn_metadata_k = common_attn_metadata_copy common_attn_metadata_k.num_actual_tokens = batch_size common_attn_metadata_k.num_input_tokens = batch_size common_attn_metadata_k.max_query_len = 1 common_attn_metadata_k.query_start_loc = self.arange[: batch_size + 1] common_attn_metadata_k.query_start_loc_cpu = torch.from_numpy( self.token_arange_np[: batch_size + 1] ).clone() # In padded drafter batch, we need to adjust the sequence lengths # to remove the "padding" (i.e. rejected tokens). # Only apply this adjustment when we have rejected tokens # (i.e., not the first proposal). for token_index in range(self.num_speculative_tokens - 1): ''' ============================= Modify by vllm_mlu ============================= @brief: get dp_params for draft model ''' # dp_params for draft model if main_model_dp_params is not None: dp_params = self.runner._get_data_parallel_metadata( input_batch_size, input_batch_size, common_attn_metadata.is_decode_only, [1] * input_batch_size ) kwargs = {} if main_model_dp_params is None else {"dp_params": dp_params} ''' ============================= End of MLU Hijack ============================= ''' # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. input_ids = draft_token_ids_list[-1].int() if self.uses_mrope: positions += 1 # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. # Since it is complex to remove such requests from the batch, # we keep them in the batch but adjust the position ids # and slot mappings to avoid the # out-of-range access during the model execution. # The draft tokens generated with this adjustment # should be ignored. exceeds_max_model_len = positions[0] >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. clamped_positions = torch.where( exceeds_max_model_len.unsqueeze(0), torch.zeros_like(positions), positions, ) else: positions += 1 exceeds_max_model_len = positions >= self.max_model_len clamped_positions = torch.where(exceeds_max_model_len, 0, positions) # For data integrity when async scheduling, we shouldn't use in place # operations in case they are modified in next step's `prepare_input` # of main model. # Increment the sequence lengths. common_attn_metadata_k.seq_lens += 1 # This is an out-of-place operation to avoid modifying the original tensor. common_attn_metadata_k.seq_lens_cpu = common_attn_metadata_k.seq_lens_cpu + 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. common_attn_metadata_k.seq_lens.masked_fill_(exceeds_max_model_len, 1) common_attn_metadata_k.num_computed_tokens_cpu = ( common_attn_metadata_k.seq_lens_cpu - 1 ) # Compute the slot mapping. if self.uses_mrope: # all dimensions of positions are the same block_numbers = clamped_positions[0] // self.block_size else: block_numbers = clamped_positions // self.block_size block_ids = common_attn_metadata_k.block_table_tensor.gather( dim=1, index=block_numbers.view(-1, 1) ) block_ids = block_ids.view(-1) if self.uses_mrope: common_attn_metadata_k.slot_mapping = ( block_ids * self.block_size + clamped_positions[0] % self.block_size ) else: common_attn_metadata_k.slot_mapping = ( block_ids * self.block_size + clamped_positions % self.block_size ) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. common_attn_metadata_k.slot_mapping.masked_fill_( exceeds_max_model_len, PADDING_SLOT_ID ) # Rebuild attention metadata attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore common_attn_metadata=common_attn_metadata_k, draft_index=token_index + 1 ) for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata_k # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states if self.supports_mm_inputs: self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids) input_ids = None inputs_embeds = self.inputs_embeds[:input_batch_size] else: input_ids = self.input_ids[:input_batch_size] inputs_embeds = None ''' ============================= Modify by vllm_mlu ============================= @brief: record latency ''' if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: start = torch.mlu.Event(enable_timing=True) start.record() ''' ============================= End of MLU Hijack ============================= ''' # Run the model. with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): ret_hidden_states = self.model( input_ids=self.input_ids[:input_batch_size], positions=self.positions[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size], **kwargs, ) if self.method == "mtp": last_hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: end = torch.mlu.Event(enable_timing=True) end.record() time_markers.append([start, end]) ''' ============================= End of MLU Hijack ============================= ''' hidden_states = last_hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], dp_params=dp_params) # TODO(wenlong): get more than one token for tree attention draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids