Sync from v0.13
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
# coding=utf-8
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
@@ -16,36 +18,51 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTNeoXConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
@@ -53,11 +70,9 @@ class GPTNeoXAttention(nn.Module):
|
||||
self.head_size = self.hidden_size // self.total_num_heads
|
||||
self.bias = getattr(config, "attention_bias", True)
|
||||
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
config.hidden_size,
|
||||
@@ -65,62 +80,65 @@ class GPTNeoXAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
bias=self.bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=self.bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
scaling = self.head_size**-0.5
|
||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||
assert rotary_dim % 2 == 0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_size,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_parameters=config.rope_parameters,
|
||||
)
|
||||
scaling = self.head_size**-0.5
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
self.attn = Attention(self.num_heads, self.head_size, scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class GPTNeoXMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_h_to_4h",
|
||||
)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_4h_to_h",
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act, quant_config,
|
||||
config.intermediate_size)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||
@@ -130,34 +148,35 @@ class GPTNeoXMLP(nn.Module):
|
||||
|
||||
|
||||
class GPTNeoXLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.attention = GPTNeoXAttention(config, quant_config)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.attention = GPTNeoXAttention(
|
||||
config, cache_config, quant_config, prefix=f"{prefix}.attention"
|
||||
)
|
||||
self.mlp = GPTNeoXMLP(config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.input_layernorm(hidden_states)
|
||||
attn_output = self.attention(
|
||||
position_ids=position_ids,
|
||||
hidden_states=attn_input,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if self.use_parallel_residual:
|
||||
@@ -177,101 +196,75 @@ class GPTNeoXLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class GPTNeoXModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embed_in = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
GPTNeoXLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: GPTNeoXLayer(
|
||||
config, cache_config, quant_config, prefix=prefix
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states"], config.hidden_size
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_in(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
attn_metadata,
|
||||
)
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states = layer(position_ids, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTNeoXForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.gpt_neox = GPTNeoXModel(config, quant_config)
|
||||
self.embed_out = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.embed_out.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
if (
|
||||
"attention.bias" in name
|
||||
or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name
|
||||
):
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using OpenRLHF may include
|
||||
# these tensors in the checkpoint. Skip them.
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
@@ -284,12 +277,64 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
if output_dim is not None:
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = loaded_weight.transpose(
|
||||
output_dim, output_dim + 1)
|
||||
loaded_weight_shape[:output_dim]
|
||||
+ (num_heads, 3, -1)
|
||||
+ loaded_weight_shape[output_dim + 1 :]
|
||||
)
|
||||
loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1)
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.gpt_neox = GPTNeoXModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt_neox")
|
||||
)
|
||||
self.embed_out = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "embed_out"),
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.embed_out.weight = self.gpt_neox.embed_in.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.gpt_neox.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.gpt_neox.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.gpt_neox(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(self.embed_out, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user