Files
sglang/python/sglang/srt/models/llama4.py
2025-04-27 07:18:10 -07:00

487 lines
17 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import Llama4TextConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.dp_attention import (
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
logger = logging.getLogger(__name__)
class Llama4MoE(nn.Module):
@torch.compile(dynamic=True, backend=get_compiler_backend())
@staticmethod
def custom_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
hidden_states.dtype
)
return (
router_scores_aK.view(-1).reshape(router_scores_aK.shape),
router_indices_aK.to(torch.int32),
)
def __init__(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok
intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=False,
quant_config=None,
prefix=add_prefix("router", prefix),
)
self.experts = FusedMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
custom_routing_function=Llama4MoE.custom_routing_function,
intermediate_size=intermediate_size_moe,
reduce_results=False,
renormalize=False,
quant_config=quant_config,
apply_router_weight_on_input=True,
prefix=add_prefix("experts", prefix),
)
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
prefix=add_prefix("shared_expert", prefix),
reduce_results=False, # We need to do scatter before reduce
)
def forward(self, hidden_states):
# router_scores: [num_tokens, num_experts]
router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states)
routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
out_aD = routed_out + shared_out
if self.tp_size > 1:
out_aD = tensor_model_parallel_all_reduce(out_aD)
return out_aD
class Llama4Attention(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
layer_id: int,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
self.hidden_size = hidden_size
self.use_rope = int((layer_id + 1) % 4 != 0)
self.use_qk_norm = config.use_qk_norm and self.use_rope
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.total_num_heads = num_heads
assert self.total_num_heads % attn_tp_size == 0
self.num_heads = self.total_num_heads // attn_tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= attn_tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % attn_tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert attn_tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
self.head_dim = config.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.attn_temperature_tuning = config.attn_temperature_tuning
self.floor_scale = config.floor_scale
self.attn_scale = config.attn_scale
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads
self.qk_norm = (
RMSNorm(
hidden_size=self.head_dim,
eps=config.rms_norm_eps,
)
if self.use_qk_norm
else None
)
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
reduce_results=False,
)
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type in ["llama", "llama4"]:
is_neox_style = False
self.rotary_emb = (
get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
rope_scaling=rope_scaling if rope_scaling != "default" else None,
is_neox_style=is_neox_style,
)
if self.use_rope
else None
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
use_irope=self.use_rope,
)
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
return attn_scale.unsqueeze(-1)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _mul_attn_scale(self, positions, q):
attn_scale = self._get_attn_scale(positions)
return (q * attn_scale).to(q.dtype)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb is not None:
q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
del q_view, k_view, q_out_unused, k_out_unused
if self.qk_norm is not None:
# TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
qk = self.qk_norm(qk).to(torch.bfloat16)
qk = qk.reshape(-1, self.q_size + self.kv_size)
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
if self.attn_temperature_tuning and not self.use_rope:
q = self._mul_attn_scale(positions=positions, q=q)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
class Llama4DecoderLayer(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.layer_id = layer_id
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings
self.dp_size = get_attention_dp_size()
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.self_attn = Llama4Attention(
config=config,
layer_id=layer_id,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
bias_o_proj=False,
prefix=add_prefix("self_attn", prefix),
)
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("feed_forward", prefix),
)
else:
self.feed_forward = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size_mlp,
hidden_act="silu",
quant_config=quant_config,
prefix=add_prefix("feed_forward", prefix),
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected
hidden_states = self.feed_forward(hidden_states)
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
# Scatter
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
class Llama4Model(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix),
enable_tp=not global_server_args_dict["enable_dp_attention"],
)
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Llama4DecoderLayer(
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
),
prefix=add_prefix("layers", prefix),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layers_to_capture = []
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
aux_hidden_states = []
for i in range(len(self.layers)):
if i in self.layers_to_capture:
aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
if not forward_batch.forward_mode.is_idle():
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class Llama4ForCausalLM(LlamaForCausalLM):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config, quant_config, prefix)
def get_input_embeddings(self):
return self.model.embed_tokens
def _init_model(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return Llama4Model(config, quant_config=quant_config, prefix=prefix)
EntryClass = [Llama4ForCausalLM]