Files
2026-01-09 15:09:53 +08:00

230 lines
9.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import math
from typing import Iterable, List, Set, Tuple, Optional
from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm import _custom_ops as ops
from vllm.distributed import tensor_model_parallel_all_gather, tensor_model_parallel_gather
from vllm import envs
SQRT2 = 2**0.5
class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
eps : float
Safety term to prevent division by zero. Make sure the chosen value
fits in the range of your encoding scheme
(i.e. fp16 requires eps >= 6e-8).
elementwise_scale_and_shift : bool
Include a learned scaling and shift term after normalization.
"""
def __init__(
self,
normalized_shape,
eps=1e-06,
elementwise_scale_and_shift=True,
):
super().__init__()
self.elementwise_scale_and_shift = elementwise_scale_and_shift
if self.elementwise_scale_and_shift:
self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps
def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
if self.elementwise_scale_and_shift:
x = self.weight * x
x = x + self.bias
return x
class MLPSpeculator(nn.Module):
"""
An implementation of the speculative models introduced in
"Accelerating Production LLMs with Combined Token/Embedding
Speculators"
https://arxiv.org/pdf/2404.19124
Trained speculators of this type are available on HF hub at:
https://huggingface.co/ibm-ai-platform and https://huggingface.co/ibm-granite
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
config = vllm_config.model_config.hf_config
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
else config.emb_dim
self.max_speculative_tokens = config.num_lookahead_tokens
self.tie_weights = config.tie_weights
self.scale_input = config.scale_input
if self.tie_weights:
assert (
self.n_predict > 1
), "You cannot tie weights between stages when only 1 exists"
embedding = VocabParallelEmbedding(
config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
# the initial projection from the base model may
# have a different size, so that stays separate.
# proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
# proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
proj_first = ColumnParallelLinear(input_size=self.emb_dim,
output_size=self.inner_dim,
bias=False,
gather_output=True)
proj_tied = ColumnParallelLinear(input_size=self.inner_dim,
output_size=self.inner_dim,
bias=False,
gather_output=True)
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
(self.max_speculative_tokens - 1))
head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
ln = MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
else:
self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
for _ in range(self.max_speculative_tokens)
])
self.proj = nn.ModuleList([
ColumnParallelLinear(input_size=(self.emb_dim if i == 0 else self.inner_dim),
output_size=self.inner_dim,
bias=False,
gather_output=True)
for i in range(self.max_speculative_tokens)
])
self.head = nn.ModuleList([
ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
for _ in range(self.max_speculative_tokens)
])
self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
for _ in range(self.max_speculative_tokens)
])
if self.scale_input:
self.ln0 = MLPSpeculatorLayerNorm(
self.emb_dim, elementwise_scale_and_shift=False)
self.state_weight = 0.5**(0.5 / config.n_predict)
self.emb_weight = math.sqrt(
(1 - self.state_weight**2) * (self.inner_dim / 2))
self.activation = nn.GELU()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
config.vocab_size, 1.0)
self.sampler = get_sampler()
def generate_proposals(
self,
input_ids: torch.Tensor,
previous_hidden_states: torch.Tensor,
num_predict_tokens: int,
sampling_metadata: SamplingMetadata,
head_index: int
) -> Tuple[Optional[SamplerOutput], Optional[torch.Tensor]]:
if num_predict_tokens > self.max_speculative_tokens:
raise ValueError(f"Max speculative tokens for model is "
f"{self.max_speculative_tokens}, but "
f"{num_predict_tokens} were requested")
if self.scale_input:
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
# Project and predict
z = self.emb[head_index](input_ids) # b k d
states, _ = self.proj[head_index](previous_hidden_states)
# Weighted add of state_weight*state and emb_weight*z
# Let subsequent LN take care of denominator
# state_weight is close to 1, so shouldn't be any precision issues
states.add_(z, alpha=self.emb_weight / self.state_weight)
states = self.activation(self.ln[head_index](states)) # b k d
previous_hidden_states = states
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)
# sampling_metadata is not None indicates that driver card is running
if sampling_metadata is not None:
logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)
output = self.sampler(logits, sampling_metadata)
return output, previous_hidden_states
else:
logits = self.head[head_index].linear_method.apply(self.head[head_index],
states,
bias=None)
logits = tensor_model_parallel_gather(logits)
return None, None
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
name = name.replace("speculator.", "")
param = params_dict.get(name)
if param is not None:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn or envs.VLLM_USE_NN:
if (os.environ['LM_NN'] == '1' and "head" in name) or "proj" in name:
_weight = torch.zeros_like(param.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, param.data, _weight.shape[0], _weight.shape[1])
param.data.copy_(_weight)
param.data=param.data.reshape(ori_shape[1],-1)
return loaded_params