# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os import math from typing import Iterable, List, Set, Tuple, Optional from collections.abc import Iterable import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm import _custom_ops as ops from vllm.distributed import tensor_model_parallel_all_gather, tensor_model_parallel_gather from vllm import envs SQRT2 = 2**0.5 class MLPSpeculatorLayerNorm(nn.Module): """ A L2 normalization implementation ... Args ---- normalized_shape : int Dimensionality of input data (size of final tensor axis) eps : float Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8). elementwise_scale_and_shift : bool Include a learned scaling and shift term after normalization. """ def __init__( self, normalized_shape, eps=1e-06, elementwise_scale_and_shift=True, ): super().__init__() self.elementwise_scale_and_shift = elementwise_scale_and_shift if self.elementwise_scale_and_shift: self.weight = nn.Parameter(torch.empty(normalized_shape)) self.bias = nn.Parameter(torch.empty(normalized_shape)) self.eps = eps def forward(self, x): xf = x xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) x = xf.type_as(x) if self.elementwise_scale_and_shift: x = self.weight * x x = x + self.bias return x class MLPSpeculator(nn.Module): """ An implementation of the speculative models introduced in "Accelerating Production LLMs with Combined Token/Embedding Speculators" https://arxiv.org/pdf/2404.19124 Trained speculators of this type are available on HF hub at: https://huggingface.co/ibm-ai-platform and https://huggingface.co/ibm-granite """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' config = vllm_config.model_config.hf_config self.n_predict = config.n_predict self.vocab_size = config.vocab_size self.emb_dim = config.emb_dim self.inner_dim = config.inner_dim if config.inner_dim != 0 \ else config.emb_dim self.max_speculative_tokens = config.num_lookahead_tokens self.tie_weights = config.tie_weights self.scale_input = config.scale_input if self.tie_weights: assert ( self.n_predict > 1 ), "You cannot tie weights between stages when only 1 exists" embedding = VocabParallelEmbedding( config.vocab_size, self.inner_dim, org_num_embeddings=config.vocab_size) self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens) # the initial projection from the base model may # have a different size, so that stays separate. # proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False) # proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False) proj_first = ColumnParallelLinear(input_size=self.emb_dim, output_size=self.inner_dim, bias=False, gather_output=True) proj_tied = ColumnParallelLinear(input_size=self.inner_dim, output_size=self.inner_dim, bias=False, gather_output=True) self.proj = nn.ModuleList([proj_first] + [proj_tied] * (self.max_speculative_tokens - 1)) head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) self.head = nn.ModuleList([head] * self.max_speculative_tokens) ln = MLPSpeculatorLayerNorm(self.inner_dim, elementwise_scale_and_shift=True) self.ln = nn.ModuleList([ln] * self.max_speculative_tokens) else: self.emb = nn.ModuleList([ VocabParallelEmbedding(config.vocab_size, self.inner_dim, org_num_embeddings=config.vocab_size) for _ in range(self.max_speculative_tokens) ]) self.proj = nn.ModuleList([ ColumnParallelLinear(input_size=(self.emb_dim if i == 0 else self.inner_dim), output_size=self.inner_dim, bias=False, gather_output=True) for i in range(self.max_speculative_tokens) ]) self.head = nn.ModuleList([ ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) for _ in range(self.max_speculative_tokens) ]) self.ln = nn.ModuleList([ MLPSpeculatorLayerNorm(self.inner_dim, elementwise_scale_and_shift=True) for _ in range(self.max_speculative_tokens) ]) if self.scale_input: self.ln0 = MLPSpeculatorLayerNorm( self.emb_dim, elementwise_scale_and_shift=False) self.state_weight = 0.5**(0.5 / config.n_predict) self.emb_weight = math.sqrt( (1 - self.state_weight**2) * (self.inner_dim / 2)) self.activation = nn.GELU() self.config = config self.logits_processor = LogitsProcessor(config.vocab_size, config.vocab_size, 1.0) self.sampler = get_sampler() def generate_proposals( self, input_ids: torch.Tensor, previous_hidden_states: torch.Tensor, num_predict_tokens: int, sampling_metadata: SamplingMetadata, head_index: int ) -> Tuple[Optional[SamplerOutput], Optional[torch.Tensor]]: if num_predict_tokens > self.max_speculative_tokens: raise ValueError(f"Max speculative tokens for model is " f"{self.max_speculative_tokens}, but " f"{num_predict_tokens} were requested") if self.scale_input: previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 # Project and predict z = self.emb[head_index](input_ids) # b k d states, _ = self.proj[head_index](previous_hidden_states) # Weighted add of state_weight*state and emb_weight*z # Let subsequent LN take care of denominator # state_weight is close to 1, so shouldn't be any precision issues states.add_(z, alpha=self.emb_weight / self.state_weight) states = self.activation(self.ln[head_index](states)) # b k d previous_hidden_states = states # TODO: not yet supporting top_k_tokens_per_head states = states.flatten(0, 1) # sampling_metadata is not None indicates that driver card is running if sampling_metadata is not None: logits = self.logits_processor(self.head[head_index], states, sampling_metadata) output = self.sampler(logits, sampling_metadata) return output, previous_hidden_states else: logits = self.head[head_index].linear_method.apply(self.head[head_index], states, bias=None) logits = tensor_model_parallel_gather(logits) return None, None def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: name = name.replace("speculator.", "") param = params_dict.get(name) if param is not None: weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) if self.use_llama_nn or envs.VLLM_USE_NN: if (os.environ['LM_NN'] == '1' and "head" in name) or "proj" in name: _weight = torch.zeros_like(param.data) ori_shape =_weight.shape ops.trans_w16_gemm(_weight, param.data, _weight.shape[0], _weight.shape[1]) param.data.copy_(_weight) param.data=param.data.reshape(ori_shape[1],-1) return loaded_params