update
This commit is contained in:
342
vllm/model_executor/models/mistral.py
Normal file
342
vllm/model_executor/models/mistral.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Mistral adaptation of the LLaMA architecture."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .utils import AutoWeightsLoader
|
||||
|
||||
|
||||
class MistralMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
gate_up_proj_bias: bool | None = None,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
disable_tp: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
gate_up_proj_bias = bias if gate_up_proj_bias is None else gate_up_proj_bias
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=gate_up_proj_bias,
|
||||
quant_config=quant_config,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MistralAttention(LlamaAttention):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
bias_o_proj: bool = False,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config=config,
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=bias,
|
||||
bias_o_proj=bias_o_proj,
|
||||
cache_config=cache_config,
|
||||
prefix=prefix,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
|
||||
llama_4_scaling_config: dict[str, int | float | str] | None = getattr(
|
||||
config, "llama_4_scaling", None
|
||||
)
|
||||
self.do_llama_4_scaling = llama_4_scaling_config is not None
|
||||
if self.do_llama_4_scaling:
|
||||
assert llama_4_scaling_config is not None
|
||||
self.llama_4_scaling_original_max_position_embeddings = (
|
||||
llama_4_scaling_config["original_max_position_embeddings"]
|
||||
)
|
||||
self.llama_4_scaling_beta = llama_4_scaling_config["beta"]
|
||||
|
||||
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||
# Llama4 scaling
|
||||
scaling = 1 + self.llama_4_scaling_beta * torch.log(
|
||||
1
|
||||
+ torch.floor(
|
||||
positions / self.llama_4_scaling_original_max_position_embeddings
|
||||
)
|
||||
)
|
||||
# Broadcast over head_dim
|
||||
return scaling.unsqueeze(-1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
if self.do_llama_4_scaling:
|
||||
attn_scale = self._get_llama_4_attn_scale(positions)
|
||||
q = (q * attn_scale).to(q.dtype)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MistralDecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: LlamaConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
config=config,
|
||||
attn_layer_type=MistralAttention,
|
||||
)
|
||||
|
||||
self.layer_idx = int(prefix.split(sep=".")[-1])
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
|
||||
if getattr(config, "ada_rms_norm_t_cond", False):
|
||||
self.ada_rms_norm_t_cond = nn.Sequential(
|
||||
ColumnParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_size=config.ada_rms_norm_t_cond_dim,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
),
|
||||
nn.GELU(),
|
||||
RowParallelLinear(
|
||||
input_size=config.ada_rms_norm_t_cond_dim,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.ada_rms_norm_t_cond = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
t_cond: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
if self.ada_rms_norm_t_cond is not None:
|
||||
assert t_cond is not None
|
||||
hidden_states = hidden_states * (1 + self.ada_rms_norm_t_cond(t_cond))
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MistralModel(LlamaModel):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = MistralDecoderLayer,
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
t_cond: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
return super().forward(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds, t_cond=t_cond
|
||||
)
|
||||
|
||||
|
||||
class MistralForCausalLM(LlamaForCausalLM):
|
||||
# Mistral: We don't support LoRA on the embedding layers.
|
||||
embedding_modules: dict[str, str] = {}
|
||||
|
||||
# Mistral/Llama models can also be loaded with --load-format mistral
|
||||
# from consolidated.safetensors checkpoints
|
||||
mistral_mapping = {
|
||||
"layers": "model.layers",
|
||||
"attention": "self_attn",
|
||||
"qscale_act": "input_scale",
|
||||
"qscale_weight": "weight_scale",
|
||||
"kv_fake_quantizer.qscale_act": "kv_scale",
|
||||
"q_fake_quantizer.qscale_act": "attn.q_scale",
|
||||
"k_fake_quantizer.qscale_act": "k_scale",
|
||||
"v_fake_quantizer.qscale_act": "v_scale",
|
||||
"wq": "q_proj",
|
||||
"wk": "k_proj",
|
||||
"wv": "v_proj",
|
||||
"wo": "o_proj",
|
||||
"attention_norm": "input_layernorm",
|
||||
"feed_forward": "mlp",
|
||||
"w1": "gate_proj",
|
||||
"w2": "down_proj",
|
||||
"w3": "up_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"tok_embeddings": "model.embed_tokens",
|
||||
"output": "lm_head",
|
||||
"norm": "model.norm",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = MistralDecoderLayer,
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = MistralDecoderLayer,
|
||||
):
|
||||
return MistralModel(
|
||||
vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(
|
||||
self.maybe_remap_mistral(name, loaded_weight)
|
||||
for name, loaded_weight in weights
|
||||
)
|
||||
|
||||
def maybe_remap_mistral(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
|
||||
return (
|
||||
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
||||
.transpose(1, 2)
|
||||
.reshape(attn_in, attn_out)
|
||||
)
|
||||
|
||||
mapping = self.mistral_mapping
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
# If using quantized model in mistral format,
|
||||
# quantization scales (qscale_weight) also need to be sliced
|
||||
if "wk" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
|
||||
)
|
||||
elif (
|
||||
"wk" in modules
|
||||
and modules[-1] == "qscale_weight"
|
||||
and loaded_weight.numel() > 1
|
||||
):
|
||||
loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
|
||||
elif "wq" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.config.num_attention_heads, self.config.hidden_size
|
||||
)
|
||||
elif (
|
||||
"wq" in modules
|
||||
and modules[-1] == "qscale_weight"
|
||||
and loaded_weight.numel() > 1
|
||||
):
|
||||
loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
|
||||
|
||||
num_modules = len(modules)
|
||||
for i in range(num_modules):
|
||||
item = modules[i]
|
||||
next_item = modules[i + 1] if i < num_modules - 1 else None
|
||||
|
||||
combined_item = f"{item}.{next_item}" if next_item is not None else None
|
||||
|
||||
if combined_item in mapping:
|
||||
name = name.replace(combined_item, mapping[combined_item])
|
||||
elif item in mapping and mapping[item] not in name:
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
Reference in New Issue
Block a user