# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os from typing import Iterable, List, Optional, Set, Tuple, Any, Dict from collections.abc import Iterable from typing import Optional import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm import _custom_ops as ops from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient) class ResidualBlock(nn.Module): def __init__(self, config: VllmConfig, hidden_size: int, num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList([ nn.Linear(hidden_size, hidden_size, bias=getattr(config, "medusa_fc_bias", False)) for _ in range(num_layers) ]) self.act = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.layers: x = x + self.act(layer(x)) return x class Medusa(nn.Module): """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774 Reference implementation: https://github.com/FasterDecoding/Medusa Differences from reference implementation: 1. Currently this only supports generating proposals from top-1 tokens. 2. We have an optional token_map which reduces draft vocab to most frequently used tokens to give some additional speed-up by reducing sampling overhead. This is disabled unless the checkpoint file has explicit token_map tensor and config has an optional attribute truncated_vocab_size < vocab_size. To use this technique, one has to find the top-k most frequent tokens in target dataset and add that as a tensor in the draft checkpoint (using key token_map). Also, the draft config needs to have truncated_vocab_size (=k) as an attribute.""" def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.speculative_config.draft_model_config.hf_config super().__init__() self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.config = config self.blocks = nn.ModuleList([ ResidualBlock(config=config, hidden_size=self.config.hidden_size, num_layers=self.config.num_hidden_layers) for _ in range(self.config.num_heads) ]) self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size self.medusa_choices = config.medusa_choices if getattr(config, "original_lm_head", False): self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=self.truncated_vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) self.lm_heads = [ self.lm_head for _ in range(self.config.num_heads) ] else: self.lm_heads = nn.ModuleList([ ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=self.truncated_vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) for _ in range(self.config.num_heads) ]) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.truncated_vocab_size, logit_scale) # Token map is a idx to token mapping to reduce the vocab size for # the draft model. Using smaller vocab size for draft, containing # only most frequent tokens reduces the speculation overhead. This # doesn't affect the acceptance rate much and thus gives more speed # -up. By default, this is disabled and is only used if the EAGLE # checkpoint file has token_map tensor. self.token_map = None def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]: return [block(hidden_states) for block in self.blocks] def compute_logits( self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]: logits_lst: List[torch.Tensor] = [] for hs, lm_head in zip(hidden_states, self.lm_heads): #_logits = self.logits_processor(lm_head, hs, sampling_metadata) _logits = lm_head.quant_method.apply(lm_head, hs, bias=None) _logits = tensor_model_parallel_all_gather(_logits) if _logits is None: # _logits should only be None on rank > 0, in which case # it should remain true for every lm_head assert len(logits_lst) == 0 continue if self.token_map is None: logits_lst.append(_logits) else: logits_lst.append(-torch.inf * torch.ones( size=(*_logits.shape[:-1], self.orig_vocab_size), device=_logits.device, dtype=_logits.dtype)) logits_lst[-1][..., self.token_map] = _logits return logits_lst def sample( self, logits: List[torch.Tensor], sample_indices_list: List[List[int]], ) -> List[SamplerOutput]: logits = torch.stack(logits, dim=0).float() logprobs = torch.log_softmax(logits, dim=-1) token_ids = logits.argmax(-1) # support only top-1 for now probs = torch.softmax(logits, dim=-1) token_id_list = [] token_prob_list = [] token_logprob_list = [] for idx, sample_indices in enumerate(sample_indices_list): token_id_list.append(token_ids[:, sample_indices]) token_prob_list.append(probs[:, sample_indices]) token_logprob_list.append(logprobs[:, sample_indices]) outputs: List[Optional[SamplerOutput]] = [] for idx in range(len(sample_indices_list)): outputs.append( SamplerOutput( outputs=None, sampled_token_probs=token_prob_list[idx].squeeze(1), logprobs=token_logprob_list[idx].squeeze(1), sampled_token_ids=token_id_list[idx].squeeze(1), )) return outputs def medusa_sample( self, medusa_logits: List[torch.Tensor], sample_indices_list: List[List[int]], logits: torch.Tensor, medusa_buffers: Dict[str, Any] ) -> List[SamplerOutput]: batch_size = logits.shape[0] candidates_logit = torch.argmax(logits, dim=-1).view(batch_size, -1) # [batch_size, 1] medusa_logits = torch.stack(medusa_logits, dim=0) # [medusa_heads, batch_size, vocab_size] # Extract the TOPK candidates from the medusa logits. candidates_medusa_logits = torch.topk(medusa_logits, TOPK, dim=-1).indices # [medusa_heads, batch_size, TOPK] candidates_medusa_logits = candidates_medusa_logits.permute(1, 0 ,2) # [batch_size, medusa_heads, TOPK] candidates_medusa_logits = candidates_medusa_logits.reshape(batch_size, -1) # Combine the selected candidate from the original logits with the topk medusa logits. candidates = torch.cat([candidates_logit, candidates_medusa_logits], dim=-1) #[batch_size, 1+medusa_heads*TOPK] # Map the combined candidates to the tree indices to get tree candidates. tree_candidates = torch.index_select(candidates, dim=-1, index=medusa_buffers['tree_indices']) # [batch_size, choices] # Extend the tree candidates by appending a zero. tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((batch_size, 1), dtype=torch.long, device=tree_candidates.device)], dim=-1) # [batch_size, choices] # Retrieve the cartesian candidates using the retrieve indices. cart_candidates = tree_candidates_ext[:, medusa_buffers['retrieve_indices']] # [batch_size, retrieve_size, max_depth] token_id_list = [] cart_candidate_list = [] for sample_indices in sample_indices_list: token_id_list.append(tree_candidates[sample_indices, :]) cart_candidate_list.append(cart_candidates[sample_indices, :]) outputs: List[Optional[SamplerOutput]] = [] for idx in range(len(sample_indices_list)): outputs.append( SamplerOutput( outputs=None, sampled_token_ids=token_id_list[idx].squeeze(1), cart_candidates=cart_candidate_list[idx] )) return outputs def generate_proposals( self, previous_hidden_states: torch.Tensor, sample_indices_list: List[List[int]], previous_logits: torch.Tensor=None, medusa_buffers: Dict[str, Any]=None ) -> List[SamplerOutput]: if previous_hidden_states.size(0) == 0: # Return None to signal the Top1Proposer that no proposals # were generated for this batch, allowing it to handle this # special case appropriately return None if medusa_buffers is None: return self.sample( logits=self.compute_logits( hidden_states=self.forward(previous_hidden_states) ), sample_indices_list=sample_indices_list, ) else: return self.medusa_sample(medusa_logits=self.compute_logits( hidden_states=self.forward(previous_hidden_states) ), sample_indices_list=sample_indices_list, logits=previous_logits, medusa_buffers=medusa_buffers) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() weights_map = {} for name, loaded_weight in weights: name = name.replace("medusa_heads.", "") if name == "token_map": if self.truncated_vocab_size < self.orig_vocab_size: self.token_map = nn.Parameter(loaded_weight, requires_grad=False) elif name in params_dict: weights_map[name] = loaded_weight elif (getattr(self.config, "original_lm_head", False) and name == "lm_heads.0.weight"): weights_map["lm_head.weight"] = loaded_weight for name, loaded_weight in weights_map.items(): if "lm_head" in name and self.token_map is not None and\ loaded_weight.shape[0] > self.token_map.shape[0]: loaded_weight = loaded_weight[self.token_map] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) if self.use_llama_nn and os.environ['LM_NN'] == '1' and "lm_head" in name: _weight = torch.zeros_like(param.data) ori_shape =_weight.shape ops.trans_w16_gemm(_weight, param.data, _weight.shape[0], _weight.shape[1]) param.data.copy_(_weight) param.data=param.data.reshape(ori_shape[1],-1) if self.token_map is not None: self.token_map.to(device=self.lm_heads[0].weight.device) assert (self.truncated_vocab_size == self.orig_vocab_size) or (self.token_map is not None) return loaded_params