support llama4 eagle3 (#6985)
Co-authored-by: shuaills <shishuaiuoe@gmail.com> Co-authored-by: Shenggui Li <somerlee.9@gmail.com> Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com> Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
@@ -306,7 +306,26 @@ class ModelRunner:
|
||||
|
||||
# auxiliary hidden capture mode. TODO: expose this to server args?
|
||||
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
|
||||
self.model.set_eagle3_layers_to_capture()
|
||||
# load draft config
|
||||
draft_model_config = ModelConfig.from_server_args(
|
||||
server_args,
|
||||
model_path=(server_args.speculative_draft_model_path),
|
||||
is_draft_model=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# get the aux layer from draft model config
|
||||
eagle_config = getattr(
|
||||
draft_model_config.hf_config, "eagle_config", None
|
||||
)
|
||||
eagle_aux_hidden_state_layer_ids = eagle_config[
|
||||
"eagle_aux_hidden_state_layer_ids"
|
||||
]
|
||||
except:
|
||||
# if there is no aux layer, set to None
|
||||
eagle_aux_hidden_state_layer_ids = None
|
||||
|
||||
self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
|
||||
|
||||
def model_specific_adjustment(self):
|
||||
server_args = self.server_args
|
||||
|
||||
@@ -124,6 +124,9 @@ def _get_quantization_config(
|
||||
quant_config = get_quant_config(
|
||||
model_config, load_config, packed_modules_mapping
|
||||
)
|
||||
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
||||
if quant_config is None:
|
||||
return None
|
||||
major, minor = get_device_capability()
|
||||
|
||||
if major is not None and minor is not None:
|
||||
|
||||
@@ -209,6 +209,17 @@ def get_quant_config(
|
||||
config["adapter_name_or_path"] = model_name_or_path
|
||||
elif model_config.quantization == "modelopt":
|
||||
if config["producer"]["name"] == "modelopt":
|
||||
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
||||
if config["quantization"]["quant_algo"] is None:
|
||||
if (
|
||||
model_config.hf_config.architectures[0]
|
||||
!= "LlamaForCausalLMEagle3"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid quant_config, quantization method: {model_config.quantization},"
|
||||
f"hf architectures: {model_config.hf_config.architectures[0]}. "
|
||||
)
|
||||
return None
|
||||
if "FP4" in config["quantization"]["quant_algo"]:
|
||||
return ModelOptFp4Config.from_config(config)
|
||||
else:
|
||||
|
||||
@@ -697,13 +697,19 @@ class LlamaForCausalLM(nn.Module):
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
def set_eagle3_layers_to_capture(self):
|
||||
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||
if not self.pp_group.is_last_rank:
|
||||
return
|
||||
|
||||
self.capture_aux_hidden_states = True
|
||||
num_layers = self.config.num_hidden_layers
|
||||
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
||||
if layer_ids is None:
|
||||
self.capture_aux_hidden_states = True
|
||||
num_layers = self.config.num_hidden_layers
|
||||
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
||||
else:
|
||||
self.capture_aux_hidden_states = True
|
||||
# we plus 1 here because in sglang, for the ith layer, it takes the output
|
||||
# of the (i-1)th layer as aux hidden state
|
||||
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
||||
|
||||
|
||||
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||
|
||||
@@ -35,7 +35,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP
|
||||
|
||||
|
||||
class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
@@ -59,6 +60,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
if config.model_type == "llama4_text":
|
||||
inter_size = config.intermediate_size_mlp
|
||||
else:
|
||||
inter_size = config.intermediate_size
|
||||
|
||||
self.mlp = LlamaMLP(
|
||||
config.hidden_size, inter_size, config.hidden_act, quant_config, prefix
|
||||
)
|
||||
|
||||
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
@@ -105,11 +115,19 @@ class LlamaModel(nn.Module):
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
|
||||
|
||||
if hasattr(config, "target_hidden_size"):
|
||||
self.fc = torch.nn.Linear(config.target_hidden_size * 3, config.hidden_size)
|
||||
self.hidden_size_in = config.target_hidden_size
|
||||
else:
|
||||
self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size)
|
||||
self.hidden_size_in = config.hidden_size
|
||||
|
||||
self.fc = torch.nn.Linear(
|
||||
self.hidden_size_in * 3,
|
||||
config.hidden_size,
|
||||
bias=getattr(config, "bias", False),
|
||||
)
|
||||
|
||||
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@@ -179,18 +197,50 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.capture_aux_hidden_states = True
|
||||
self.hot_token_id = None
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
|
||||
params_dict = dict(self.named_parameters())
|
||||
# Define the parameter mapping for stacked parameters
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
for name, loaded_weight in weights:
|
||||
if "d2t" in name:
|
||||
# d2t stores diffs between draft id and target id
|
||||
self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
|
||||
continue
|
||||
|
||||
if "d2t" not in name and "t2d" not in name and "lm_head" not in name:
|
||||
new_name = f"model.{name}"
|
||||
super().load_weights([(new_name, loaded_weight)])
|
||||
elif "lm_head" in name:
|
||||
super().load_weights([(name, loaded_weight)])
|
||||
if "t2d" in name:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param_name = f"model.{name}" if name not in params_dict else name
|
||||
if param_name in params_dict:
|
||||
param = params_dict[param_name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Handle regular parameters
|
||||
param_name = name if name in params_dict else f"model.{name}"
|
||||
if param_name in params_dict:
|
||||
param = params_dict[param_name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def get_hot_token_id(self):
|
||||
return self.hot_token_id
|
||||
|
||||
@@ -223,5 +223,34 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||
if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
|
||||
self.language_model.set_eagle3_layers_to_capture(layer_ids)
|
||||
|
||||
def get_embed_and_head(self):
|
||||
# For EAGLE3, we delegate to the language model which should have this method
|
||||
# If the language model doesn't have lm_head (like EAGLE3), we return None for head
|
||||
embed = self.language_model.get_embed()
|
||||
if hasattr(self.language_model, "get_embed_and_head"):
|
||||
return self.language_model.get_embed_and_head()
|
||||
elif hasattr(self.language_model, "lm_head"):
|
||||
return embed, self.language_model.lm_head.weight
|
||||
else:
|
||||
# For EAGLE3, head might not be needed
|
||||
return embed, None
|
||||
|
||||
def set_embed_and_head(self, embed, head):
|
||||
if hasattr(self.language_model, "set_embed_and_head"):
|
||||
return self.language_model.set_embed_and_head(embed, head)
|
||||
else:
|
||||
# For EAGLE3, only set embed
|
||||
return self.language_model.set_embed(embed)
|
||||
|
||||
def get_embed(self):
|
||||
return self.language_model.get_embed()
|
||||
|
||||
def set_embed(self, embed):
|
||||
return self.language_model.set_embed(embed)
|
||||
|
||||
|
||||
EntryClass = Llama4ForConditionalGeneration
|
||||
|
||||
@@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.draft_model_runner.model.set_embed(embed)
|
||||
|
||||
# grab hot token ids
|
||||
self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
|
||||
embed.device
|
||||
)
|
||||
if self.draft_model_runner.model.hot_token_id is not None:
|
||||
self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
|
||||
embed.device
|
||||
)
|
||||
|
||||
else:
|
||||
if self.hot_token_id is not None:
|
||||
head = head.clone()
|
||||
|
||||
Reference in New Issue
Block a user