Files
sglang/python/sglang/srt/models/mllama4.py
Chang Su f04c80dc42 Add Llama4 support (#5092)
Co-authored-by: Cheng Wan <cwan39@gatech.edu>
Co-authored-by: fzyzcjy <ch271828n@outlook.com>
Co-authored-by: ispobock <ispobaoke@163.com>
2025-04-07 00:29:36 -07:00

155 lines
5.7 KiB
Python

# TODO: add Aapted from vllm/mllama4.py
from collections.abc import Iterable
from typing import Optional, Set, Tuple
import torch
from torch import nn
from transformers import Llama4Config
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class Llama4ForConditionalGeneration(nn.Module):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
def __init__(
self,
config: Llama4Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.quant_config = quant_config
# Initialize the language model
from sglang.srt.models.llama4 import Llama4ForCausalLM
self.language_model = Llama4ForCausalLM(
config.text_config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
self.logits_processor = LogitsProcessor(config.text_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs: object,
) -> torch.Tensor:
return self.language_model(input_ids, positions, forward_batch)
def permute_qk_weight_for_rotary(
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.language_model.config.head_dim * n_heads
attn_out = self.language_model.config.hidden_size
return (
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2)
.reshape(attn_in, attn_out)
)
modules = name.split(".")
# rotary embeds should be sliced
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.language_model.config.num_key_value_heads
)
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.language_model.config.num_attention_heads
)
return name, loaded_weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
]
params_dict = dict(self.named_parameters())
num_experts = self.config.text_config.num_local_experts
for name, loaded_weight in weights:
if name.startswith("vision_model") or name.startswith(
"multi_modal_projector"
):
continue
name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if ".experts" in name:
if ".gate_up_proj" in name:
name_list = [
name.replace(".experts.gate_up_proj", ".experts.w13_weight")
] * 2
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
shard_id_list = ["w1", "w3"]
else:
name_list = [
name.replace(".experts.down_proj", ".experts.w2_weight")
]
shard_id_list = ["w2"]
loaded_weight_list = [loaded_weight]
for name, loaded_weight, shard_id in zip(
name_list, loaded_weight_list, shard_id_list
):
param = params_dict[name]
weight_loader = param.weight_loader
for expert_id in range(num_experts):
weight_loader(
param,
loaded_weight[expert_id].T,
name,
shard_id=shard_id,
expert_id=expert_id,
)
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
EntryClass = Llama4ForConditionalGeneration