[feat] Support EAGLE3 for Qwen (#7745)
Co-authored-by: 纬杭 <ximing.wxm@antgroup.com> Co-authored-by: zyksir <zyksir@outlook.com>
This commit is contained in:
@@ -293,6 +293,9 @@ class Qwen2Model(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer(return_tuple=True)
|
||||
|
||||
# For EAGLE3 support
|
||||
self.layers_to_capture = []
|
||||
|
||||
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self.config, "scale_emb"):
|
||||
return self.get_input_embeddings()(input_ids) * self.config.scale_emb
|
||||
@@ -321,7 +324,12 @@ class Qwen2Model(nn.Module):
|
||||
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||
residual = pp_proxy_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
if i in self.layers_to_capture:
|
||||
aux_hidden_states.append(
|
||||
hidden_states + residual if residual is not None else hidden_states
|
||||
)
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
@@ -342,7 +350,11 @@ class Qwen2Model(nn.Module):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
if len(aux_hidden_states) == 0:
|
||||
return hidden_states
|
||||
|
||||
return hidden_states, aux_hidden_states
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
|
||||
@@ -440,6 +440,9 @@ class Qwen2MoeModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer(return_tuple=True)
|
||||
|
||||
# For EAGLE3 support
|
||||
self.layers_to_capture = []
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -459,6 +462,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||
residual = pp_proxy_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
if forward_batch.can_run_tbo:
|
||||
hidden_states, residual = model_forward_maybe_tbo(
|
||||
layers=self.layers,
|
||||
@@ -471,6 +475,12 @@ class Qwen2MoeModel(nn.Module):
|
||||
)
|
||||
else:
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
if i in self.layers_to_capture:
|
||||
aux_hidden_states.append(
|
||||
hidden_states + residual
|
||||
if residual is not None
|
||||
else hidden_states
|
||||
)
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
@@ -489,7 +499,11 @@ class Qwen2MoeModel(nn.Module):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
if len(aux_hidden_states) == 0:
|
||||
return hidden_states
|
||||
|
||||
return hidden_states, aux_hidden_states
|
||||
|
||||
|
||||
class Qwen2MoeForCausalLM(nn.Module):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -325,6 +325,9 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
# For EAGLE3 support
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
@@ -346,10 +349,18 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
|
||||
aux_hidden_states = None
|
||||
if self.capture_aux_hidden_states:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids,
|
||||
hidden_states,
|
||||
self.lm_head,
|
||||
forward_batch,
|
||||
aux_hidden_states,
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
@@ -447,5 +458,20 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||
if not self.pp_group.is_last_rank:
|
||||
return
|
||||
|
||||
self.capture_aux_hidden_states = True
|
||||
if layer_ids is None:
|
||||
num_layers = self.config.num_hidden_layers
|
||||
self.model.layers_to_capture = [
|
||||
2,
|
||||
num_layers // 2,
|
||||
num_layers - 3,
|
||||
] # Specific layers for EAGLE3 support
|
||||
else:
|
||||
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
||||
|
||||
|
||||
EntryClass = Qwen3ForCausalLM
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -717,6 +717,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -735,9 +736,13 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
|
||||
aux_hidden_states = None
|
||||
if self.capture_aux_hidden_states:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
||||
)
|
||||
else:
|
||||
return hidden_states
|
||||
@@ -750,6 +755,24 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
def end_layer(self):
|
||||
return self.model.end_layer
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||
if not self.pp_group.is_last_rank:
|
||||
return
|
||||
|
||||
self.capture_aux_hidden_states = True
|
||||
if layer_ids is None:
|
||||
num_layers = self.config.num_hidden_layers
|
||||
self.model.layers_to_capture = [
|
||||
2,
|
||||
num_layers // 2,
|
||||
num_layers - 3,
|
||||
] # Specific layers for EAGLE3 support
|
||||
else:
|
||||
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
Reference in New Issue
Block a user