diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 5445b4f23..84470456e 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -649,6 +649,27 @@ class FusedMoE(torch.nn.Module): loaded_weight: torch.tensor, tp_rank: int, ): + """Load w2 weights for down projection. + + Args: + expert_data: The expert data tensor to load into + shard_dim: The dimension to shard along + shard_id: The shard ID (must be "w2") + loaded_weight: The weight tensor to load from + tp_rank: The tensor parallel rank + """ + if not isinstance(expert_data, torch.Tensor) or not isinstance( + loaded_weight, torch.Tensor + ): + raise ValueError("expert_data and loaded_weight must be torch.Tensor") + + if expert_data.dim() != 2 or loaded_weight.dim() != 2: + raise ValueError( + f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}" + ) + + if shard_id != "w2": + raise ValueError(f"shard_id must be 'w2', got {shard_id}") # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim @@ -669,6 +690,10 @@ class FusedMoE(torch.nn.Module): if not self.use_presharded_weights: if self.use_triton_kernels: loaded_weight = loaded_weight.transpose(-2, -1) + if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]: + raise ValueError( + f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}" + ) loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) @@ -795,8 +820,21 @@ class FusedMoE(torch.nn.Module): tp_rank=tp_rank, ) return + if "ModelOpt" in self.quant_method.__class__.__name__: - if "weight_scale_2" in weight_name or "input_scale" in weight_name: + # Determine per-tensor weight scale patterns based on variant + is_fp4_variant = ( + "ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__ + ) + + # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor + per_tensor_conditions = ( + "weight_scale_2" in weight_name + if is_fp4_variant + else "weight_scale" in weight_name + ) or "input_scale" in weight_name + + if per_tensor_conditions: self._load_per_tensor_weight_scale( shard_id=shard_id, param=param, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 913a5bb99..85be4f8f4 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.utils import ( convert_to_channelwise, is_layer_skipped, + per_tensor_dequantize, requantize_with_max_scale, ) from sglang.srt.layers.radix_attention import RadixAttention @@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: if self.exclude_modules and any( - module in prefix for module in self.exclude_modules + module in prefix + or ( + prefix.startswith("language_model.") + and module in prefix.removeprefix("language_model.") + ) + for module in self.exclude_modules ): return None @@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig): if self.kv_cache_quant_method and isinstance(layer, RadixAttention): return ModelOptFp8KVCacheMethod(self) + # Add MoE support + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + if isinstance(layer, FusedMoE): + return ModelOptFp8MoEMethod(self) + return None def get_scaled_act_names(self) -> List[str]: @@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): super().__init__(quant_config) +class ModelOptFp8MoEMethod: + """MoE method for ModelOpt FP8. + Supports loading FP8 checkpoints with static weight scale and activation scale. + + Args: + quant_config: The ModelOpt quantization config. + """ + + def __new__(cls, *args, **kwargs): + """ + Dynamic class composition pattern. + + This allows us to effectively "inject" FusedMoEMethodBase as a parent class + at runtime while avoiding circular import issues. + """ + from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config: ModelOptFp8Config): + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + # Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) + weight_loader = extra_weight_attrs.get("weight_loader") + + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight", w13_weight) + + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, hidden_size, intermediate_size, dtype=weight_dtype + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight", w2_weight) + + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALES - Per-tensor scaling for ModelOpts + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = PerTensorScaleParameter( + data=torch.full( + (num_experts, 2), + torch.finfo(torch.float32).min, + dtype=torch.float32, + ), + weight_loader=weight_loader, + ) + w2_weight_scale = PerTensorScaleParameter( + data=torch.full( + (num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32 + ), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + # Set weight loader attributes for scales + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + + # INPUT SCALES - Per-tensor scaling for ModelOpt + w13_input_scale = PerTensorScaleParameter( + data=torch.full((num_experts,), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + w2_input_scale = PerTensorScaleParameter( + data=torch.full((num_experts,), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + layer.register_parameter("w2_input_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Process FP8 MoE weights after loading from serialized checkpoint. + + Only supports pre-quantized checkpoints with FP8 weights and scales. + """ + + layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + + # Handle scale parameters + if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None: + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max of the w1 and w3 scales then dequant and requant each expert. + if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2) + from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant + + # Get the maximum scale across w1 and w3 for each expert + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + + # Requantize each expert's weights using the combined scale + # w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size) + # where the first intermediate_size rows are w1, the next are w3 + intermediate_size = layer.w13_weight.shape[1] // 2 + for expert_id in range(layer.w13_weight.shape[0]): + start = 0 + for shard_id in range(2): # w1 and w3 + # Dequantize using the original scale for this shard + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][ + start : start + intermediate_size, : + ], + layer.w13_weight_scale[expert_id][shard_id], + ) + # Requantize using the combined max scale + ( + layer.w13_weight[expert_id][ + start : start + intermediate_size, : + ], + _, + ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + + start += intermediate_size + + # Update the scale parameter to be per-expert instead of per-shard + layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False) + else: + layer.w13_weight_scale = Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) + + if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None: + layer.w2_weight_scale = Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) + if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None: + layer.w13_input_scale = Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None: + layer.w2_input_scale = Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.topk import select_experts + + # Expert selection + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + activation=activation, + use_fp8_w8a8=True, + per_channel_quant=False, # ModelOpt uses per-tensor quantization + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + no_combine=no_combine, + ) + + class ModelOptFp4Config(QuantizationConfig): """Config class for FP4.""" diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 73d1a0068..4774e842b 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -1,3 +1,6 @@ +import json as json_lib +import logging +import os from collections.abc import Iterable from typing import List, Optional, Set, Tuple @@ -19,6 +22,13 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_cpu _is_cpu = is_cpu() +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) class Llama4ForConditionalGeneration(nn.Module): @@ -37,19 +47,85 @@ class Llama4ForConditionalGeneration(nn.Module): self.config = config self.quant_config = quant_config - self.vision_model = Llama4VisionModel(config.vision_config) - self.multi_modal_projector = Llama4MultiModalProjector(config) + # Check if this is a text-only model (modelopt fp8 llama4 has no vision components) + self.has_vision = self._has_vision_weights(config) + if not self.has_vision: + logger.warning( + "No vision weights found in checkpoint. Model will run in text-only mode. " + "Multimodal capabilities (image processing) will be unavailable." + ) + + if self.has_vision: + self.vision_model = Llama4VisionModel(config.vision_config) + self.multi_modal_projector = Llama4MultiModalProjector(config) + else: + self.vision_model = None + self.multi_modal_projector = None # Initialize the language model from sglang.srt.models.llama4 import Llama4ForCausalLM self.language_model = Llama4ForCausalLM( - config.text_config, + config.text_config if hasattr(config, "text_config") else config, quant_config=quant_config, prefix=add_prefix("language_model", prefix), ) - self.logits_processor = LogitsProcessor(config.text_config) + self.logits_processor = LogitsProcessor( + config.text_config if hasattr(config, "text_config") else config + ) + + def _has_vision_weights(self, config) -> bool: + """Check if the model has vision components by examining the checkpoint.""" + model_path = getattr(config, "_name_or_path", None) + if not model_path: + return False + + # Check if this is a local path first + if os.path.isdir(model_path): + index_file = os.path.join(model_path, "model.safetensors.index.json") + if os.path.exists(index_file): + return self._check_vision_weights_in_index(index_file) + + # For HuggingFace models, we need to check the actual checkpoint + # The config might say it's multimodal, but the checkpoint might be text-only + try: + # Try to access the HuggingFace cache directory + from huggingface_hub import try_to_load_from_cache + + # Check if index file exists in cache + index_file_path = try_to_load_from_cache( + repo_id=model_path, + filename="model.safetensors.index.json", + cache_dir=None, + ) + + if index_file_path and os.path.exists(index_file_path): + return self._check_vision_weights_in_index(index_file_path) + + except Exception: + # If we can't access the cache, fall back to config-based detection + pass + + # Fallback, assume text-only + return False + + def _check_vision_weights_in_index(self, index_file: str) -> bool: + """Check if the model.safetensors.index.json contains vision weights.""" + try: + with open(index_file, "r") as f: + index_data = json_lib.load(f) + + vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"] + weight_names = index_data.get("weight_map", {}).keys() + + return any( + pattern in weight_name + for weight_name in weight_names + for pattern in vision_patterns + ) + except (OSError, json_lib.JSONDecodeError, KeyError): + return False def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): pattern = MultiModalityDataPaddingPatternMultimodalTokens() @@ -59,6 +135,10 @@ class Llama4ForConditionalGeneration(nn.Module): self, items: List[MultimodalDataItem], ) -> torch.Tensor: + # For text-only models, return None or raise an error + if not self.has_vision or self.vision_model is None: + raise ValueError("Vision model not available for text-only checkpoint") + pixel_values = ( torch.concat([item.pixel_values for item in items]) .to(next(self.vision_model.parameters()).device) @@ -79,11 +159,14 @@ class Llama4ForConditionalGeneration(nn.Module): **kwargs: object, ) -> torch.Tensor: + # For text-only models, pass None for image_data_embedding_func + image_embedding_func = self.get_image_feature if self.has_vision else None + hs = general_mm_embed_routine( input_ids=input_ids, forward_batch=forward_batch, language_model=self.language_model, - image_data_embedding_func=self.get_image_feature, + image_data_embedding_func=image_embedding_func, positions=positions, ) @@ -124,7 +207,6 @@ class Llama4ForConditionalGeneration(nn.Module): return name, loaded_weight def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), @@ -137,11 +219,12 @@ class Llama4ForConditionalGeneration(nn.Module): ] params_dict = dict(self.named_parameters()) + num_experts = ( + self.config.text_config.num_local_experts + if hasattr(self.config, "text_config") + else self.config.num_local_experts + ) - num_experts = self.config.text_config.num_local_experts - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -150,81 +233,279 @@ class Llama4ForConditionalGeneration(nn.Module): ) for name, loaded_weight in weights: - if not "vision" in name: + if self._should_skip_weight(name): + continue + + name = self._transform_weight_name(name) + + if "vision" not in name: name, loaded_weight = self.permute_qk_weight_for_rotary( name, loaded_weight ) - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue + if self._handle_scale_remapping(name, params_dict): + continue - if "vision" in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - if ".experts" in name: - # NOTE: llama4 fp8 has different weight format for experts - if ( - "experts.gate_up_proj" not in name - and "experts.down_proj" not in name - ): - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - ) - break - else: - if ".gate_up_proj" in name: - name_list = [ - name.replace( - ".experts.gate_up_proj", ".experts.w13_weight" - ) - ] * 2 - loaded_weight_list = loaded_weight.chunk(2, dim=-1) - shard_id_list = ["w1", "w3"] - else: - name_list = [ - name.replace(".experts.down_proj", ".experts.w2_weight") - ] - shard_id_list = ["w2"] - loaded_weight_list = [loaded_weight] - for name, loaded_weight, shard_id in zip( - name_list, loaded_weight_list, shard_id_list - ): - param = params_dict[name] - weight_loader = param.weight_loader - for expert_id in range(num_experts): - weight_loader( - param, - loaded_weight[expert_id].T, - name, - shard_id=shard_id, - expert_id=expert_id, - ) - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader + if self._handle_stacked_params( + name, loaded_weight, stacked_params_mapping, params_dict + ): + continue + + if self._handle_expert_weights( + name, loaded_weight, expert_params_mapping, params_dict, num_experts + ): + continue + + self._handle_default_weight(name, loaded_weight, params_dict) + + def _should_skip_weight(self, name: str) -> bool: + """Check if we should skip loading this weight.""" + return "vision" in name and not self.has_vision + + def _transform_weight_name(self, name: str) -> str: + """Transform weight name by adding language_model prefix if needed.""" + if ( + not name.startswith("language_model.") + and "vision" not in name + and "multi_modal_projector" not in name + ): + return f"language_model.{name}" + return name + + def _handle_scale_remapping(self, name: str, params_dict: dict) -> bool: + """Handle scale parameter remapping. Returns True if handled.""" + if "scale" in name and "expert" not in name: + remapped_name = maybe_remap_kv_scale_name(name, params_dict) + return remapped_name is None + return False + + def _handle_stacked_params( + self, + name: str, + loaded_weight: torch.Tensor, + stacked_params_mapping: list, + params_dict: dict, + ) -> bool: + """Handle stacked parameter loading. Returns True if handled.""" + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name in name and "vision" not in name: + transformed_name = name.replace(weight_name, param_name) + param = params_dict[transformed_name] + param.weight_loader(param, loaded_weight, shard_id) + return True + return False + + def _handle_expert_weights( + self, + name: str, + loaded_weight: torch.Tensor, + expert_params_mapping: list, + params_dict: dict, + num_experts: int, + ) -> bool: + """Handle expert weight loading for MoE (Mixture of Experts) layers. + + Args: + name: Parameter name from the checkpoint + loaded_weight: The weight tensor to be loaded + expert_params_mapping: Mapping of parameter names to expert configurations + params_dict: Dictionary of model parameters + num_experts: Total number of experts in the MoE layer + + Returns: + bool: True if the parameter was handled (is an expert parameter), False otherwise + """ + if ".experts" not in name: + return False + + if "experts.gate_up_proj" not in name and "experts.down_proj" not in name: + return self._handle_other_expert_params( + name, loaded_weight, expert_params_mapping, params_dict + ) + + if "scale" in name: + return self._handle_expert_scale_params( + name, loaded_weight, params_dict, num_experts + ) + else: + return self._handle_expert_weight_params( + name, loaded_weight, params_dict, num_experts + ) + + def _handle_other_expert_params( + self, + name: str, + loaded_weight: torch.Tensor, + expert_params_mapping: list, + params_dict: dict, + ) -> bool: + """Handle expert parameters that are not gate_up_proj or down_proj weights. + + Args: + name: Parameter name from the checkpoint + loaded_weight: The weight tensor to be loaded + expert_params_mapping: List of tuples mapping checkpoint names to model parameters + params_dict: Dictionary of model parameters + + Returns: + bool: True if parameter was found and handled, False otherwise + """ + for param_name, weight_name, expert_id, shard_id in expert_params_mapping: + if weight_name in name: + transformed_name = name.replace(weight_name, param_name) + param = params_dict[transformed_name] + param.weight_loader( + param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id + ) + return True + return False + + def _transform_expert_name( + self, name: str, is_weight: bool = False + ) -> Tuple[str, str, List[str]]: + """Transform expert parameter name and get shard information. + + Args: + name: The original parameter name + is_weight: Whether this is a weight parameter (adds _weight suffix) + + Returns: + Tuple of (transformed_name, shard_id, shard_id_list) + """ + suffix = "_weight" if is_weight else "" + + if ".gate_up_proj" in name: + transformed_name = name.replace( + ".experts.gate_up_proj", f".experts.w13{suffix}" + ) + shard_id = "w13" + shard_id_list = ["w1", "w3"] + else: # down_proj + transformed_name = name.replace( + ".experts.down_proj", f".experts.w2{suffix}" + ) + shard_id = "w2" + shard_id_list = ["w2"] + + return transformed_name, shard_id, shard_id_list + + def _handle_expert_scale_params( + self, + name: str, + loaded_weight: torch.Tensor, + params_dict: dict, + num_experts: int, + ) -> bool: + """Handle quantization scale parameters for expert weights. + + Args: + name: Parameter name containing scale information + loaded_weight: Scale tensor to be loaded + params_dict: Dictionary of model parameters + num_experts: Total number of experts for broadcast operations + + Returns: + bool: True (always handles scale parameters) + """ + import re + + # Check if this matches the expert parameter pattern: experts.{expert_id}.{param_name} + expert_match = re.search(r"experts\.(\d+)\.", name) + + # Transform name + transformed_name, _, _ = self._transform_expert_name(name) + + if transformed_name not in params_dict: + return True + + param = params_dict[transformed_name] + + # Handle scale parameters + if expert_match: + # If we have a specific expert ID, only load for that expert + expert_id = int(expert_match.group(1)) + # For scale parameters, we can directly set the value + param.data[expert_id] = loaded_weight + else: + # No expert ID found - this is a single scale for all experts + # Load the same scale for all experts + for expert_id in range(num_experts): + param.data[expert_id] = loaded_weight + + return True + + def _handle_expert_weight_params( + self, + name: str, + loaded_weight: torch.Tensor, + params_dict: dict, + num_experts: int, + ) -> bool: + """Handle actual weight tensors for expert layers (gate_up_proj and down_proj). + + Args: + name: Parameter name (should contain gate_up_proj or down_proj) + loaded_weight: Weight tensor(s) to be loaded + params_dict: Dictionary of model parameters + num_experts: Total number of experts for tensor distribution + + Returns: + bool: True (always handles weight parameters) + """ + # Transform name and get shard info + transformed_name, _, shard_id_list = self._transform_expert_name( + name, is_weight=True + ) + + if ".gate_up_proj" in name: + loaded_weight_list = loaded_weight.chunk(2, dim=-1) + else: # down_proj + loaded_weight_list = [loaded_weight] + + for param_name, weight_chunk, shard_id in zip( + [transformed_name] * len(shard_id_list), loaded_weight_list, shard_id_list + ): + if param_name not in params_dict: + continue + + param = params_dict[param_name] + weight_loader = param.weight_loader + + # Handle the case where loaded_weight might be a single tensor for all experts + if weight_chunk.dim() == 2: + # Single tensor case - load for all experts + for expert_id in range(num_experts): + weight_loader( + param, + weight_chunk.T, + param_name, + shard_id=shard_id, + expert_id=expert_id, ) - weight_loader(param, loaded_weight) + else: + # Multiple experts case - load each expert's weights + for expert_id in range(num_experts): + weight_loader( + param, + weight_chunk[expert_id].T, + param_name, + shard_id=shard_id, + expert_id=expert_id, + ) + + return True + + def _handle_default_weight( + self, name: str, loaded_weight: torch.Tensor, params_dict: dict + ): + """Handle default weight loading.""" + # Skip loading extra bias for GPTQ models + if name.endswith(".bias") and name not in params_dict: + return + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): if hasattr(self.language_model, "set_eagle3_layers_to_capture"):