1012 lines
36 KiB
Python
1012 lines
36 KiB
Python
import json as json_lib
|
|
import logging
|
|
import math
|
|
import os
|
|
import re
|
|
from collections.abc import Iterable
|
|
from typing import List, Optional, Set, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import Llama4Config, Llama4VisionConfig
|
|
from transformers.models.llama4.modeling_llama4 import (
|
|
Llama4MultiModalProjector,
|
|
vision_apply_rotary_emb,
|
|
)
|
|
|
|
from sglang.srt.layers.attention.vision import VisionAttention
|
|
from sglang.srt.layers.linear import (
|
|
ColumnParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
from sglang.srt.layers.quantization import QuantizationConfig
|
|
from sglang.srt.managers.mm_utils import (
|
|
MultiModalityDataPaddingPatternMultimodalTokens,
|
|
general_mm_embed_routine,
|
|
)
|
|
from sglang.srt.managers.schedule_batch import (
|
|
Modality,
|
|
MultimodalDataItem,
|
|
MultimodalInputs,
|
|
global_server_args_dict,
|
|
)
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.utils import is_cpu
|
|
|
|
_is_cpu = is_cpu()
|
|
|
|
from sglang.srt.model_loader.weight_utils import (
|
|
default_weight_loader,
|
|
maybe_remap_kv_scale_name,
|
|
)
|
|
from sglang.srt.utils import add_prefix
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Llama4VisionMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
intermediate_size: int,
|
|
output_size: int,
|
|
bias: bool,
|
|
output_activation: bool,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
):
|
|
super().__init__()
|
|
cls_fc1 = ReplicatedLinear if use_data_parallel else ColumnParallelLinear
|
|
self.fc1 = cls_fc1(
|
|
input_size=input_size,
|
|
output_size=intermediate_size,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.fc1",
|
|
)
|
|
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
|
|
self.fc2 = cls_fc2(
|
|
input_size=intermediate_size,
|
|
output_size=output_size,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.fc2",
|
|
)
|
|
self.activation_fn = nn.GELU()
|
|
self.output_activation = output_activation
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states, _ = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states, _ = self.fc2(hidden_states)
|
|
if self.output_activation:
|
|
return self.activation_fn(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
def pixel_shuffle(input_tensor, shuffle_ratio):
|
|
# input_tensor: [batch_size, num_patches, channels]
|
|
batch_size, num_patches, channels = input_tensor.shape
|
|
patch_size = int(math.sqrt(num_patches))
|
|
|
|
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
|
|
batch_size, height, width, channels = input_tensor.size()
|
|
|
|
reshaped_tensor = input_tensor.view(
|
|
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
|
|
)
|
|
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
|
|
|
reshaped_tensor = reshaped_tensor.view(
|
|
batch_size,
|
|
int(height * shuffle_ratio),
|
|
int(width * shuffle_ratio),
|
|
int(channels / (shuffle_ratio**2)),
|
|
)
|
|
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
|
|
|
output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
|
|
return output_tensor
|
|
|
|
|
|
class Llama4VisionPixelShuffleMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
|
self.mlp = Llama4VisionMLP(
|
|
input_size=config.intermediate_size,
|
|
intermediate_size=config.projector_input_dim,
|
|
output_size=config.projector_output_dim,
|
|
bias=config.multi_modal_projector_bias,
|
|
output_activation=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
use_data_parallel=use_data_parallel,
|
|
)
|
|
|
|
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
|
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
|
|
return self.mlp(encoded_patches)
|
|
|
|
|
|
def apply_position_embedding(q, k, freqs_ci, shape):
|
|
# [batch_size_times_num_tiles, num_channels]
|
|
input_shape = shape[:2]
|
|
# [batch_size_times_num_tiles, num_channels, num_heads, head_dim]
|
|
hidden_shape = (*input_shape, *q.shape[-2:])
|
|
q = q.view(hidden_shape)
|
|
k = k.view(hidden_shape)
|
|
q, k = vision_apply_rotary_emb(q, k, freqs_ci)
|
|
return q, k
|
|
|
|
|
|
class Llama4VisionEncoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: Llama4VisionConfig,
|
|
quant_config: Optional[QuantizationConfig],
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.intermediate_size = config.intermediate_size
|
|
|
|
self.self_attn = VisionAttention(
|
|
self.hidden_size,
|
|
self.num_attention_heads,
|
|
self.hidden_size,
|
|
use_qkv_parallel=True,
|
|
# vision_model is explicitly ignored in Maverick-17B-128E-Instruct-FP8
|
|
quant_config=None,
|
|
dropout=0.0,
|
|
qkv_backend="sdpa",
|
|
softmax_in_single_precision=False,
|
|
flatten_batch=False,
|
|
prefix=add_prefix("self_attn", prefix),
|
|
qkv_bias=True,
|
|
customized_position_embedding_applier=apply_position_embedding,
|
|
)
|
|
self.mlp = Llama4VisionMLP(
|
|
input_size=config.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
output_size=config.hidden_size,
|
|
bias=True,
|
|
output_activation=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
use_data_parallel=use_data_parallel,
|
|
)
|
|
|
|
self.input_layernorm = nn.LayerNorm(config.hidden_size)
|
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_state: torch.Tensor,
|
|
freqs_ci: torch.Tensor,
|
|
):
|
|
# Self Attention
|
|
residual = hidden_state
|
|
hidden_state = self.input_layernorm(hidden_state)
|
|
hidden_state = self.self_attn(hidden_state, position_embeddings=freqs_ci)
|
|
hidden_state = residual + hidden_state
|
|
|
|
# Feed forward
|
|
residual = hidden_state
|
|
hidden_state = self.post_attention_layernorm(hidden_state)
|
|
hidden_state = self.mlp(hidden_state)
|
|
hidden_state = residual + hidden_state
|
|
|
|
outputs = hidden_state
|
|
return outputs
|
|
|
|
|
|
class Llama4VisionEncoder(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: Llama4VisionConfig,
|
|
quant_config: Optional[QuantizationConfig],
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
Llama4VisionEncoderLayer(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.layers.{layer_idx}",
|
|
use_data_parallel=use_data_parallel,
|
|
)
|
|
for layer_idx in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
freqs_ci: torch.Tensor, # TODO: move this to an attribute instead of keeping it around
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor` of shape
|
|
`(batch_size, sequence_length, hidden_size)`):
|
|
Optionally, instead of passing `input_ids` you can choose to
|
|
directly pass an embedded representation. This is useful if you
|
|
want more control over how to convert `input_ids` indices into
|
|
associated vectors than the model's internal embedding
|
|
lookup matrix.
|
|
"""
|
|
|
|
for encoder_layer in self.layers:
|
|
layer_outputs = encoder_layer(hidden_states, freqs_ci=freqs_ci)
|
|
hidden_states = layer_outputs
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Llama4UnfoldConvolution(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: Llama4VisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
):
|
|
super().__init__()
|
|
kernel_size = config.patch_size
|
|
if isinstance(kernel_size, int):
|
|
kernel_size = (kernel_size, kernel_size)
|
|
self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
|
|
params = {
|
|
"input_size": config.num_channels * kernel_size[0] * kernel_size[1],
|
|
"output_size": config.hidden_size,
|
|
"bias": False,
|
|
"quant_config": quant_config,
|
|
"prefix": f"{prefix}.linear",
|
|
}
|
|
if use_data_parallel:
|
|
cls = ReplicatedLinear
|
|
else:
|
|
cls = ColumnParallelLinear
|
|
params["gather_output"] = True
|
|
self.linear = cls(**params)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.unfold(hidden_states)
|
|
hidden_states = hidden_states.permute(0, 2, 1).contiguous()
|
|
hidden_states, _ = self.linear(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class Llama4VisionRotaryEmbedding(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
idx = config.image_size // config.patch_size
|
|
img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1)
|
|
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
|
img_idx[-1, -1] = -2 # ID_CLS_TOKEN
|
|
frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
|
|
frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
|
|
freq_dim = config.hidden_size // config.num_attention_heads // 2
|
|
rope_freq = 1.0 / (
|
|
config.rope_theta
|
|
** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim)
|
|
)
|
|
freqs_x = (
|
|
(frequencies_x + 1)[..., None] * rope_freq[None, None, :]
|
|
).repeat_interleave(2, dim=-1)
|
|
freqs_y = (
|
|
(frequencies_y + 1)[..., None] * rope_freq[None, None, :]
|
|
).repeat_interleave(2, dim=-1)
|
|
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
|
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
|
|
freq_cis = torch.view_as_complex(
|
|
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
|
)
|
|
self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
|
|
|
|
def forward(self, hidden_states):
|
|
return self.freqs_ci.to(hidden_states.device)
|
|
|
|
|
|
class Llama4VisionModel(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: Llama4VisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
self.hidden_size = config.hidden_size
|
|
self.num_channels = config.num_channels
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
|
self.scale = config.hidden_size**-0.5
|
|
|
|
self.patch_embedding = Llama4UnfoldConvolution(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.patch_embedding",
|
|
)
|
|
|
|
self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
|
|
self.positional_embedding_vlm = nn.Parameter(
|
|
self.scale * torch.randn(self.num_patches, self.hidden_size)
|
|
)
|
|
|
|
self.rotary_embedding = Llama4VisionRotaryEmbedding(config)
|
|
|
|
# layer norms
|
|
self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
|
self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
|
|
|
# encoders
|
|
self.model = Llama4VisionEncoder(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.model",
|
|
)
|
|
self.vision_adapter = Llama4VisionPixelShuffleMLP(
|
|
config,
|
|
quant_config,
|
|
prefix=f"{prefix}.vision_adapter",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# Patch embedding
|
|
hidden_state = self.patch_embedding(pixel_values)
|
|
num_tiles, num_patches, hidden_dim = hidden_state.shape
|
|
|
|
# Add cls token
|
|
class_embedding = self.class_embedding.expand(
|
|
hidden_state.shape[0], 1, hidden_state.shape[-1]
|
|
)
|
|
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
|
|
num_patches += 1
|
|
|
|
# Position embeddings
|
|
hidden_state = hidden_state.reshape(
|
|
num_tiles,
|
|
1,
|
|
num_patches,
|
|
hidden_dim,
|
|
)
|
|
positional_embedding = self.positional_embedding_vlm.to(
|
|
dtype=hidden_state.dtype, device=hidden_state.device
|
|
)
|
|
hidden_state = hidden_state + positional_embedding
|
|
hidden_state = self.layernorm_pre(hidden_state)
|
|
hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
|
|
freqs_ci = self.rotary_embedding(pixel_values)
|
|
# Apply encoder
|
|
hidden_state = self.model(hidden_state, freqs_ci=freqs_ci)
|
|
hidden_state = self.layernorm_post(hidden_state)
|
|
|
|
# Remove CLS token output
|
|
hidden_state = hidden_state[:, :-1, :]
|
|
|
|
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
|
|
hidden_state = self.vision_adapter(hidden_state)
|
|
|
|
return hidden_state
|
|
|
|
|
|
class Llama4ForConditionalGeneration(nn.Module):
|
|
packed_modules_mapping = {
|
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
}
|
|
|
|
# Pattern to match language model layers only (skip vision_model and multi_modal_projector)
|
|
lora_pattern = re.compile(
|
|
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
config: Llama4Config,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
|
|
# Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
|
|
self.has_vision_weights = self._has_vision_weights(config)
|
|
if not self.has_vision_weights:
|
|
logger.warning(
|
|
"No vision weights found in checkpoint. Model will run in text-only mode. "
|
|
"Multimodal capabilities (vision understanding) will be unavailable. "
|
|
"Please not that this warning might be inaccurate if the weights haven't been fully downloaded"
|
|
)
|
|
|
|
self.has_vision = (
|
|
self.has_vision_weights and global_server_args_dict["enable_multimodal"]
|
|
)
|
|
|
|
if self.has_vision:
|
|
# TODO: make this more general
|
|
ignore_quant_layers = getattr(config, "quantization_config", {}).get(
|
|
"ignore", {}
|
|
)
|
|
if (
|
|
"model.layers.vision_model*" in ignore_quant_layers
|
|
and "model.layers.multi_modal_projector*" in ignore_quant_layers
|
|
):
|
|
vision_quant_config = None
|
|
else:
|
|
vision_quant_config = quant_config
|
|
self.vision_model = Llama4VisionModel(
|
|
config.vision_config,
|
|
quant_config=vision_quant_config,
|
|
prefix=add_prefix("vision_model", prefix),
|
|
)
|
|
|
|
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
|
else:
|
|
self.vision_model = None
|
|
self.multi_modal_projector = None
|
|
|
|
# Initialize the language model
|
|
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
|
|
|
self.language_model = Llama4ForCausalLM(
|
|
config.text_config if hasattr(config, "text_config") else config,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("language_model", prefix),
|
|
)
|
|
|
|
self.logits_processor = LogitsProcessor(
|
|
config.text_config if hasattr(config, "text_config") else config
|
|
)
|
|
self.padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
|
|
|
def _has_vision_weights(self, config) -> bool:
|
|
"""Check if the model has vision components by examining the checkpoint."""
|
|
model_path = getattr(config, "_name_or_path", None)
|
|
if not model_path:
|
|
return False
|
|
|
|
# Check if this is a local path first
|
|
if os.path.isdir(model_path):
|
|
index_file = os.path.join(model_path, "model.safetensors.index.json")
|
|
if os.path.exists(index_file):
|
|
return self._check_vision_weights_in_index(index_file)
|
|
|
|
# For HuggingFace models, we need to check the actual checkpoint
|
|
# The config might say it's multimodal, but the checkpoint might be text-only
|
|
try:
|
|
# Try to access the HuggingFace cache directory
|
|
from huggingface_hub import try_to_load_from_cache
|
|
|
|
# Check if index file exists in cache
|
|
index_file_path = try_to_load_from_cache(
|
|
repo_id=model_path,
|
|
filename="model.safetensors.index.json",
|
|
cache_dir=None,
|
|
)
|
|
if index_file_path and os.path.exists(index_file_path):
|
|
return self._check_vision_weights_in_index(index_file_path)
|
|
|
|
except Exception:
|
|
# If we can't access the cache, fall back to config-based detection
|
|
pass
|
|
|
|
# Fallback, assume text-only
|
|
return False
|
|
|
|
def _check_vision_weights_in_index(self, index_file: str) -> bool:
|
|
"""Check if the model.safetensors.index.json contains vision weights."""
|
|
try:
|
|
with open(index_file, "r") as f:
|
|
index_data = json_lib.load(f)
|
|
|
|
vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"]
|
|
weight_names = index_data.get("weight_map", {}).keys()
|
|
return any(
|
|
pattern in weight_name
|
|
for weight_name in weight_names
|
|
for pattern in vision_patterns
|
|
)
|
|
except (OSError, json_lib.JSONDecodeError, KeyError):
|
|
return False
|
|
|
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
|
return self.padding_pattern.pad_input_tokens(input_ids, mm_inputs)
|
|
|
|
def get_image_feature(
|
|
self,
|
|
items: List[MultimodalDataItem],
|
|
) -> torch.Tensor:
|
|
# For text-only models, return None or raise an error
|
|
if not self.has_vision or self.vision_model is None:
|
|
raise ValueError("Vision model not available for text-only checkpoint")
|
|
pixel_values = (
|
|
torch.concat([item.feature for item in items])
|
|
.to(next(self.vision_model.parameters()).device)
|
|
.type(next(self.vision_model.parameters()).dtype)
|
|
)
|
|
image_features = self.vision_model(pixel_values)
|
|
|
|
vision_flat = image_features.view(-1, image_features.size(-1))
|
|
|
|
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
|
|
|
return projected_vision_flat
|
|
|
|
def should_apply_lora(self, module_name: str) -> bool:
|
|
"""Skip vision model and multi_modal_projector for LoRA."""
|
|
return bool(self.lora_pattern.match(module_name))
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
**kwargs: object,
|
|
) -> torch.Tensor:
|
|
|
|
# For text-only models, pass None for image_data_embedding_func
|
|
image_embedding_func = self.get_image_feature if self.has_vision else None
|
|
|
|
hs = general_mm_embed_routine(
|
|
input_ids=input_ids,
|
|
forward_batch=forward_batch,
|
|
language_model=self.language_model,
|
|
data_embedding_funcs={
|
|
Modality.IMAGE: image_embedding_func,
|
|
},
|
|
positions=positions,
|
|
)
|
|
|
|
return hs
|
|
|
|
def permute_qk_weight_for_rotary(
|
|
self,
|
|
name: str,
|
|
loaded_weight: torch.Tensor,
|
|
) -> Tuple[str, torch.Tensor]:
|
|
|
|
def permute(w: torch.Tensor, n_heads: int):
|
|
attn_in = self.language_model.config.head_dim * n_heads
|
|
attn_out = self.language_model.config.hidden_size
|
|
|
|
return (
|
|
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
|
.transpose(1, 2)
|
|
.reshape(attn_in, attn_out)
|
|
)
|
|
|
|
modules = name.split(".")
|
|
|
|
# rotary embeds should be sliced
|
|
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
|
|
if _is_cpu:
|
|
dim = self.language_model.config.original_total_num_kv_heads
|
|
else:
|
|
dim = self.language_model.config.num_key_value_heads
|
|
loaded_weight = permute(loaded_weight, dim)
|
|
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
|
|
if _is_cpu:
|
|
dim = self.language_model.config.original_num_attention_heads
|
|
else:
|
|
dim = self.language_model.config.num_attention_heads
|
|
loaded_weight = permute(loaded_weight, dim)
|
|
|
|
return name, loaded_weight
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
|
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
|
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
|
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
|
|
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
|
|
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
|
|
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
|
|
]
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
num_experts = (
|
|
self.config.text_config.num_local_experts
|
|
if hasattr(self.config, "text_config")
|
|
else self.config.num_local_experts
|
|
)
|
|
|
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
ckpt_gate_proj_name="gate_proj",
|
|
ckpt_down_proj_name="down_proj",
|
|
ckpt_up_proj_name="up_proj",
|
|
num_experts=num_experts,
|
|
)
|
|
|
|
loaded_params = set()
|
|
|
|
for name, loaded_weight in weights:
|
|
if self._should_skip_weight(name):
|
|
continue
|
|
|
|
name = self._transform_weight_name(name)
|
|
|
|
if "vision" in name:
|
|
name = name.replace(".self_attn.o_proj", ".self_attn.proj")
|
|
else:
|
|
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
|
name, loaded_weight
|
|
)
|
|
|
|
if self._handle_scale_remapping(name, params_dict):
|
|
loaded_params.add(name)
|
|
continue
|
|
|
|
if self._handle_stacked_params(
|
|
name, loaded_weight, stacked_params_mapping, params_dict, loaded_params
|
|
):
|
|
continue
|
|
|
|
if self._handle_expert_weights(
|
|
name,
|
|
loaded_weight,
|
|
expert_params_mapping,
|
|
params_dict,
|
|
num_experts,
|
|
loaded_params,
|
|
):
|
|
continue
|
|
|
|
loaded_params.add(name)
|
|
self._handle_default_weight(name, loaded_weight, params_dict)
|
|
unloaded_params = params_dict.keys() - loaded_params
|
|
if unloaded_params:
|
|
logger.warning(
|
|
f"Some weights are not initialized from checkpoints {unloaded_params}"
|
|
)
|
|
|
|
def _should_skip_weight(self, name: str) -> bool:
|
|
"""Check if we should skip loading this weight."""
|
|
return not self.has_vision and (
|
|
"vision" in name or "multi_modal_projector" in name
|
|
)
|
|
|
|
def _transform_weight_name(self, name: str) -> str:
|
|
"""Transform weight name by adding language_model prefix if needed."""
|
|
if (
|
|
not name.startswith("language_model.")
|
|
and "vision" not in name
|
|
and "multi_modal_projector" not in name
|
|
):
|
|
return f"language_model.{name}"
|
|
return name
|
|
|
|
def _handle_scale_remapping(self, name: str, params_dict: dict) -> bool:
|
|
"""Handle scale parameter remapping. Returns True if handled."""
|
|
if "scale" in name and "expert" not in name:
|
|
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
|
return remapped_name is not None and remapped_name != name
|
|
return False
|
|
|
|
def _handle_stacked_params(
|
|
self,
|
|
name: str,
|
|
loaded_weight: torch.Tensor,
|
|
stacked_params_mapping: list,
|
|
params_dict: dict,
|
|
loaded_params: set,
|
|
) -> bool:
|
|
"""Handle stacked parameter loading. Returns True if handled."""
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name in name:
|
|
transformed_name = name.replace(weight_name, param_name)
|
|
loaded_params.add(transformed_name)
|
|
param = params_dict[transformed_name]
|
|
param.weight_loader(param, loaded_weight, shard_id)
|
|
return True
|
|
return False
|
|
|
|
def _handle_expert_weights(
|
|
self,
|
|
name: str,
|
|
loaded_weight: torch.Tensor,
|
|
expert_params_mapping: list,
|
|
params_dict: dict,
|
|
num_experts: int,
|
|
loaded_params: set,
|
|
) -> bool:
|
|
"""Handle expert weight loading for MoE (Mixture of Experts) layers.
|
|
|
|
Args:
|
|
name: Parameter name from the checkpoint
|
|
loaded_weight: The weight tensor to be loaded
|
|
expert_params_mapping: Mapping of parameter names to expert configurations
|
|
params_dict: Dictionary of model parameters
|
|
num_experts: Total number of experts in the MoE layer
|
|
|
|
Returns:
|
|
bool: True if the parameter was handled (is an expert parameter), False otherwise
|
|
"""
|
|
if ".experts" not in name:
|
|
return False
|
|
|
|
if "experts.gate_up_proj" not in name and "experts.down_proj" not in name:
|
|
return self._handle_other_expert_params(
|
|
name, loaded_weight, expert_params_mapping, params_dict, loaded_params
|
|
)
|
|
|
|
if "scale" in name:
|
|
return self._handle_expert_scale_params(
|
|
name, loaded_weight, params_dict, num_experts, loaded_params
|
|
)
|
|
else:
|
|
return self._handle_expert_weight_params(
|
|
name, loaded_weight, params_dict, num_experts, loaded_params
|
|
)
|
|
|
|
def _handle_other_expert_params(
|
|
self,
|
|
name: str,
|
|
loaded_weight: torch.Tensor,
|
|
expert_params_mapping: list,
|
|
params_dict: dict,
|
|
loaded_params: set,
|
|
) -> bool:
|
|
"""Handle expert parameters that are not gate_up_proj or down_proj weights.
|
|
|
|
Args:
|
|
name: Parameter name from the checkpoint
|
|
loaded_weight: The weight tensor to be loaded
|
|
expert_params_mapping: List of tuples mapping checkpoint names to model parameters
|
|
params_dict: Dictionary of model parameters
|
|
loaded_params: Set of loaded parameter names
|
|
|
|
Returns:
|
|
bool: True if parameter was found and handled, False otherwise
|
|
"""
|
|
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
|
|
if weight_name in name:
|
|
transformed_name = name.replace(weight_name, param_name)
|
|
param = params_dict[transformed_name]
|
|
param.weight_loader(
|
|
param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id
|
|
)
|
|
loaded_params.add(transformed_name)
|
|
return True
|
|
return False
|
|
|
|
def _transform_expert_name(
|
|
self, name: str, is_weight: bool = False
|
|
) -> Tuple[str, str, List[str]]:
|
|
"""Transform expert parameter name and get shard information.
|
|
|
|
Args:
|
|
name: The original parameter name
|
|
is_weight: Whether this is a weight parameter (adds _weight suffix)
|
|
|
|
Returns:
|
|
Tuple of (transformed_name, shard_id, shard_id_list)
|
|
"""
|
|
suffix = "_weight" if is_weight else ""
|
|
|
|
if ".gate_up_proj" in name:
|
|
transformed_name = name.replace(
|
|
".experts.gate_up_proj", f".experts.w13{suffix}"
|
|
)
|
|
shard_id = "w13"
|
|
shard_id_list = ["w1", "w3"]
|
|
else: # down_proj
|
|
transformed_name = name.replace(
|
|
".experts.down_proj", f".experts.w2{suffix}"
|
|
)
|
|
shard_id = "w2"
|
|
shard_id_list = ["w2"]
|
|
|
|
return transformed_name, shard_id, shard_id_list
|
|
|
|
def _handle_expert_scale_params(
|
|
self,
|
|
name: str,
|
|
loaded_weight: torch.Tensor,
|
|
params_dict: dict,
|
|
num_experts: int,
|
|
loaded_params: set,
|
|
) -> bool:
|
|
"""Handle quantization scale parameters for expert weights.
|
|
|
|
Args:
|
|
name: Parameter name containing scale information
|
|
loaded_weight: Scale tensor to be loaded
|
|
params_dict: Dictionary of model parameters
|
|
num_experts: Total number of experts for broadcast operations
|
|
loaded_params: Set of loaded parameter names
|
|
|
|
Returns:
|
|
bool: True (always handles scale parameters)
|
|
"""
|
|
import re
|
|
|
|
# Check if this matches the expert parameter pattern: experts.{expert_id}.{param_name}
|
|
expert_match = re.search(r"experts\.(\d+)\.", name)
|
|
|
|
# Transform name
|
|
transformed_name, _, _ = self._transform_expert_name(name)
|
|
|
|
if transformed_name not in params_dict:
|
|
return True
|
|
|
|
param = params_dict[transformed_name]
|
|
|
|
# Handle scale parameters
|
|
if expert_match:
|
|
# If we have a specific expert ID, only load for that expert
|
|
expert_id = int(expert_match.group(1))
|
|
# For scale parameters, we can directly set the value
|
|
param.data[expert_id] = loaded_weight
|
|
else:
|
|
# No expert ID found - this is a single scale for all experts
|
|
# Load the same scale for all experts
|
|
for expert_id in range(num_experts):
|
|
param.data[expert_id] = loaded_weight
|
|
loaded_params.add(transformed_name)
|
|
|
|
return True
|
|
|
|
def _handle_expert_weight_params(
|
|
self,
|
|
name: str,
|
|
loaded_weight: torch.Tensor,
|
|
params_dict: dict,
|
|
num_experts: int,
|
|
loaded_params: set,
|
|
) -> bool:
|
|
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
|
|
|
|
Args:
|
|
name: Parameter name (should contain gate_up_proj or down_proj)
|
|
loaded_weight: Weight tensor(s) to be loaded
|
|
params_dict: Dictionary of model parameters
|
|
num_experts: Total number of experts for tensor distribution
|
|
loaded_params: Set of loaded parameter names
|
|
|
|
Returns:
|
|
bool: True (always handles weight parameters)
|
|
"""
|
|
# Transform name and get shard info
|
|
transformed_name, _, shard_id_list = self._transform_expert_name(
|
|
name, is_weight=True
|
|
)
|
|
|
|
if ".gate_up_proj" in name:
|
|
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
|
else: # down_proj
|
|
loaded_weight_list = [loaded_weight]
|
|
|
|
for param_name, weight_chunk, shard_id in zip(
|
|
[transformed_name] * len(shard_id_list), loaded_weight_list, shard_id_list
|
|
):
|
|
if param_name not in params_dict:
|
|
continue
|
|
|
|
param = params_dict[param_name]
|
|
weight_loader = param.weight_loader
|
|
loaded_params.add(param_name)
|
|
|
|
# Handle the case where loaded_weight might be a single tensor for all experts
|
|
if weight_chunk.dim() == 2:
|
|
# Single tensor case - load for all experts
|
|
for expert_id in range(num_experts):
|
|
weight_loader(
|
|
param,
|
|
weight_chunk.T,
|
|
param_name,
|
|
shard_id=shard_id,
|
|
expert_id=expert_id,
|
|
)
|
|
else:
|
|
# Multiple experts case - load each expert's weights
|
|
for expert_id in range(num_experts):
|
|
weight_loader(
|
|
param,
|
|
weight_chunk[expert_id].T,
|
|
param_name,
|
|
shard_id=shard_id,
|
|
expert_id=expert_id,
|
|
)
|
|
|
|
return True
|
|
|
|
def _handle_default_weight(
|
|
self, name: str, loaded_weight: torch.Tensor, params_dict: dict
|
|
):
|
|
"""Handle default weight loading."""
|
|
# Skip loading extra bias for GPTQ models
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
return
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
|
if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
|
|
self.language_model.set_eagle3_layers_to_capture(layer_ids)
|
|
|
|
def get_embed_and_head(self):
|
|
# For EAGLE3, we delegate to the language model which should have this method
|
|
# If the language model doesn't have lm_head (like EAGLE3), we return None for head
|
|
embed = self.language_model.get_embed()
|
|
if hasattr(self.language_model, "get_embed_and_head"):
|
|
return self.language_model.get_embed_and_head()
|
|
elif hasattr(self.language_model, "lm_head"):
|
|
return embed, self.language_model.lm_head.weight
|
|
else:
|
|
# For EAGLE3, head might not be needed
|
|
return embed, None
|
|
|
|
def set_embed_and_head(self, embed, head):
|
|
if hasattr(self.language_model, "set_embed_and_head"):
|
|
return self.language_model.set_embed_and_head(embed, head)
|
|
else:
|
|
# For EAGLE3, only set embed
|
|
return self.language_model.set_embed(embed)
|
|
|
|
def get_embed(self):
|
|
return self.language_model.get_embed()
|
|
|
|
def set_embed(self, embed):
|
|
return self.language_model.set_embed(embed)
|
|
|
|
def get_hidden_dim(self, module_name, layer_idx):
|
|
# return input_dim, output_dim
|
|
if module_name == "qkv_proj":
|
|
return (
|
|
self.config.hidden_size,
|
|
self.config.head_dim
|
|
* (
|
|
self.config.num_attention_heads
|
|
+ self.config.num_key_value_heads * 2
|
|
),
|
|
)
|
|
elif module_name == "o_proj":
|
|
return (
|
|
self.config.head_dim * self.config.num_attention_heads,
|
|
self.config.hidden_size,
|
|
)
|
|
elif module_name == "gate_up_proj":
|
|
return self.config.hidden_size, self.config.intermediate_size * 2
|
|
elif module_name == "down_proj":
|
|
decoder_layer = self.language_model.get_layers()[layer_idx]
|
|
intermediate_size = decoder_layer.get_intermediate_size()
|
|
return intermediate_size, self.config.hidden_size
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
|
|
EntryClass = Llama4ForConditionalGeneration
|