Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -1,4 +1,6 @@
# coding=utf-8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
# Copyright 2023 The vLLM team.
@@ -19,65 +21,83 @@
"""PyTorch Falcon model."""
import math
from typing import Iterable, List, Optional, Tuple, Union
from collections.abc import Iterable
from itertools import islice
from typing import TypeAlias
import torch
from torch import nn
from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig
FalconConfig = Union[HF_FalconConfig, RWConfig]
from .interfaces import SupportsPP
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
FalconConfig: TypeAlias = HF_FalconConfig | RWConfig
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32)
closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(1,
1 + 2 * num_remaining_heads,
2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32
)
num_remaining_heads = min(
closest_power_of_2, total_num_heads - closest_power_of_2
)
extra_powers = torch.arange(
1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class FalconAttention(nn.Module):
def __init__(
self,
config: FalconConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -117,65 +137,79 @@ class FalconAttention(nn.Module):
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.reduce_row_parallel_results = not (
config.new_decoder_architecture or config.parallel_attn
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
reduce_results=self.reduce_row_parallel_results)
reduce_results=self.reduce_row_parallel_results,
prefix=f"{prefix}.dense",
)
self.use_rotary = config.rotary
self.use_alibi = config.alibi
assert not (self.use_rotary and self.use_alibi), (
"Rotary and alibi are mutually exclusive.")
"Rotary and alibi are mutually exclusive."
)
if self.use_rotary:
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_parameters=config.rope_parameters,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads)
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
self.inv_norm_factor)
alibi_slopes = (
_get_alibi_slopes(self.total_num_heads) * self.inv_norm_factor
)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = Attention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
else:
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads)
self.attn = Attention(
self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
@@ -183,36 +217,42 @@ class FalconAttention(nn.Module):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_rotary:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
attn_output, bias = self.dense(attn_output)
return attn_output, bias
class FalconMLP(nn.Module):
def __init__(
self,
config: FalconConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
4 * hidden_size,
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.act = get_act_fn("gelu")
self.reduce_row_parallel_results = not (
config.new_decoder_architecture or config.parallel_attn
)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
bias=config.bias,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
@@ -225,45 +265,56 @@ class FalconMLP(nn.Module):
class FalconDecoderLayer(nn.Module):
def __init__(
self,
config: FalconConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, quant_config)
self.mlp = FalconMLP(config, quant_config)
self.self_attention = FalconAttention(
config, cache_config, quant_config, prefix=f"{prefix}.self_attention"
)
self.mlp = FalconMLP(config, quant_config, prefix=f"{prefix}.mlp")
self.config = config
if config.new_decoder_architecture:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
if not hasattr(config, "num_ln_in_parallel_attn"):
config.num_ln_in_parallel_attn = None
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
if config.num_ln_in_parallel_attn is None and config.new_decoder_architecture:
config.num_ln_in_parallel_attn = 2
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon
)
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
if config.num_ln_in_parallel_attn == 2:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon
)
self.reduce_row_parallel_results = not (
config.new_decoder_architecture or config.parallel_attn
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
if self.config.new_decoder_architecture:
if self.config.num_ln_in_parallel_attn == 2:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
@@ -273,8 +324,6 @@ class FalconDecoderLayer(nn.Module):
attention_output, attention_bias = self.self_attention(
positions=positions,
hidden_states=attention_layernorm_out,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias
@@ -286,6 +335,13 @@ class FalconDecoderLayer(nn.Module):
residual += attention_output
mlp_layernorm_out = self.post_attention_layernorm(residual)
if (
self.config.new_decoder_architecture
and self.config.parallel_attn
and self.config.num_ln_in_parallel_attn == 1
):
mlp_layernorm_out = attention_layernorm_out
# MLP.
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
if self.reduce_row_parallel_results and mlp_bias is not None:
@@ -306,14 +362,15 @@ class FalconDecoderLayer(nn.Module):
return output
@support_torch_compile
class FalconModel(nn.Module):
def __init__(
self,
config: FalconConfig,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
@@ -326,79 +383,45 @@ class FalconModel(nn.Module):
)
# Transformer blocks
self.h = nn.ModuleList([
FalconDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: FalconDecoderLayer(
config, cache_config, quant_config, prefix=prefix
),
prefix=f"{prefix}.h",
)
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids)
def forward(
self,
input_ids: torch.LongTensor,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
)
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class FalconForCausalLM(nn.Module):
def __init__(
self,
config: FalconConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
positions,
kv_caches,
attn_metadata,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
total_num_heads = self.config.num_attention_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
@@ -408,37 +431,113 @@ class FalconForCausalLM(nn.Module):
total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight":
# Falcon uses tied embeddings.
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
output_dim = getattr(param, "output_dim", None)
loaded_weight_shape = loaded_weight.shape
if output_dim is not None:
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] +
(total_num_kv_heads, num_query_heads_per_kv_head + 2,
-1) + loaded_weight_shape[output_dim + 1:])
loaded_weight_shape[:output_dim]
+ (total_num_kv_heads, num_query_heads_per_kv_head + 2, -1)
+ loaded_weight_shape[output_dim + 1 :]
)
wq = loaded_weight.narrow(
output_dim + 1, 0,
num_query_heads_per_kv_head).reshape(
*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
output_dim + 1, 0, num_query_heads_per_kv_head
).reshape(
*loaded_weight_shape[:output_dim],
-1,
*loaded_weight_shape[output_dim + 1 :],
)
wk = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
output_dim + 1, num_query_heads_per_kv_head, 1
).reshape(
*loaded_weight_shape[:output_dim],
-1,
*loaded_weight_shape[output_dim + 1 :],
)
wv = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head + 1,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
output_dim + 1, num_query_heads_per_kv_head + 1, 1
).reshape(
*loaded_weight_shape[:output_dim],
-1,
*loaded_weight_shape[output_dim + 1 :],
)
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class FalconForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
)
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default
self.tie_word_embeddings = (
config.tie_word_embeddings
if config.tie_word_embeddings is not None
else True
)
if self.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)