From 886d344964757bb89bdacbbdccc0728a0b909aca Mon Sep 17 00:00:00 2001 From: lukec <118525388+sleepcoo@users.noreply.github.com> Date: Tue, 1 Jul 2025 13:34:10 +0800 Subject: [PATCH] support llama4 eagle3 (#6985) Co-authored-by: shuaills Co-authored-by: Shenggui Li Co-authored-by: Yingyi Huang Co-authored-by: yizhang2077 <1109276519@qq.com> --- .../sglang/srt/model_executor/model_runner.py | 21 +++++- python/sglang/srt/model_loader/loader.py | 3 + .../sglang/srt/model_loader/weight_utils.py | 11 +++ python/sglang/srt/models/llama.py | 14 ++-- python/sglang/srt/models/llama_eagle3.py | 70 ++++++++++++++++--- python/sglang/srt/models/mllama4.py | 29 ++++++++ python/sglang/srt/speculative/eagle_worker.py | 8 ++- 7 files changed, 138 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ccb1cf08f..0bd442d0e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -306,7 +306,26 @@ class ModelRunner: # auxiliary hidden capture mode. TODO: expose this to server args? if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: - self.model.set_eagle3_layers_to_capture() + # load draft config + draft_model_config = ModelConfig.from_server_args( + server_args, + model_path=(server_args.speculative_draft_model_path), + is_draft_model=True, + ) + + try: + # get the aux layer from draft model config + eagle_config = getattr( + draft_model_config.hf_config, "eagle_config", None + ) + eagle_aux_hidden_state_layer_ids = eagle_config[ + "eagle_aux_hidden_state_layer_ids" + ] + except: + # if there is no aux layer, set to None + eagle_aux_hidden_state_layer_ids = None + + self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids) def model_specific_adjustment(self): server_args = self.server_args diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 7b1154c94..5b267caf2 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -124,6 +124,9 @@ def _get_quantization_config( quant_config = get_quant_config( model_config, load_config, packed_modules_mapping ) + # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3 + if quant_config is None: + return None major, minor = get_device_capability() if major is not None and minor is not None: diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 33bc4e152..8dd0a4a15 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -209,6 +209,17 @@ def get_quant_config( config["adapter_name_or_path"] = model_name_or_path elif model_config.quantization == "modelopt": if config["producer"]["name"] == "modelopt": + # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3 + if config["quantization"]["quant_algo"] is None: + if ( + model_config.hf_config.architectures[0] + != "LlamaForCausalLMEagle3" + ): + raise ValueError( + f"Invalid quant_config, quantization method: {model_config.quantization}," + f"hf architectures: {model_config.hf_config.architectures[0]}. " + ) + return None if "FP4" in config["quantization"]["quant_algo"]: return ModelOptFp4Config.from_config(config) else: diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index f08a50611..24a16bf21 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -697,13 +697,19 @@ class LlamaForCausalLM(nn.Module): def load_kv_cache_scales(self, quantization_param_path: str) -> None: self.model.load_kv_cache_scales(quantization_param_path) - def set_eagle3_layers_to_capture(self): + def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): if not self.pp_group.is_last_rank: return - self.capture_aux_hidden_states = True - num_layers = self.config.num_hidden_layers - self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3] + if layer_ids is None: + self.capture_aux_hidden_states = True + num_layers = self.config.num_hidden_layers + self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3] + else: + self.capture_aux_hidden_states = True + # we plus 1 here because in sglang, for the ith layer, it takes the output + # of the (i-1)th layer as aux hidden state + self.model.layers_to_capture = [val + 1 for val in layer_ids] class Phi3ForCausalLM(LlamaForCausalLM): diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py index bbc58ae60..f8d7b608c 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -35,7 +35,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP class LlamaDecoderLayer(LlamaDecoderLayer): @@ -59,6 +60,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer): prefix=add_prefix("qkv_proj", prefix), ) + if config.model_type == "llama4_text": + inter_size = config.intermediate_size_mlp + else: + inter_size = config.intermediate_size + + self.mlp = LlamaMLP( + config.hidden_size, inter_size, config.hidden_act, quant_config, prefix + ) + self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -105,11 +115,19 @@ class LlamaModel(nn.Module): config.hidden_size, prefix=add_prefix("embed_tokens", prefix), ) - self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix) + if hasattr(config, "target_hidden_size"): - self.fc = torch.nn.Linear(config.target_hidden_size * 3, config.hidden_size) + self.hidden_size_in = config.target_hidden_size else: - self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size) + self.hidden_size_in = config.hidden_size + + self.fc = torch.nn.Linear( + self.hidden_size_in * 3, + config.hidden_size, + bias=getattr(config, "bias", False), + ) + + self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -179,18 +197,50 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM): self.logits_processor = LogitsProcessor(config) self.capture_aux_hidden_states = True + self.hot_token_id = None + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: + params_dict = dict(self.named_parameters()) + # Define the parameter mapping for stacked parameters + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "d2t" in name: # d2t stores diffs between draft id and target id self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0]) + continue - if "d2t" not in name and "t2d" not in name and "lm_head" not in name: - new_name = f"model.{name}" - super().load_weights([(new_name, loaded_weight)]) - elif "lm_head" in name: - super().load_weights([(name, loaded_weight)]) + if "t2d" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param_name = f"model.{name}" if name not in params_dict else name + if param_name in params_dict: + param = params_dict[param_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight, shard_id) + break + else: + # Handle regular parameters + param_name = name if name in params_dict else f"model.{name}" + if param_name in params_dict: + param = params_dict[param_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) def get_hot_token_id(self): return self.hot_token_id diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 21b6c51d6..3837989da 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -223,5 +223,34 @@ class Llama4ForConditionalGeneration(nn.Module): ) weight_loader(param, loaded_weight) + def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): + if hasattr(self.language_model, "set_eagle3_layers_to_capture"): + self.language_model.set_eagle3_layers_to_capture(layer_ids) + + def get_embed_and_head(self): + # For EAGLE3, we delegate to the language model which should have this method + # If the language model doesn't have lm_head (like EAGLE3), we return None for head + embed = self.language_model.get_embed() + if hasattr(self.language_model, "get_embed_and_head"): + return self.language_model.get_embed_and_head() + elif hasattr(self.language_model, "lm_head"): + return embed, self.language_model.lm_head.weight + else: + # For EAGLE3, head might not be needed + return embed, None + + def set_embed_and_head(self, embed, head): + if hasattr(self.language_model, "set_embed_and_head"): + return self.language_model.set_embed_and_head(embed, head) + else: + # For EAGLE3, only set embed + return self.language_model.set_embed(embed) + + def get_embed(self): + return self.language_model.get_embed() + + def set_embed(self, embed): + return self.language_model.set_embed(embed) + EntryClass = Llama4ForConditionalGeneration diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index effcbae4a..e78da174f 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker): self.draft_model_runner.model.set_embed(embed) # grab hot token ids - self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to( - embed.device - ) + if self.draft_model_runner.model.hot_token_id is not None: + self.hot_token_id = self.draft_model_runner.model.hot_token_id.to( + embed.device + ) + else: if self.hot_token_id is not None: head = head.clone()