Files

426 lines
15 KiB
Python
Raw Permalink Normal View History

2026-01-19 10:38:50 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2026-01-09 13:34:11 +08:00
# Copyright 2023 The vLLM team.
# Copyright (c) Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
2026-01-19 10:38:50 +08:00
from collections.abc import Iterable
from functools import cache
from itertools import islice
from typing import Any
2026-01-09 13:34:11 +08:00
import torch
from torch import nn
from transformers import GemmaConfig
2026-01-19 10:38:50 +08:00
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
2026-01-09 13:34:11 +08:00
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
2026-01-19 10:38:50 +08:00
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
2026-01-09 13:34:11 +08:00
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2026-01-19 10:38:50 +08:00
from vllm.model_executor.layers.quantization import QuantizationConfig
2026-01-09 13:34:11 +08:00
from vllm.model_executor.layers.rotary_embedding import get_rope
2026-01-19 10:38:50 +08:00
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
2026-01-09 13:34:11 +08:00
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2026-01-19 10:38:50 +08:00
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
2026-01-09 13:34:11 +08:00
logger = init_logger(__name__)
2026-01-19 10:38:50 +08:00
@cache
2026-01-09 13:34:11 +08:00
def _get_gemma_act_fn(
2026-01-19 10:38:50 +08:00
hidden_act: str | None,
hidden_activation: str | None,
2026-01-09 13:34:11 +08:00
) -> nn.Module:
if hidden_activation is None:
if hidden_act is not None:
logger.warning(
"Gemma's activation function was incorrectly set to exact GeLU "
"in the config JSON file when it was initially released. "
"Changing the activation function to approximate GeLU "
"(`gelu_pytorch_tanh`). If you want to use the legacy "
"`%s`, edit the config JSON to set "
"`hidden_activation=%s` instead of `hidden_act`. "
"See https://github.com/huggingface/transformers/pull/29402 "
2026-01-19 10:38:50 +08:00
"for more details.",
hidden_act,
hidden_act,
)
2026-01-09 13:34:11 +08:00
return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu_pytorch_tanh":
return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu":
return GeluAndMul(approximate="none")
else:
2026-01-19 10:38:50 +08:00
raise ValueError(
f"Activation function {hidden_act} is not supported for Gemma models."
)
2026-01-09 13:34:11 +08:00
class GemmaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
2026-01-19 10:38:50 +08:00
hidden_act: str | None = None,
hidden_activation: str | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
2026-01-09 13:34:11 +08:00
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
2026-01-19 10:38:50 +08:00
hidden_size,
[intermediate_size] * 2,
2026-01-09 13:34:11 +08:00
bias=False,
2026-01-19 10:38:50 +08:00
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
2026-01-09 13:34:11 +08:00
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class GemmaAttention(nn.Module):
2026-01-19 10:38:50 +08:00
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
rope_parameters: dict[str, Any],
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
2026-01-09 13:34:11 +08:00
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
2026-01-19 10:38:50 +08:00
prefix=f"{prefix}.qkv_proj",
2026-01-09 13:34:11 +08:00
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
2026-01-19 10:38:50 +08:00
prefix=f"{prefix}.o_proj",
2026-01-09 13:34:11 +08:00
)
self.rotary_emb = get_rope(
self.head_dim,
max_position=max_position_embeddings,
2026-01-19 10:38:50 +08:00
rope_parameters=rope_parameters,
2026-01-09 13:34:11 +08:00
is_neox_style=True,
)
2026-01-19 10:38:50 +08:00
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
2026-01-09 13:34:11 +08:00
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
2026-01-19 10:38:50 +08:00
attn_output = self.attn(q, k, v)
2026-01-09 13:34:11 +08:00
output, _ = self.o_proj(attn_output)
return output
class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: GemmaConfig,
2026-01-19 10:38:50 +08:00
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
2026-01-09 13:34:11 +08:00
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GemmaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
2026-01-19 10:38:50 +08:00
rope_parameters=config.rope_parameters,
cache_config=cache_config,
2026-01-09 13:34:11 +08:00
quant_config=quant_config,
2026-01-19 10:38:50 +08:00
prefix=f"{prefix}.self_attn",
2026-01-09 13:34:11 +08:00
)
self.mlp = GemmaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
hidden_activation=getattr(config, "hidden_activation", None),
quant_config=quant_config,
2026-01-19 10:38:50 +08:00
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
2026-01-09 13:34:11 +08:00
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
2026-01-19 10:38:50 +08:00
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
2026-01-09 13:34:11 +08:00
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
2026-01-19 10:38:50 +08:00
hidden_states, residual = self.input_layernorm(hidden_states, residual)
2026-01-09 13:34:11 +08:00
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
2026-01-19 10:38:50 +08:00
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
2026-01-09 13:34:11 +08:00
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
2026-01-19 10:38:50 +08:00
@support_torch_compile
2026-01-09 13:34:11 +08:00
class GemmaModel(nn.Module):
2026-01-19 10:38:50 +08:00
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
2026-01-09 13:34:11 +08:00
super().__init__()
2026-01-19 10:38:50 +08:00
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
2026-01-09 13:34:11 +08:00
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
2026-01-19 10:38:50 +08:00
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GemmaDecoderLayer(
config, cache_config, quant_config, prefix=prefix
),
prefix=f"{prefix}.layers",
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2026-01-09 13:34:11 +08:00
# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
# data type such as bfloat16, not float32.
# See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5
2026-01-19 10:38:50 +08:00
self.register_buffer("normalizer", torch.tensor(normalizer), persistent=False)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
2026-01-09 13:34:11 +08:00
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
2026-01-19 10:38:50 +08:00
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.normalizer
residual = None
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
2026-01-09 13:34:11 +08:00
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
2026-01-19 10:38:50 +08:00
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
2026-01-09 13:34:11 +08:00
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
2026-01-19 10:38:50 +08:00
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
2026-01-09 13:34:11 +08:00
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
2026-01-19 10:38:50 +08:00
loaded_params: set[str] = set()
2026-01-09 13:34:11 +08:00
for name, loaded_weight in weights:
2026-01-19 10:38:50 +08:00
for param_name, shard_name, shard_id in stacked_params_mapping:
2026-01-09 13:34:11 +08:00
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
2026-01-19 10:38:50 +08:00
if is_pp_missing_parameter(name, self):
continue
2026-01-09 13:34:11 +08:00
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
2026-01-19 10:38:50 +08:00
if is_pp_missing_parameter(name, self):
continue
2026-01-09 13:34:11 +08:00
param = params_dict[name]
2026-01-19 10:38:50 +08:00
weight_loader = getattr(param, "weight_loader", default_weight_loader)
2026-01-09 13:34:11 +08:00
weight_loader(param, loaded_weight)
loaded_params.add(name)
2026-01-19 10:38:50 +08:00
return loaded_params
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.quant_config = quant_config
self.model = GemmaModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.model.embed_tokens, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)