support llama4 eagle3 (#6985)
Co-authored-by: shuaills <shishuaiuoe@gmail.com> Co-authored-by: Shenggui Li <somerlee.9@gmail.com> Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com> Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
@@ -306,7 +306,26 @@ class ModelRunner:
|
|||||||
|
|
||||||
# auxiliary hidden capture mode. TODO: expose this to server args?
|
# auxiliary hidden capture mode. TODO: expose this to server args?
|
||||||
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
|
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
|
||||||
self.model.set_eagle3_layers_to_capture()
|
# load draft config
|
||||||
|
draft_model_config = ModelConfig.from_server_args(
|
||||||
|
server_args,
|
||||||
|
model_path=(server_args.speculative_draft_model_path),
|
||||||
|
is_draft_model=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# get the aux layer from draft model config
|
||||||
|
eagle_config = getattr(
|
||||||
|
draft_model_config.hf_config, "eagle_config", None
|
||||||
|
)
|
||||||
|
eagle_aux_hidden_state_layer_ids = eagle_config[
|
||||||
|
"eagle_aux_hidden_state_layer_ids"
|
||||||
|
]
|
||||||
|
except:
|
||||||
|
# if there is no aux layer, set to None
|
||||||
|
eagle_aux_hidden_state_layer_ids = None
|
||||||
|
|
||||||
|
self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
|
||||||
|
|
||||||
def model_specific_adjustment(self):
|
def model_specific_adjustment(self):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
|
|||||||
@@ -124,6 +124,9 @@ def _get_quantization_config(
|
|||||||
quant_config = get_quant_config(
|
quant_config = get_quant_config(
|
||||||
model_config, load_config, packed_modules_mapping
|
model_config, load_config, packed_modules_mapping
|
||||||
)
|
)
|
||||||
|
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
||||||
|
if quant_config is None:
|
||||||
|
return None
|
||||||
major, minor = get_device_capability()
|
major, minor = get_device_capability()
|
||||||
|
|
||||||
if major is not None and minor is not None:
|
if major is not None and minor is not None:
|
||||||
|
|||||||
@@ -209,6 +209,17 @@ def get_quant_config(
|
|||||||
config["adapter_name_or_path"] = model_name_or_path
|
config["adapter_name_or_path"] = model_name_or_path
|
||||||
elif model_config.quantization == "modelopt":
|
elif model_config.quantization == "modelopt":
|
||||||
if config["producer"]["name"] == "modelopt":
|
if config["producer"]["name"] == "modelopt":
|
||||||
|
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
||||||
|
if config["quantization"]["quant_algo"] is None:
|
||||||
|
if (
|
||||||
|
model_config.hf_config.architectures[0]
|
||||||
|
!= "LlamaForCausalLMEagle3"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid quant_config, quantization method: {model_config.quantization},"
|
||||||
|
f"hf architectures: {model_config.hf_config.architectures[0]}. "
|
||||||
|
)
|
||||||
|
return None
|
||||||
if "FP4" in config["quantization"]["quant_algo"]:
|
if "FP4" in config["quantization"]["quant_algo"]:
|
||||||
return ModelOptFp4Config.from_config(config)
|
return ModelOptFp4Config.from_config(config)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -697,13 +697,19 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||||
self.model.load_kv_cache_scales(quantization_param_path)
|
self.model.load_kv_cache_scales(quantization_param_path)
|
||||||
|
|
||||||
def set_eagle3_layers_to_capture(self):
|
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||||
if not self.pp_group.is_last_rank:
|
if not self.pp_group.is_last_rank:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if layer_ids is None:
|
||||||
self.capture_aux_hidden_states = True
|
self.capture_aux_hidden_states = True
|
||||||
num_layers = self.config.num_hidden_layers
|
num_layers = self.config.num_hidden_layers
|
||||||
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
||||||
|
else:
|
||||||
|
self.capture_aux_hidden_states = True
|
||||||
|
# we plus 1 here because in sglang, for the ith layer, it takes the output
|
||||||
|
# of the (i-1)th layer as aux hidden state
|
||||||
|
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
||||||
|
|
||||||
|
|
||||||
class Phi3ForCausalLM(LlamaForCausalLM):
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP
|
||||||
|
|
||||||
|
|
||||||
class LlamaDecoderLayer(LlamaDecoderLayer):
|
class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||||
@@ -59,6 +60,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
prefix=add_prefix("qkv_proj", prefix),
|
prefix=add_prefix("qkv_proj", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.model_type == "llama4_text":
|
||||||
|
inter_size = config.intermediate_size_mlp
|
||||||
|
else:
|
||||||
|
inter_size = config.intermediate_size
|
||||||
|
|
||||||
|
self.mlp = LlamaMLP(
|
||||||
|
config.hidden_size, inter_size, config.hidden_act, quant_config, prefix
|
||||||
|
)
|
||||||
|
|
||||||
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -105,11 +115,19 @@ class LlamaModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
prefix=add_prefix("embed_tokens", prefix),
|
prefix=add_prefix("embed_tokens", prefix),
|
||||||
)
|
)
|
||||||
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
|
|
||||||
if hasattr(config, "target_hidden_size"):
|
if hasattr(config, "target_hidden_size"):
|
||||||
self.fc = torch.nn.Linear(config.target_hidden_size * 3, config.hidden_size)
|
self.hidden_size_in = config.target_hidden_size
|
||||||
else:
|
else:
|
||||||
self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size)
|
self.hidden_size_in = config.hidden_size
|
||||||
|
|
||||||
|
self.fc = torch.nn.Linear(
|
||||||
|
self.hidden_size_in * 3,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=getattr(config, "bias", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
@@ -179,18 +197,50 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
|||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.capture_aux_hidden_states = True
|
self.capture_aux_hidden_states = True
|
||||||
|
self.hot_token_id = None
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
# Define the parameter mapping for stacked parameters
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "d2t" in name:
|
if "d2t" in name:
|
||||||
# d2t stores diffs between draft id and target id
|
# d2t stores diffs between draft id and target id
|
||||||
self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
|
self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
|
||||||
|
continue
|
||||||
|
|
||||||
if "d2t" not in name and "t2d" not in name and "lm_head" not in name:
|
if "t2d" in name:
|
||||||
new_name = f"model.{name}"
|
continue
|
||||||
super().load_weights([(new_name, loaded_weight)])
|
|
||||||
elif "lm_head" in name:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
super().load_weights([(name, loaded_weight)])
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param_name = f"model.{name}" if name not in params_dict else name
|
||||||
|
if param_name in params_dict:
|
||||||
|
param = params_dict[param_name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Handle regular parameters
|
||||||
|
param_name = name if name in params_dict else f"model.{name}"
|
||||||
|
if param_name in params_dict:
|
||||||
|
param = params_dict[param_name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
def get_hot_token_id(self):
|
def get_hot_token_id(self):
|
||||||
return self.hot_token_id
|
return self.hot_token_id
|
||||||
|
|||||||
@@ -223,5 +223,34 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||||
|
if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
|
||||||
|
self.language_model.set_eagle3_layers_to_capture(layer_ids)
|
||||||
|
|
||||||
|
def get_embed_and_head(self):
|
||||||
|
# For EAGLE3, we delegate to the language model which should have this method
|
||||||
|
# If the language model doesn't have lm_head (like EAGLE3), we return None for head
|
||||||
|
embed = self.language_model.get_embed()
|
||||||
|
if hasattr(self.language_model, "get_embed_and_head"):
|
||||||
|
return self.language_model.get_embed_and_head()
|
||||||
|
elif hasattr(self.language_model, "lm_head"):
|
||||||
|
return embed, self.language_model.lm_head.weight
|
||||||
|
else:
|
||||||
|
# For EAGLE3, head might not be needed
|
||||||
|
return embed, None
|
||||||
|
|
||||||
|
def set_embed_and_head(self, embed, head):
|
||||||
|
if hasattr(self.language_model, "set_embed_and_head"):
|
||||||
|
return self.language_model.set_embed_and_head(embed, head)
|
||||||
|
else:
|
||||||
|
# For EAGLE3, only set embed
|
||||||
|
return self.language_model.set_embed(embed)
|
||||||
|
|
||||||
|
def get_embed(self):
|
||||||
|
return self.language_model.get_embed()
|
||||||
|
|
||||||
|
def set_embed(self, embed):
|
||||||
|
return self.language_model.set_embed(embed)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Llama4ForConditionalGeneration
|
EntryClass = Llama4ForConditionalGeneration
|
||||||
|
|||||||
@@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.draft_model_runner.model.set_embed(embed)
|
self.draft_model_runner.model.set_embed(embed)
|
||||||
|
|
||||||
# grab hot token ids
|
# grab hot token ids
|
||||||
self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
|
if self.draft_model_runner.model.hot_token_id is not None:
|
||||||
|
self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
|
||||||
embed.device
|
embed.device
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.hot_token_id is not None:
|
if self.hot_token_id is not None:
|
||||||
head = head.clone()
|
head = head.clone()
|
||||||
|
|||||||
Reference in New Issue
Block a user