From 26b311ccf5ffcd7aa129c6927d29b57247f4acd2 Mon Sep 17 00:00:00 2001 From: dongxinyu03 Date: Tue, 6 Jan 2026 21:37:21 +0800 Subject: [PATCH] [Feature] DeepSeek Support MTP --- vllm_kunlun/__init__.py | 1 + vllm_kunlun/models/__init__.py | 4 + vllm_kunlun/models/deepseek_mtp.py | 289 ++++++++ vllm_kunlun/v1/sample/__init__.py | 0 vllm_kunlun/v1/sample/rejection_sampler.py | 645 ++++++++++++++++++ vllm_kunlun/v1/sample/spec_decode/__init__.py | 0 vllm_kunlun/v1/sample/spec_decode/eagle.py | 312 +++++++++ 7 files changed, 1251 insertions(+) create mode 100644 vllm_kunlun/models/deepseek_mtp.py create mode 100644 vllm_kunlun/v1/sample/__init__.py create mode 100644 vllm_kunlun/v1/sample/rejection_sampler.py create mode 100644 vllm_kunlun/v1/sample/spec_decode/__init__.py create mode 100644 vllm_kunlun/v1/sample/spec_decode/eagle.py diff --git a/vllm_kunlun/__init__.py b/vllm_kunlun/__init__.py index f4fbbbb..4a22dff 100644 --- a/vllm_kunlun/__init__.py +++ b/vllm_kunlun/__init__.py @@ -17,6 +17,7 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0) "vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler", "vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler", "vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler", + "vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler" } if module_name in module_mappings: diff --git a/vllm_kunlun/models/__init__.py b/vllm_kunlun/models/__init__.py index 9fd12c5..55ab61e 100644 --- a/vllm_kunlun/models/__init__.py +++ b/vllm_kunlun/models/__init__.py @@ -88,6 +88,10 @@ def register_model(): ModelRegistry.register_model( "DeepseekV32ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM") + + ModelRegistry.register_model( + "DeepSeekMTPModel", + "vllm_kunlun.models.deepseek_mtp:DeepSeekMTP") def register_quant_method(): """to do""" diff --git a/vllm_kunlun/models/deepseek_mtp.py b/vllm_kunlun/models/deepseek_mtp.py new file mode 100644 index 0000000..6385eae --- /dev/null +++ b/vllm_kunlun/models/deepseek_mtp.py @@ -0,0 +1,289 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import IntermediateTensors + +from .deepseek_v2 import (DeepseekV2DecoderLayer, + get_spec_layer_idx_from_weight_name) +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import maybe_prefix + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class DeepSeekMultiTokenPredictorLayer(nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + + self.is_v32 = hasattr(config, "index_topk") + if self.is_v32: + topk_tokens = config.index_topk + topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device="cuda") + else: + topk_indices_buffer = None + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +class DeepSeekMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + DeepSeekMultiTokenPredictorLayer(vllm_config, + f"{prefix}.layers.{idx}") + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( + input_ids, + positions, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states)) + return logits + + +class DeepSeekMTP(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, spec_step_idx) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name_mapped = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + if ((param_name == "fused_qkv_a_proj") + and name_mapped not in params_dict): + continue + else: + name = name_mapped + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if (spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + and rename shared layer weights to be top level. + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + shared_weight_names = ["embed_tokens"] + spec_layer_weight = False + shared_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + elif shared_weight: + # treat shared weights as top level weights + name = name.replace(f"model.layers.{spec_layer}.", "model.") + return name diff --git a/vllm_kunlun/v1/sample/__init__.py b/vllm_kunlun/v1/sample/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_kunlun/v1/sample/rejection_sampler.py b/vllm_kunlun/v1/sample/rejection_sampler.py new file mode 100644 index 0000000..0d34385 --- /dev/null +++ b/vllm_kunlun/v1/sample/rejection_sampler.py @@ -0,0 +1,645 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional +from typing import Union +import torch +import torch.nn as nn + +from vllm.logger import init_logger + +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +logger = init_logger(__name__) + +PLACEHOLDER_TOKEN_ID = -1 +GREEDY_TEMPERATURE = -1 +# Maximum number of speculative draft tokens allowed per request in a single +# step. This value is chosen to be large enough to handle typical use cases. +MAX_SPEC_LEN = 32 + + +class RejectionSampler(nn.Module): + """ + The implementation strictly follows the algorithm described in + https://arxiv.org/abs/2211.17192. + However, we want to clarify the terminology used in the implementation: + accepted tokens: tokens that are accepted based on the relationship + between the "raw" draft and target probabilities. + recovered tokens: tokens that are sampled based on the adjusted probability + distribution, which is derived from both the draft and target + probabilities. + bonus tokens: + If all proposed tokens are accepted, the bonus token is added to the + end of the sequence. The bonus token is only sampled from the target + probabilities. We pass in the bonus tokens instead of sampling them + in the rejection sampler to allow for more flexibility in the + sampling process. For example, we can use top_p, top_k sampling for + bonus tokens, while spec decode does not support these sampling + strategies. + output tokens: + Tokens are finally generated with the rejection sampler. + output tokens = accepted tokens + recovered tokens + bonus tokens + """ + + def forward( + self, + metadata: SpecDecodeMetadata, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_logits: torch.Tensor, + # [batch_size, 1] + bonus_token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + ''' + Args: + metadata: + Metadata for spec decoding. + draft_probs (Optional[torch.Tensor]): + Probability distribution for the draft tokens. Shape is + [num_tokens, vocab_size]. Can be None if probabilities are + not provided, which is the case for ngram spec decode. + target_logits (torch.Tensor): + Target model's logits probability distribution. + Shape is [num_tokens, vocab_size]. Here, probabilities from + different requests are flattened into a single tensor because + this is the shape of the output logits. + NOTE: `target_logits` can be updated in place to save memory. + bonus_token_ids_tensor (torch.Tensor): + A tensor containing bonus tokens. Shape is [batch_size, 1]. + Bonus tokens are added to the end of the sequence if all + proposed tokens are accepted. We generate the bonus tokens + outside of the rejection sampler with the default sampling + strategy. It allows for more flexibility in the sampling + process such as top_p, top_k sampling. + sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata): + Additional metadata needed for sampling, such as temperature, + top-k/top-p parameters, or other relevant information. + Returns: + output_token_ids (torch.Tensor): + A tensor containing the final output token IDs. + ''' + assert metadata.max_spec_len <= MAX_SPEC_LEN + # [num_tokens, vocab_size] + # NOTE(woosuk): `target_logits` can be updated in place inside the + # `compute_probs` function. + + target_probs = compute_probs( + target_logits, + metadata.cu_num_draft_tokens, + sampling_metadata, + ) + + output_token_ids = rejection_sample( + metadata.draft_token_ids, + metadata.num_draft_tokens, + metadata.max_spec_len, + metadata.cu_num_draft_tokens, + draft_probs, + target_probs, + bonus_token_ids, + sampling_metadata, + ) + return output_token_ids + + @staticmethod + def parse_output( + output_token_ids: torch.Tensor, + vocab_size: int, + ) -> list[list[int]]: + """Parse the output of the rejection sampler. + + Args: + output_token_ids: The sampled token IDs in shape + [batch_size, max_spec_len + 1]. The rejected tokens are + replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler + and will be filtered out in this function. + vocab_size: The size of the vocabulary. + + Returns: + A list of lists of token IDs. + """ + output_token_ids_np = output_token_ids.cpu().numpy() + # Create mask for valid tokens. + valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & + (output_token_ids_np < vocab_size)) + outputs = [ + row[valid_mask[i]].tolist() + for i, row in enumerate(output_token_ids_np) + ] + return outputs + + +def rejection_sample( + # [num_tokens] + draft_token_ids: torch.Tensor, + # [batch_size] + num_draft_tokens: list[int], + max_spec_len: int, + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_probs: torch.Tensor, + # [batch_size, 1] + bonus_token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + assert draft_token_ids.ndim == 1 + assert draft_probs is None or draft_probs.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 + assert target_probs.ndim == 2 + + batch_size = len(num_draft_tokens) + num_tokens = draft_token_ids.shape[0] + vocab_size = target_probs.shape[-1] + device = target_probs.device + assert draft_token_ids.is_contiguous() + assert draft_probs is None or draft_probs.is_contiguous() + assert target_probs.is_contiguous() + assert bonus_token_ids.is_contiguous() + assert target_probs.shape == (num_tokens, vocab_size) + + # Create output buffer. + output_token_ids = torch.empty( + (batch_size, max_spec_len + 1), + dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. + device=device, + ) + output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) + + if sampling_metadata.all_greedy: + is_greedy = None + else: + is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE + if not sampling_metadata.all_random: + # Rejection sampling for greedy sampling requests. + target_argmax = target_probs.argmax(dim=-1) + if min(num_draft_tokens) == 1 and max( + num_draft_tokens) == 1 and sampling_metadata.all_greedy: + rejection_greedy_sample_spec_len_1_pytorch( + output_token_ids, + draft_token_ids, + target_argmax, + bonus_token_ids, + ) + else: + rejection_greedy_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + num_draft_tokens, + max_spec_len, + is_greedy, + ) + if sampling_metadata.all_greedy: + return output_token_ids + + # Generate uniform probabilities for rejection sampling. + # [num_tokens] + uniform_probs = generate_uniform_probs( + num_tokens, + num_draft_tokens, + sampling_metadata.generators, + device, + ) + + # Sample recovered tokens for each position. + # [num_tokens] + recovered_token_ids = sample_recovered_tokens( + max_spec_len, + num_draft_tokens, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + sampling_metadata, + device, + ) + + rejection_random_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + # num_warps=1, + ) + return output_token_ids + + +def compute_probs( + logits: torch.Tensor, # [num_tokens, vocab_size] + cu_num_draft_tokens: torch.Tensor, # [batch_size] + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + """Compute probability distribution from logits based on sampling metadata. + + This function applies temperature scaling to the logits and converts + them to probabilities using softmax. For greedy decoding, it returns + the original logits. + + Args: + logits: Input logits tensor to be converted to probabilities. + cu_num_draft_tokens: Cumulative number of draft tokens. + sampling_metadata: Metadata containing sampling parameters such as + temperature and whether greedy sampling is used. + + Returns: + torch.Tensor: Probability distribution (softmax of scaled logits) + if non-greedy sampling is used, otherwise returns the + original logits. + """ + assert logits.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 + if sampling_metadata.all_greedy: + return logits + + num_tokens = logits.shape[0] + temperature = expand_batch_to_tokens( + sampling_metadata.temperature, + cu_num_draft_tokens, + num_tokens, + replace_from=GREEDY_TEMPERATURE, + replace_to=1, + ) + # NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor. + logits.div_(temperature.unsqueeze(-1)) + + # Get expanded top_k and top_p tensors. + top_k = None + if sampling_metadata.top_k is not None: + top_k = expand_batch_to_tokens( + sampling_metadata.top_k, + cu_num_draft_tokens, + num_tokens, + ) + top_p = None + if sampling_metadata.top_p is not None: + top_p = expand_batch_to_tokens( + sampling_metadata.top_p, + cu_num_draft_tokens, + num_tokens, + ) + + # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, + # which is slow for large vocab sizes. This may cause performance issues. + logits = apply_top_k_top_p(logits, top_k, top_p) + output_prob = logits.softmax(dim=-1, dtype=torch.float32) + return output_prob + + +def expand_batch_to_tokens( + x: torch.Tensor, # [batch_size] + cu_num_tokens: torch.Tensor, # [batch_size] + num_tokens: int, + replace_from: int = 0, + replace_to: int = 0, +) -> torch.Tensor: + """Expand [batch_size] tensor to [num_tokens] tensor based on the number of + tokens per batch in cu_num_tokens. + + For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then + num_tokens = 6, and expanded_x = [a, a, b, b, b, c]. + + Args: + x: [batch_size] tensor to expand. + cu_num_tokens: [batch_size] tensor containing the cumulative number of + tokens per batch. Each element represents the total number of + tokens up to and including that batch. + num_tokens: Total number of tokens. + replace_from: int = 0 + Value to be replaced if it is found in x. + replace_to: int = 0 + Value to replace with when replace_from is found. + Returns: + expanded_x: [num_tokens] tensor. + """ + batch_size = x.shape[0] + assert cu_num_tokens.shape[0] == batch_size + expanded_x = x.new_empty(num_tokens) + expand_pytorch( + expanded_x, + x, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + ) + return expanded_x + + +def generate_uniform_probs( + num_tokens: int, + num_draft_tokens: list[int], + generators: dict[int, torch.Generator], + device: torch.device, +) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + if available. + + This method creates a tensor of shape `(num_tokens, )` filled + with uniform random values in the range [0, 1). If `generators` is provided, + the requests with their own seeds will use the provided `torch.Generator` + for reproducibility. The samples for the other requests will be generated + without a seed. + + Args: + num_tokens : int + Total number of tokens. + num_draft_tokens : List[List[int]] + Number of draft tokens per request. + generators : Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. + device : torch.device + The device on which to allocate the tensor. + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(num_tokens, )` containing uniform + random values in the range [0, 1). + """ + uniform_probs = torch.rand( + (num_tokens, ), + dtype=torch.float32, + device=device, + ) + start_idx = 0 + for req_idx, n in enumerate(num_draft_tokens): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if n == 0: + continue + end_idx = start_idx + n + generator = generators.get(req_idx) + if generator is not None: + uniform_probs[start_idx:end_idx].uniform_(generator=generator) + start_idx = end_idx + return uniform_probs + + +def sample_recovered_tokens( + max_spec_len: int, + num_draft_tokens: list[int], + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens] + draft_token_ids: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_probs: torch.Tensor, + sampling_metadata: SamplingMetadata, + device: torch.device, +) -> torch.Tensor: + # NOTE(woosuk): Create only one distribution for each request. + batch_size = len(num_draft_tokens) + vocab_size = target_probs.shape[-1] + q = torch.empty( + (batch_size, vocab_size), + dtype=torch.float32, + device=device, + ) + q.exponential_() + for i, generator in sampling_metadata.generators.items(): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if num_draft_tokens[i] > 0: + q[i].exponential_(generator=generator) + + recovered_token_ids = torch.empty_like(draft_token_ids) + sample_recovered_tokens_pytorch( + recovered_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + IS_NGRAM=draft_probs is None, + ) + return recovered_token_ids + + +def rejection_greedy_sample_spec_len_1_pytorch( + output_token_ids, # [batch_size, 2] + draft_token_ids, # [num_tokens] + target_argmax, # [num_tokens] + bonus_token_ids, # [batch_size] +): + batch_size = output_token_ids.size(0) + num_tokens = draft_token_ids.size(0) + assert batch_size == num_tokens + accept_req_mask = draft_token_ids == target_argmax + output_token_ids[:, 0] = target_argmax + bonus_token_ids = bonus_token_ids.squeeze(1) + output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids, + output_token_ids[:, 1]) + + +def rejection_greedy_sample_pytorch( + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + target_argmax, # [num_tokens] + bonus_token_ids, # [batch_size] + draft_tokens_per_req, # [batch_size], list + max_spec_len, + is_greedy=None, # [batch_size] or None +): + batch_size = output_token_ids.size(0) + num_tokens = draft_token_ids.size(0) + device = output_token_ids.device + draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to( + device, non_blocking=True) + if is_greedy is None: + is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device) + + start_indices = cu_num_draft_tokens - draft_tokens_per_req + req_ids = torch.arange(batch_size, device=device) + token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req) + token_positions = torch.arange( + num_tokens, device=device) - start_indices[token_req_ids] + + # Find the first mismatch position of each request. + mismatch_global = (draft_token_ids != target_argmax) + if max_spec_len == 0: + first_mismatch_pos_per_req = torch.zeros(batch_size, + dtype=torch.long, + device=device) + else: + # [bs, max_spec_len] + pos_matrix = torch.full((batch_size, max_spec_len), + -1, + dtype=torch.long, + device=device) + pos_matrix[token_req_ids, token_positions] = token_positions + mismatch_matrix = torch.full((batch_size, max_spec_len), + False, + dtype=torch.bool, + device=device) + mismatch_matrix[token_req_ids, token_positions] = mismatch_global + mismatch_positions = torch.where(mismatch_matrix, pos_matrix, + max_spec_len * 2) + first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1) + no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2) + first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[ + no_mismatch_mask] + + # Copy matched target tokens into output. + copy_len = torch.minimum(first_mismatch_pos_per_req + 1, + draft_tokens_per_req) + copy_indices = torch.arange(max_spec_len + 1, + device=device).expand(batch_size, -1) + copy_mask = copy_indices < copy_len.unsqueeze(1) + greedy_mask = is_greedy.unsqueeze(1) + final_copy_mask = copy_mask & greedy_mask + global_idx = start_indices.unsqueeze(1) + copy_indices + output_token_ids[final_copy_mask] = target_argmax[ + global_idx[final_copy_mask]].to(output_token_ids.dtype) + # Fill bonus token. + needs_bonus = is_greedy & (first_mismatch_pos_per_req + >= draft_tokens_per_req) + if torch.any(needs_bonus): + bonus_rows = torch.where(needs_bonus)[0] + bonus_cols = draft_tokens_per_req[bonus_rows] + bonus_token_ids = bonus_token_ids.squeeze(1) + output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows] + + +def rejection_random_sample_pytorch( + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + bonus_token_ids, # [batch_size] + recovered_token_ids, # [num_tokens] + uniform_probs, # [num_tokens] + is_greedy, # [batch_size] + max_spec_len, + vocab_size, + IS_NGRAM=False, +): + batch_size = output_token_ids.shape[0] + + for req_idx in range(batch_size): + if is_greedy[req_idx]: + continue + + if req_idx == 0: + start_idx = 0 + else: + start_idx = cu_num_draft_tokens[req_idx - 1].item() + end_idx = cu_num_draft_tokens[req_idx].item() + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = draft_token_ids[start_idx + pos].item() + + if IS_NGRAM: + draft_prob = 1.0 + else: + draft_prob = draft_probs[start_idx + pos, + draft_token_id].item() + + target_prob = target_probs[start_idx + pos, + draft_token_id].item() + uniform_prob = uniform_probs[start_idx + pos].item() + + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + token_id = draft_token_id + else: + rejected = True + token_id = recovered_token_ids[start_idx + pos].item() + + output_token_ids[req_idx, pos] = token_id + + if not rejected: + bonus_token_id = bonus_token_ids[req_idx].item() + output_token_ids[req_idx, num_draft_tokens] = bonus_token_id + + +def expand_pytorch( + output_ptr, # [num_tokens] + input_ptr, # [batch_size] + cu_num_tokens_ptr, # [batch_size] + replace_from, + replace_to, + MAX_NUM_TOKENS, +): + batch_size = len(input_ptr) + + for req_idx in range(batch_size): + start_idx = 0 if req_idx == 0 else cu_num_tokens_ptr[req_idx - 1] + end_idx = cu_num_tokens_ptr[req_idx] + num_tokens = end_idx - start_idx + + src_val = input_ptr[req_idx] + src_val = replace_to if src_val == replace_from else src_val + + offset = torch.arange(MAX_NUM_TOKENS, device=num_tokens.device) + mask = offset < num_tokens + + output_slice = start_idx + offset[mask] + output_ptr[output_slice] = src_val + + +def sample_recovered_tokens_pytorch( + output_token_ids, # [num_tokens] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + q, # [batch_size, vocab_size] + vocab_size, + IS_NGRAM=False, +): + batch_size = len(cu_num_draft_tokens) + + for req_idx in range(batch_size): + start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1] + end_idx = cu_num_draft_tokens[req_idx] + num_draft_tokens = end_idx - start_idx + + for pos in range(num_draft_tokens): + token_idx = start_idx + pos + + if IS_NGRAM: + draft_token_id = draft_token_ids[token_idx] + orig_prob = target_probs[token_idx, draft_token_id].item() + target_probs[token_idx, draft_token_id] = 0 + prob = target_probs[token_idx].clone() + else: + draft_p = draft_probs[token_idx].clone() + target_p = target_probs[token_idx].clone() + prob = torch.maximum(target_p - draft_p, + torch.tensor(0.0, device=target_p.device)) + + q_values = torch.full((vocab_size, ), + float('-inf'), + device=q.device) + q_values[:vocab_size] = q[req_idx, :vocab_size] + + recovered_id = torch.argmax(prob / q_values).item() + output_token_ids[token_idx] = recovered_id + + if IS_NGRAM: + target_probs[token_idx, draft_token_id] = orig_prob + diff --git a/vllm_kunlun/v1/sample/spec_decode/__init__.py b/vllm_kunlun/v1/sample/spec_decode/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_kunlun/v1/sample/spec_decode/eagle.py b/vllm_kunlun/v1/sample/spec_decode/eagle.py new file mode 100644 index 0000000..9cbd59e --- /dev/null +++ b/vllm_kunlun/v1/sample/spec_decode/eagle.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn + +from vllm.attention.layer import Attention +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.platforms import current_platform +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata) +from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, + TreeAttentionMetadataBuilder) +from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.eagle import EagleProposer + +logger = init_logger(__name__) + +PADDING_SLOT_ID = -1 + + +def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, + sampling_metadata: SamplingMetadata, + mm_embeds: Optional[list[torch.Tensor]] = None, +) -> torch.Tensor: + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + if self.method == "eagle3": + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids + + assert self.runner is not None + + # FIXME: need to consider multiple kv_cache_groups + attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ + .build_for_drafting(common_attn_metadata=common_attn_metadata, + draft_index=0) + if attn_metadata.decode is not None and attn_metadata.decode.spec_num_seq_len is not None: + attn_metadata.decode.spec_num_seq_len = -1 + # At this moment, we assume all eagle layers belong to the same KV + # cache group, thus using the same attention metadata. + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + if self.use_cuda_graph and \ + num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + num_input_tokens = num_tokens + + + # copy inputs to buffer for cudagraph + self.positions[:num_tokens] = target_positions + self.hidden_states[:num_tokens] = target_hidden_states + if self.is_multimodal_model: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = self.model.get_input_embeddings( + input_ids, + multimodal_embeddings=mm_embeds or None, + ) + self.inputs_embeds[:num_tokens] = inputs_embeds + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + else: + inputs_embeds = None + input_ids = self.input_ids[:num_input_tokens] + + with set_forward_context(per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): + ret_hidden_states = self.model( + input_ids=input_ids, + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + inputs_embeds=inputs_embeds, + ) + if self.method == "deepseek_mtp": + last_hidden_states = ret_hidden_states + hidden_states = self.hidden_states[:num_input_tokens] + else: + last_hidden_states, hidden_states = ret_hidden_states + sample_hidden_states = last_hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + positions = target_positions[last_token_indices] + hidden_states = hidden_states[last_token_indices] + + if isinstance(attn_metadata, TreeAttentionMetadata): + # Draft using tree attention. + draft_token_ids_list = self.propose_tree( + batch_size=batch_size, + logits=logits, + positions=positions, + hidden_states=hidden_states, + common_attn_metadata=common_attn_metadata, + ) + # [batch_size, num_tree_tokens] + return torch.cat(draft_token_ids_list, dim=1) + + draft_token_ids = logits.argmax(dim=-1) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + # TODO: Currently, MTP module released by deepseek only has + # one layer. Adapt this code to support multiple layers once + # there's a multi-layer MTP module. + + + # Generate the remaining draft tokens. + draft_token_ids_list = [draft_token_ids] + if self.use_cuda_graph and \ + batch_size <= self.cudagraph_batch_sizes[-1]: + input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) + else: + input_batch_size = batch_size + + common_attn_metadata.num_actual_tokens = batch_size + common_attn_metadata.max_query_len = 1 + common_attn_metadata.query_start_loc = self.arange[:batch_size + 1].to(torch.int32) + common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[:batch_size + 1]).clone().to(torch.int32) + for _ in range(self.num_speculative_tokens - 1): + # Update the inputs. + # cast to int32 is crucial when eagle model is compiled. + # tensor.argmax() returns int64 by default. + input_ids = draft_token_ids_list[-1].int() + positions += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) + + # Increment the sequence lengths. + common_attn_metadata.seq_lens += 1 + common_attn_metadata.seq_lens_cpu += 1 + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, + 1) + common_attn_metadata.num_computed_tokens_cpu = \ + common_attn_metadata.seq_lens_cpu - 1 + + # Compute the slot mapping. + block_numbers = clamped_positions // self.block_size + block_ids = common_attn_metadata.block_table_tensor.gather( + dim=1, index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions % self.block_size) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + common_attn_metadata.slot_mapping.masked_fill_( + exceeds_max_model_len, PADDING_SLOT_ID) + + attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ + .build_for_drafting(common_attn_metadata=common_attn_metadata, + draft_index=0) + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:batch_size] = hidden_states + if self.is_multimodal_model: + inputs_embeds = self.model.get_input_embeddings(input_ids) + self.inputs_embeds[:batch_size] = inputs_embeds + inputs_embeds = self.inputs_embeds[:input_batch_size] + input_ids = None + else: + inputs_embeds = None + input_ids = self.input_ids[:input_batch_size] + + # Run the model. + with set_forward_context(per_layer_attn_metadata, + self.vllm_config, + num_tokens=input_batch_size): + last_hidden_states = self.model( + input_ids=input_ids, + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + inputs_embeds=inputs_embeds, + ) + logits = self.model.compute_logits(last_hidden_states[:batch_size], + None) + draft_token_ids = logits.argmax(dim=-1) + draft_token_ids_list.append(draft_token_ids) + + # [batch_size, num_speculative_tokens] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + return draft_token_ids + +def prepare_next_token_ids_padded(self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int) -> \ + tuple[torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. + """ + # TODO(Ben): Combine this into a custom fused kernel + + # Precompute get_token_id for when there is no valid next token + num_reqs = gpu_input_batch.num_reqs + self.backup_next_token_ids.np[:num_reqs] = np.array([ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item()) + for i in range(num_reqs) + ]) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + + # Mask out the sampled tokens indices that should not be sampled. + discard_sampled_tokens_req_indices = \ + discard_request_indices[:num_discarded_requests] + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + # valid_sampled_token_ids_gpu.index_fill_( + # 0, discard_sampled_tokens_req_indices, -1) + # ---- FIX START ---- + # XPU/XMLIR index_fill_ does NOT accept empty index tensor. + if num_discarded_requests > 0: + # make sure index is on same device and is int64 + idx = discard_sampled_tokens_req_indices + if idx.device != valid_sampled_token_ids_gpu.device: + idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True) + if idx.dtype != torch.long: + idx = idx.to(torch.long) + if idx.numel() > 0: + valid_sampled_token_ids_gpu.index_fill_(0, idx, -1) + # ---- FIX END ---- + # Generate a mask for all valid tokens within those requests + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, + dtype=torch.bool) + else: + valid_mask = ( + (valid_sampled_token_ids_gpu != -1) & + (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)) + + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, + last_valid_indices_safe.unsqueeze(1)).squeeze(1) + + # Use last token if valid, pre-computed backup if not + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, selected_tokens, + self.backup_next_token_ids.gpu[:batch_size]) + + return next_token_ids, valid_sampled_tokens_count + +EagleProposer.propose = propose +EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded \ No newline at end of file