Files

427 lines
15 KiB
Python
Raw Permalink Normal View History

2026-01-19 10:38:50 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2026-01-09 13:34:11 +08:00
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
2026-01-19 10:38:50 +08:00
from collections.abc import Iterable
from itertools import islice
2026-01-09 13:34:11 +08:00
import torch
from torch import nn
from transformers import OPTConfig
2026-01-19 10:38:50 +08:00
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
2026-01-09 13:34:11 +08:00
from vllm.model_executor.layers.activation import get_act_fn
2026-01-19 10:38:50 +08:00
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
2026-01-09 13:34:11 +08:00
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2026-01-19 10:38:50 +08:00
from vllm.model_executor.layers.quantization import QuantizationConfig
2026-01-09 13:34:11 +08:00
from vllm.model_executor.layers.vocab_parallel_embedding import (
2026-01-19 10:38:50 +08:00
ParallelLMHead,
VocabParallelEmbedding,
)
2026-01-09 13:34:11 +08:00
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2026-01-19 10:38:50 +08:00
from vllm.sequence import IntermediateTensors
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
WeightsMapper,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
class OPTLearnedPositionalEmbedding(nn.Embedding):
2026-01-09 13:34:11 +08:00
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the
# embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, positions: torch.Tensor):
return super().forward(positions + self.offset)
class OPTAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
2026-01-19 10:38:50 +08:00
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
2026-01-09 13:34:11 +08:00
) -> None:
super().__init__()
self.embed_dim = embed_dim
2026-01-19 10:38:50 +08:00
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
2026-01-09 13:34:11 +08:00
total_num_heads = num_heads
assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
embed_dim,
self.head_dim,
total_num_heads,
bias=bias,
quant_config=quant_config,
2026-01-19 10:38:50 +08:00
prefix=f"{prefix}.qkv_proj",
2026-01-09 13:34:11 +08:00
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
2026-01-19 10:38:50 +08:00
prefix=f"{prefix}.out_proj",
)
self.attn = Attention(
self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
2026-01-09 13:34:11 +08:00
)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
2026-01-19 10:38:50 +08:00
attn_output = self.attn(q, k, v)
2026-01-09 13:34:11 +08:00
output, _ = self.out_proj(attn_output)
return output
class OPTDecoderLayer(nn.Module):
def __init__(
self,
config: OPTConfig,
2026-01-19 10:38:50 +08:00
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
2026-01-09 13:34:11 +08:00
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
bias=config.enable_bias,
2026-01-19 10:38:50 +08:00
cache_config=cache_config,
2026-01-09 13:34:11 +08:00
quant_config=quant_config,
2026-01-19 10:38:50 +08:00
prefix=f"{prefix}.self_attn",
2026-01-09 13:34:11 +08:00
)
self.do_layer_norm_before = config.do_layer_norm_before
self.self_attn_layer_norm = nn.LayerNorm(
2026-01-19 10:38:50 +08:00
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
)
2026-01-09 13:34:11 +08:00
self.fc1 = ColumnParallelLinear(
self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
quant_config=quant_config,
2026-01-19 10:38:50 +08:00
prefix=f"{prefix}.fc1",
2026-01-09 13:34:11 +08:00
)
2026-01-19 10:38:50 +08:00
self.activation_fn = get_act_fn(config.activation_function)
2026-01-09 13:34:11 +08:00
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
quant_config=quant_config,
2026-01-19 10:38:50 +08:00
prefix=f"{prefix}.fc2",
2026-01-09 13:34:11 +08:00
)
self.final_layer_norm = nn.LayerNorm(
2026-01-19 10:38:50 +08:00
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
)
2026-01-09 13:34:11 +08:00
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
2026-01-19 10:38:50 +08:00
hidden_states = self.self_attn(hidden_states=hidden_states)
2026-01-09 13:34:11 +08:00
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class OPTDecoder(nn.Module):
def __init__(
self,
config: OPTConfig,
2026-01-19 10:38:50 +08:00
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
2026-01-09 13:34:11 +08:00
):
super().__init__()
self.config = config
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.word_embed_proj_dim,
)
# Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding(
2026-01-19 10:38:50 +08:00
config.max_position_embeddings, config.hidden_size
)
2026-01-09 13:34:11 +08:00
# Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size:
2026-01-19 10:38:50 +08:00
self.project_out = ReplicatedLinear(
config.hidden_size,
config.word_embed_proj_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.project_out",
)
2026-01-09 13:34:11 +08:00
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
2026-01-19 10:38:50 +08:00
self.project_in = ReplicatedLinear(
config.word_embed_proj_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.project_in",
)
2026-01-09 13:34:11 +08:00
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to
# keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(
config.hidden_size,
2026-01-19 10:38:50 +08:00
elementwise_affine=config.layer_norm_elementwise_affine,
)
2026-01-09 13:34:11 +08:00
else:
self.final_layer_norm = None
2026-01-19 10:38:50 +08:00
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OPTDecoderLayer(
config, cache_config, quant_config, prefix=prefix
),
prefix=f"{prefix}.layers",
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
2026-01-09 13:34:11 +08:00
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
2026-01-19 10:38:50 +08:00
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.embed_input_ids(input_ids)
pos_embeds = self.embed_positions(positions)
if self.project_in is not None:
inputs_embeds, _ = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(hidden_states)
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
2026-01-09 13:34:11 +08:00
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states, _ = self.project_out(hidden_states)
return hidden_states
2026-01-19 10:38:50 +08:00
@support_torch_compile
2026-01-09 13:34:11 +08:00
class OPTModel(nn.Module):
2026-01-19 10:38:50 +08:00
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
2026-01-09 13:34:11 +08:00
super().__init__()
2026-01-19 10:38:50 +08:00
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
self.decoder = OPTDecoder(
config, cache_config, quant_config, prefix=f"{prefix}.decoder"
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], config.hidden_size
)
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.decoder.embed_input_ids(input_ids)
2026-01-09 13:34:11 +08:00
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
2026-01-19 10:38:50 +08:00
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
return self.decoder(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
2026-01-09 13:34:11 +08:00
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
2026-01-19 10:38:50 +08:00
loaded_params: set[str] = set()
2026-01-09 13:34:11 +08:00
for name, loaded_weight in weights:
2026-01-19 10:38:50 +08:00
for param_name, weight_name, shard_id in stacked_params_mapping:
2026-01-09 13:34:11 +08:00
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
2026-01-19 10:38:50 +08:00
if is_pp_missing_parameter(name, self):
continue
2026-01-09 13:34:11 +08:00
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
2026-01-19 10:38:50 +08:00
if is_pp_missing_parameter(name, self):
continue
2026-01-09 13:34:11 +08:00
param = params_dict[name]
2026-01-19 10:38:50 +08:00
weight_loader = getattr(param, "weight_loader", default_weight_loader)
2026-01-09 13:34:11 +08:00
weight_loader(param, loaded_weight)
2026-01-19 10:38:50 +08:00
loaded_params.add(name)
return loaded_params
class OPTForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"decoder.": "model.decoder.",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = OPTModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if self.config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.word_embed_proj_dim,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(
["lm_head.weight"] if self.config.tie_word_embeddings else None
),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)