Enable ModelOpt Llama4 fp8 checkpoint deployment in SGLang (#7129)
This commit is contained in:
@@ -649,6 +649,27 @@ class FusedMoE(torch.nn.Module):
|
|||||||
loaded_weight: torch.tensor,
|
loaded_weight: torch.tensor,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
):
|
):
|
||||||
|
"""Load w2 weights for down projection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expert_data: The expert data tensor to load into
|
||||||
|
shard_dim: The dimension to shard along
|
||||||
|
shard_id: The shard ID (must be "w2")
|
||||||
|
loaded_weight: The weight tensor to load from
|
||||||
|
tp_rank: The tensor parallel rank
|
||||||
|
"""
|
||||||
|
if not isinstance(expert_data, torch.Tensor) or not isinstance(
|
||||||
|
loaded_weight, torch.Tensor
|
||||||
|
):
|
||||||
|
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
|
||||||
|
|
||||||
|
if expert_data.dim() != 2 or loaded_weight.dim() != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if shard_id != "w2":
|
||||||
|
raise ValueError(f"shard_id must be 'w2', got {shard_id}")
|
||||||
|
|
||||||
# Index the loaded weight for tp sharding.
|
# Index the loaded weight for tp sharding.
|
||||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||||
@@ -669,6 +690,10 @@ class FusedMoE(torch.nn.Module):
|
|||||||
if not self.use_presharded_weights:
|
if not self.use_presharded_weights:
|
||||||
if self.use_triton_kernels:
|
if self.use_triton_kernels:
|
||||||
loaded_weight = loaded_weight.transpose(-2, -1)
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||||
|
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
|
||||||
|
)
|
||||||
loaded_weight = loaded_weight.narrow(
|
loaded_weight = loaded_weight.narrow(
|
||||||
shard_dim, shard_size * tp_rank, shard_size
|
shard_dim, shard_size * tp_rank, shard_size
|
||||||
)
|
)
|
||||||
@@ -795,8 +820,21 @@ class FusedMoE(torch.nn.Module):
|
|||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if "ModelOpt" in self.quant_method.__class__.__name__:
|
if "ModelOpt" in self.quant_method.__class__.__name__:
|
||||||
if "weight_scale_2" in weight_name or "input_scale" in weight_name:
|
# Determine per-tensor weight scale patterns based on variant
|
||||||
|
is_fp4_variant = (
|
||||||
|
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
|
||||||
|
per_tensor_conditions = (
|
||||||
|
"weight_scale_2" in weight_name
|
||||||
|
if is_fp4_variant
|
||||||
|
else "weight_scale" in weight_name
|
||||||
|
) or "input_scale" in weight_name
|
||||||
|
|
||||||
|
if per_tensor_conditions:
|
||||||
self._load_per_tensor_weight_scale(
|
self._load_per_tensor_weight_scale(
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
param=param,
|
param=param,
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
|||||||
from sglang.srt.layers.quantization.utils import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
convert_to_channelwise,
|
convert_to_channelwise,
|
||||||
is_layer_skipped,
|
is_layer_skipped,
|
||||||
|
per_tensor_dequantize,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
if self.exclude_modules and any(
|
if self.exclude_modules and any(
|
||||||
module in prefix for module in self.exclude_modules
|
module in prefix
|
||||||
|
or (
|
||||||
|
prefix.startswith("language_model.")
|
||||||
|
and module in prefix.removeprefix("language_model.")
|
||||||
|
)
|
||||||
|
for module in self.exclude_modules
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
|
|
||||||
|
# Add MoE support
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
|
if isinstance(layer, FusedMoE):
|
||||||
|
return ModelOptFp8MoEMethod(self)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
|||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOptFp8MoEMethod:
|
||||||
|
"""MoE method for ModelOpt FP8.
|
||||||
|
Supports loading FP8 checkpoints with static weight scale and activation scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The ModelOpt quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Dynamic class composition pattern.
|
||||||
|
|
||||||
|
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
|
||||||
|
at runtime while avoiding circular import issues.
|
||||||
|
"""
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||||
|
|
||||||
|
if not hasattr(cls, "_initialized"):
|
||||||
|
original_init = cls.__init__
|
||||||
|
new_cls = type(
|
||||||
|
cls.__name__,
|
||||||
|
(FusedMoEMethodBase,),
|
||||||
|
{
|
||||||
|
"__init__": original_init,
|
||||||
|
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||||
|
obj.__init__(*args, **kwargs)
|
||||||
|
return obj
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
def __init__(self, quant_config: ModelOptFp8Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
|
# Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
|
||||||
|
weight_dtype = (
|
||||||
|
torch.float8_e4m3fn
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized
|
||||||
|
else params_dtype
|
||||||
|
)
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
|
||||||
|
w13_weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
|
||||||
|
),
|
||||||
|
input_dim=2,
|
||||||
|
output_dim=1,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
|
||||||
|
w2_weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
num_experts, hidden_size, intermediate_size, dtype=weight_dtype
|
||||||
|
),
|
||||||
|
input_dim=2,
|
||||||
|
output_dim=1,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
|
||||||
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
|
# They will be combined to a single scale after weight loading.
|
||||||
|
w13_weight_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.full(
|
||||||
|
(num_experts, 2),
|
||||||
|
torch.finfo(torch.float32).min,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
w2_weight_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.full(
|
||||||
|
(num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32
|
||||||
|
),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
# Set weight loader attributes for scales
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||||
|
)
|
||||||
|
|
||||||
|
# INPUT SCALES - Per-tensor scaling for ModelOpt
|
||||||
|
w13_input_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
w2_input_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
"""Process FP8 MoE weights after loading from serialized checkpoint.
|
||||||
|
|
||||||
|
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
||||||
|
"""
|
||||||
|
|
||||||
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||||
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||||
|
|
||||||
|
# Handle scale parameters
|
||||||
|
if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
|
||||||
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||||
|
# We take the max of the w1 and w3 scales then dequant and requant each expert.
|
||||||
|
if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
|
|
||||||
|
# Get the maximum scale across w1 and w3 for each expert
|
||||||
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
|
|
||||||
|
# Requantize each expert's weights using the combined scale
|
||||||
|
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
|
||||||
|
# where the first intermediate_size rows are w1, the next are w3
|
||||||
|
intermediate_size = layer.w13_weight.shape[1] // 2
|
||||||
|
for expert_id in range(layer.w13_weight.shape[0]):
|
||||||
|
start = 0
|
||||||
|
for shard_id in range(2): # w1 and w3
|
||||||
|
# Dequantize using the original scale for this shard
|
||||||
|
dq_weight = per_tensor_dequantize(
|
||||||
|
layer.w13_weight[expert_id][
|
||||||
|
start : start + intermediate_size, :
|
||||||
|
],
|
||||||
|
layer.w13_weight_scale[expert_id][shard_id],
|
||||||
|
)
|
||||||
|
# Requantize using the combined max scale
|
||||||
|
(
|
||||||
|
layer.w13_weight[expert_id][
|
||||||
|
start : start + intermediate_size, :
|
||||||
|
],
|
||||||
|
_,
|
||||||
|
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||||
|
|
||||||
|
start += intermediate_size
|
||||||
|
|
||||||
|
# Update the scale parameter to be per-expert instead of per-shard
|
||||||
|
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
||||||
|
else:
|
||||||
|
layer.w13_weight_scale = Parameter(
|
||||||
|
layer.w13_weight_scale.data, requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
|
||||||
|
layer.w2_weight_scale = Parameter(
|
||||||
|
layer.w2_weight_scale.data, requires_grad=False
|
||||||
|
)
|
||||||
|
if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
|
||||||
|
layer.w13_input_scale = Parameter(
|
||||||
|
layer.w13_input_scale.max(), requires_grad=False
|
||||||
|
)
|
||||||
|
if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
|
||||||
|
layer.w2_input_scale = Parameter(
|
||||||
|
layer.w2_input_scale.max(), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
|
# Expert selection
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=inplace,
|
||||||
|
activation=activation,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
no_combine=no_combine,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp4Config(QuantizationConfig):
|
class ModelOptFp4Config(QuantizationConfig):
|
||||||
"""Config class for FP4."""
|
"""Config class for FP4."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
import json as json_lib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
@@ -19,6 +22,13 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.utils import add_prefix, is_cpu
|
from sglang.srt.utils import add_prefix, is_cpu
|
||||||
|
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
|
default_weight_loader,
|
||||||
|
maybe_remap_kv_scale_name,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import add_prefix
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Llama4ForConditionalGeneration(nn.Module):
|
class Llama4ForConditionalGeneration(nn.Module):
|
||||||
@@ -37,19 +47,85 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.vision_model = Llama4VisionModel(config.vision_config)
|
# Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
|
||||||
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
self.has_vision = self._has_vision_weights(config)
|
||||||
|
if not self.has_vision:
|
||||||
|
logger.warning(
|
||||||
|
"No vision weights found in checkpoint. Model will run in text-only mode. "
|
||||||
|
"Multimodal capabilities (image processing) will be unavailable."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.has_vision:
|
||||||
|
self.vision_model = Llama4VisionModel(config.vision_config)
|
||||||
|
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
||||||
|
else:
|
||||||
|
self.vision_model = None
|
||||||
|
self.multi_modal_projector = None
|
||||||
|
|
||||||
# Initialize the language model
|
# Initialize the language model
|
||||||
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
||||||
|
|
||||||
self.language_model = Llama4ForCausalLM(
|
self.language_model = Llama4ForCausalLM(
|
||||||
config.text_config,
|
config.text_config if hasattr(config, "text_config") else config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("language_model", prefix),
|
prefix=add_prefix("language_model", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config.text_config)
|
self.logits_processor = LogitsProcessor(
|
||||||
|
config.text_config if hasattr(config, "text_config") else config
|
||||||
|
)
|
||||||
|
|
||||||
|
def _has_vision_weights(self, config) -> bool:
|
||||||
|
"""Check if the model has vision components by examining the checkpoint."""
|
||||||
|
model_path = getattr(config, "_name_or_path", None)
|
||||||
|
if not model_path:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if this is a local path first
|
||||||
|
if os.path.isdir(model_path):
|
||||||
|
index_file = os.path.join(model_path, "model.safetensors.index.json")
|
||||||
|
if os.path.exists(index_file):
|
||||||
|
return self._check_vision_weights_in_index(index_file)
|
||||||
|
|
||||||
|
# For HuggingFace models, we need to check the actual checkpoint
|
||||||
|
# The config might say it's multimodal, but the checkpoint might be text-only
|
||||||
|
try:
|
||||||
|
# Try to access the HuggingFace cache directory
|
||||||
|
from huggingface_hub import try_to_load_from_cache
|
||||||
|
|
||||||
|
# Check if index file exists in cache
|
||||||
|
index_file_path = try_to_load_from_cache(
|
||||||
|
repo_id=model_path,
|
||||||
|
filename="model.safetensors.index.json",
|
||||||
|
cache_dir=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if index_file_path and os.path.exists(index_file_path):
|
||||||
|
return self._check_vision_weights_in_index(index_file_path)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If we can't access the cache, fall back to config-based detection
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback, assume text-only
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_vision_weights_in_index(self, index_file: str) -> bool:
|
||||||
|
"""Check if the model.safetensors.index.json contains vision weights."""
|
||||||
|
try:
|
||||||
|
with open(index_file, "r") as f:
|
||||||
|
index_data = json_lib.load(f)
|
||||||
|
|
||||||
|
vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"]
|
||||||
|
weight_names = index_data.get("weight_map", {}).keys()
|
||||||
|
|
||||||
|
return any(
|
||||||
|
pattern in weight_name
|
||||||
|
for weight_name in weight_names
|
||||||
|
for pattern in vision_patterns
|
||||||
|
)
|
||||||
|
except (OSError, json_lib.JSONDecodeError, KeyError):
|
||||||
|
return False
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||||
@@ -59,6 +135,10 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
self,
|
self,
|
||||||
items: List[MultimodalDataItem],
|
items: List[MultimodalDataItem],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# For text-only models, return None or raise an error
|
||||||
|
if not self.has_vision or self.vision_model is None:
|
||||||
|
raise ValueError("Vision model not available for text-only checkpoint")
|
||||||
|
|
||||||
pixel_values = (
|
pixel_values = (
|
||||||
torch.concat([item.pixel_values for item in items])
|
torch.concat([item.pixel_values for item in items])
|
||||||
.to(next(self.vision_model.parameters()).device)
|
.to(next(self.vision_model.parameters()).device)
|
||||||
@@ -79,11 +159,14 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# For text-only models, pass None for image_data_embedding_func
|
||||||
|
image_embedding_func = self.get_image_feature if self.has_vision else None
|
||||||
|
|
||||||
hs = general_mm_embed_routine(
|
hs = general_mm_embed_routine(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
language_model=self.language_model,
|
language_model=self.language_model,
|
||||||
image_data_embedding_func=self.get_image_feature,
|
image_data_embedding_func=image_embedding_func,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -124,7 +207,6 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
return name, loaded_weight
|
return name, loaded_weight
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
||||||
|
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||||
@@ -137,11 +219,12 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
num_experts = (
|
||||||
|
self.config.text_config.num_local_experts
|
||||||
|
if hasattr(self.config, "text_config")
|
||||||
|
else self.config.num_local_experts
|
||||||
|
)
|
||||||
|
|
||||||
num_experts = self.config.text_config.num_local_experts
|
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
@@ -150,81 +233,279 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if not "vision" in name:
|
if self._should_skip_weight(name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = self._transform_weight_name(name)
|
||||||
|
|
||||||
|
if "vision" not in name:
|
||||||
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
||||||
name, loaded_weight
|
name, loaded_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
if self._handle_scale_remapping(name, params_dict):
|
||||||
if weight_name not in name:
|
continue
|
||||||
continue
|
|
||||||
|
|
||||||
if "vision" in name:
|
if self._handle_stacked_params(
|
||||||
continue
|
name, loaded_weight, stacked_params_mapping, params_dict
|
||||||
name = name.replace(weight_name, param_name)
|
):
|
||||||
param = params_dict[name]
|
continue
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
if self._handle_expert_weights(
|
||||||
break
|
name, loaded_weight, expert_params_mapping, params_dict, num_experts
|
||||||
else:
|
):
|
||||||
if ".experts" in name:
|
continue
|
||||||
# NOTE: llama4 fp8 has different weight format for experts
|
|
||||||
if (
|
self._handle_default_weight(name, loaded_weight, params_dict)
|
||||||
"experts.gate_up_proj" not in name
|
|
||||||
and "experts.down_proj" not in name
|
def _should_skip_weight(self, name: str) -> bool:
|
||||||
):
|
"""Check if we should skip loading this weight."""
|
||||||
for mapping in expert_params_mapping:
|
return "vision" in name and not self.has_vision
|
||||||
param_name, weight_name, expert_id, shard_id = mapping
|
|
||||||
if weight_name not in name:
|
def _transform_weight_name(self, name: str) -> str:
|
||||||
continue
|
"""Transform weight name by adding language_model prefix if needed."""
|
||||||
name = name.replace(weight_name, param_name)
|
if (
|
||||||
param = params_dict[name]
|
not name.startswith("language_model.")
|
||||||
weight_loader = param.weight_loader
|
and "vision" not in name
|
||||||
weight_loader(
|
and "multi_modal_projector" not in name
|
||||||
param,
|
):
|
||||||
loaded_weight,
|
return f"language_model.{name}"
|
||||||
name,
|
return name
|
||||||
shard_id=shard_id,
|
|
||||||
expert_id=expert_id,
|
def _handle_scale_remapping(self, name: str, params_dict: dict) -> bool:
|
||||||
)
|
"""Handle scale parameter remapping. Returns True if handled."""
|
||||||
break
|
if "scale" in name and "expert" not in name:
|
||||||
else:
|
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
if ".gate_up_proj" in name:
|
return remapped_name is None
|
||||||
name_list = [
|
return False
|
||||||
name.replace(
|
|
||||||
".experts.gate_up_proj", ".experts.w13_weight"
|
def _handle_stacked_params(
|
||||||
)
|
self,
|
||||||
] * 2
|
name: str,
|
||||||
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
loaded_weight: torch.Tensor,
|
||||||
shard_id_list = ["w1", "w3"]
|
stacked_params_mapping: list,
|
||||||
else:
|
params_dict: dict,
|
||||||
name_list = [
|
) -> bool:
|
||||||
name.replace(".experts.down_proj", ".experts.w2_weight")
|
"""Handle stacked parameter loading. Returns True if handled."""
|
||||||
]
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
shard_id_list = ["w2"]
|
if weight_name in name and "vision" not in name:
|
||||||
loaded_weight_list = [loaded_weight]
|
transformed_name = name.replace(weight_name, param_name)
|
||||||
for name, loaded_weight, shard_id in zip(
|
param = params_dict[transformed_name]
|
||||||
name_list, loaded_weight_list, shard_id_list
|
param.weight_loader(param, loaded_weight, shard_id)
|
||||||
):
|
return True
|
||||||
param = params_dict[name]
|
return False
|
||||||
weight_loader = param.weight_loader
|
|
||||||
for expert_id in range(num_experts):
|
def _handle_expert_weights(
|
||||||
weight_loader(
|
self,
|
||||||
param,
|
name: str,
|
||||||
loaded_weight[expert_id].T,
|
loaded_weight: torch.Tensor,
|
||||||
name,
|
expert_params_mapping: list,
|
||||||
shard_id=shard_id,
|
params_dict: dict,
|
||||||
expert_id=expert_id,
|
num_experts: int,
|
||||||
)
|
) -> bool:
|
||||||
else:
|
"""Handle expert weight loading for MoE (Mixture of Experts) layers.
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
Args:
|
||||||
continue
|
name: Parameter name from the checkpoint
|
||||||
param = params_dict[name]
|
loaded_weight: The weight tensor to be loaded
|
||||||
weight_loader = getattr(
|
expert_params_mapping: Mapping of parameter names to expert configurations
|
||||||
param, "weight_loader", default_weight_loader
|
params_dict: Dictionary of model parameters
|
||||||
|
num_experts: Total number of experts in the MoE layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the parameter was handled (is an expert parameter), False otherwise
|
||||||
|
"""
|
||||||
|
if ".experts" not in name:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if "experts.gate_up_proj" not in name and "experts.down_proj" not in name:
|
||||||
|
return self._handle_other_expert_params(
|
||||||
|
name, loaded_weight, expert_params_mapping, params_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if "scale" in name:
|
||||||
|
return self._handle_expert_scale_params(
|
||||||
|
name, loaded_weight, params_dict, num_experts
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._handle_expert_weight_params(
|
||||||
|
name, loaded_weight, params_dict, num_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_other_expert_params(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
expert_params_mapping: list,
|
||||||
|
params_dict: dict,
|
||||||
|
) -> bool:
|
||||||
|
"""Handle expert parameters that are not gate_up_proj or down_proj weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Parameter name from the checkpoint
|
||||||
|
loaded_weight: The weight tensor to be loaded
|
||||||
|
expert_params_mapping: List of tuples mapping checkpoint names to model parameters
|
||||||
|
params_dict: Dictionary of model parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if parameter was found and handled, False otherwise
|
||||||
|
"""
|
||||||
|
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
|
||||||
|
if weight_name in name:
|
||||||
|
transformed_name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[transformed_name]
|
||||||
|
param.weight_loader(
|
||||||
|
param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _transform_expert_name(
|
||||||
|
self, name: str, is_weight: bool = False
|
||||||
|
) -> Tuple[str, str, List[str]]:
|
||||||
|
"""Transform expert parameter name and get shard information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The original parameter name
|
||||||
|
is_weight: Whether this is a weight parameter (adds _weight suffix)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (transformed_name, shard_id, shard_id_list)
|
||||||
|
"""
|
||||||
|
suffix = "_weight" if is_weight else ""
|
||||||
|
|
||||||
|
if ".gate_up_proj" in name:
|
||||||
|
transformed_name = name.replace(
|
||||||
|
".experts.gate_up_proj", f".experts.w13{suffix}"
|
||||||
|
)
|
||||||
|
shard_id = "w13"
|
||||||
|
shard_id_list = ["w1", "w3"]
|
||||||
|
else: # down_proj
|
||||||
|
transformed_name = name.replace(
|
||||||
|
".experts.down_proj", f".experts.w2{suffix}"
|
||||||
|
)
|
||||||
|
shard_id = "w2"
|
||||||
|
shard_id_list = ["w2"]
|
||||||
|
|
||||||
|
return transformed_name, shard_id, shard_id_list
|
||||||
|
|
||||||
|
def _handle_expert_scale_params(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
params_dict: dict,
|
||||||
|
num_experts: int,
|
||||||
|
) -> bool:
|
||||||
|
"""Handle quantization scale parameters for expert weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Parameter name containing scale information
|
||||||
|
loaded_weight: Scale tensor to be loaded
|
||||||
|
params_dict: Dictionary of model parameters
|
||||||
|
num_experts: Total number of experts for broadcast operations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True (always handles scale parameters)
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Check if this matches the expert parameter pattern: experts.{expert_id}.{param_name}
|
||||||
|
expert_match = re.search(r"experts\.(\d+)\.", name)
|
||||||
|
|
||||||
|
# Transform name
|
||||||
|
transformed_name, _, _ = self._transform_expert_name(name)
|
||||||
|
|
||||||
|
if transformed_name not in params_dict:
|
||||||
|
return True
|
||||||
|
|
||||||
|
param = params_dict[transformed_name]
|
||||||
|
|
||||||
|
# Handle scale parameters
|
||||||
|
if expert_match:
|
||||||
|
# If we have a specific expert ID, only load for that expert
|
||||||
|
expert_id = int(expert_match.group(1))
|
||||||
|
# For scale parameters, we can directly set the value
|
||||||
|
param.data[expert_id] = loaded_weight
|
||||||
|
else:
|
||||||
|
# No expert ID found - this is a single scale for all experts
|
||||||
|
# Load the same scale for all experts
|
||||||
|
for expert_id in range(num_experts):
|
||||||
|
param.data[expert_id] = loaded_weight
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _handle_expert_weight_params(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
params_dict: dict,
|
||||||
|
num_experts: int,
|
||||||
|
) -> bool:
|
||||||
|
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Parameter name (should contain gate_up_proj or down_proj)
|
||||||
|
loaded_weight: Weight tensor(s) to be loaded
|
||||||
|
params_dict: Dictionary of model parameters
|
||||||
|
num_experts: Total number of experts for tensor distribution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True (always handles weight parameters)
|
||||||
|
"""
|
||||||
|
# Transform name and get shard info
|
||||||
|
transformed_name, _, shard_id_list = self._transform_expert_name(
|
||||||
|
name, is_weight=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if ".gate_up_proj" in name:
|
||||||
|
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
||||||
|
else: # down_proj
|
||||||
|
loaded_weight_list = [loaded_weight]
|
||||||
|
|
||||||
|
for param_name, weight_chunk, shard_id in zip(
|
||||||
|
[transformed_name] * len(shard_id_list), loaded_weight_list, shard_id_list
|
||||||
|
):
|
||||||
|
if param_name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[param_name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
|
||||||
|
# Handle the case where loaded_weight might be a single tensor for all experts
|
||||||
|
if weight_chunk.dim() == 2:
|
||||||
|
# Single tensor case - load for all experts
|
||||||
|
for expert_id in range(num_experts):
|
||||||
|
weight_loader(
|
||||||
|
param,
|
||||||
|
weight_chunk.T,
|
||||||
|
param_name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
)
|
)
|
||||||
weight_loader(param, loaded_weight)
|
else:
|
||||||
|
# Multiple experts case - load each expert's weights
|
||||||
|
for expert_id in range(num_experts):
|
||||||
|
weight_loader(
|
||||||
|
param,
|
||||||
|
weight_chunk[expert_id].T,
|
||||||
|
param_name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _handle_default_weight(
|
||||||
|
self, name: str, loaded_weight: torch.Tensor, params_dict: dict
|
||||||
|
):
|
||||||
|
"""Handle default weight loading."""
|
||||||
|
# Skip loading extra bias for GPTQ models
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
return
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||||
if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
|
if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
|
||||||
|
|||||||
Reference in New Issue
Block a user